feat(app): add file uploads
This commit is contained in:
parent
c650d1c3d9
commit
4b89008f3a
25 changed files with 1035 additions and 86 deletions
|
@ -20,10 +20,12 @@ import eventEmitter from 'events';
|
|||
import computeSimilarity from '../utils/computeSimilarity';
|
||||
import logger from '../utils/logger';
|
||||
import LineListOutputParser from '../lib/outputParsers/listLineOutputParser';
|
||||
import { getDocumentsFromLinks } from '../lib/linkDocument';
|
||||
import LineOutputParser from '../lib/outputParsers/lineOutputParser';
|
||||
import { IterableReadableStream } from '@langchain/core/utils/stream';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import { getDocumentsFromLinks } from '../utils/documents';
|
||||
|
||||
const basicSearchRetrieverPrompt = `
|
||||
You are an AI question rephraser. You will be given a conversation and a follow-up question, you will have to rephrase the follow up question so it is a standalone question and can be used by another LLM to search the web for information to answer it.
|
||||
|
@ -316,6 +318,7 @@ const createBasicWebSearchAnsweringChain = (
|
|||
llm: BaseChatModel,
|
||||
embeddings: Embeddings,
|
||||
optimizationMode: 'speed' | 'balanced' | 'quality',
|
||||
fileIds: string[],
|
||||
) => {
|
||||
const basicWebSearchRetrieverChain = createBasicWebSearchRetrieverChain(llm);
|
||||
|
||||
|
@ -336,8 +339,32 @@ const createBasicWebSearchAnsweringChain = (
|
|||
return docs;
|
||||
}
|
||||
|
||||
const filesData = fileIds
|
||||
.map((file) => {
|
||||
const filePath = path.join(process.cwd(), 'uploads', file);
|
||||
|
||||
const contentPath = filePath + '-extracted.json';
|
||||
const embeddingsPath = filePath + '-embeddings.json';
|
||||
|
||||
const content = JSON.parse(fs.readFileSync(contentPath, 'utf8'));
|
||||
const embeddings = JSON.parse(fs.readFileSync(embeddingsPath, 'utf8'));
|
||||
|
||||
const fileSimilaritySearchObject = content.contents.map(
|
||||
(c: string, i) => {
|
||||
return {
|
||||
fileName: content.title,
|
||||
content: c,
|
||||
embeddings: embeddings.embeddings[i],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
return fileSimilaritySearchObject;
|
||||
})
|
||||
.flat();
|
||||
|
||||
if (query.toLocaleLowerCase() === 'summarize') {
|
||||
return docs.slice(0, 15)
|
||||
return docs.slice(0, 15);
|
||||
}
|
||||
|
||||
const docsWithContent = docs.filter(
|
||||
|
@ -345,7 +372,43 @@ const createBasicWebSearchAnsweringChain = (
|
|||
);
|
||||
|
||||
if (optimizationMode === 'speed') {
|
||||
return docsWithContent.slice(0, 15);
|
||||
if (filesData.length > 0) {
|
||||
const [queryEmbedding] = await Promise.all([
|
||||
embeddings.embedQuery(query),
|
||||
]);
|
||||
|
||||
const fileDocs = filesData.map((fileData) => {
|
||||
return new Document({
|
||||
pageContent: fileData.content,
|
||||
metadata: {
|
||||
title: fileData.fileName,
|
||||
url: `File`,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const similarity = filesData.map((fileData, i) => {
|
||||
const sim = computeSimilarity(queryEmbedding, fileData.embeddings);
|
||||
|
||||
return {
|
||||
index: i,
|
||||
similarity: sim,
|
||||
};
|
||||
});
|
||||
|
||||
const sortedDocs = similarity
|
||||
.filter((sim) => sim.similarity > 0.3)
|
||||
.sort((a, b) => b.similarity - a.similarity)
|
||||
.slice(0, 8)
|
||||
.map((sim) => fileDocs[sim.index]);
|
||||
|
||||
return [
|
||||
...sortedDocs,
|
||||
...docsWithContent.slice(0, 15 - sortedDocs.length),
|
||||
];
|
||||
} else {
|
||||
return docsWithContent.slice(0, 15);
|
||||
}
|
||||
} else if (optimizationMode === 'balanced') {
|
||||
const [docEmbeddings, queryEmbedding] = await Promise.all([
|
||||
embeddings.embedDocuments(
|
||||
|
@ -354,6 +417,20 @@ const createBasicWebSearchAnsweringChain = (
|
|||
embeddings.embedQuery(query),
|
||||
]);
|
||||
|
||||
docsWithContent.push(
|
||||
...filesData.map((fileData) => {
|
||||
return new Document({
|
||||
pageContent: fileData.content,
|
||||
metadata: {
|
||||
title: fileData.fileName,
|
||||
url: `File`,
|
||||
},
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
docEmbeddings.push(...filesData.map((fileData) => fileData.embeddings));
|
||||
|
||||
const similarity = docEmbeddings.map((docEmbedding, i) => {
|
||||
const sim = computeSimilarity(queryEmbedding, docEmbedding);
|
||||
|
||||
|
@ -408,6 +485,7 @@ const basicWebSearch = (
|
|||
llm: BaseChatModel,
|
||||
embeddings: Embeddings,
|
||||
optimizationMode: 'speed' | 'balanced' | 'quality',
|
||||
fileIds: string[],
|
||||
) => {
|
||||
const emitter = new eventEmitter();
|
||||
|
||||
|
@ -416,6 +494,7 @@ const basicWebSearch = (
|
|||
llm,
|
||||
embeddings,
|
||||
optimizationMode,
|
||||
fileIds,
|
||||
);
|
||||
|
||||
const stream = basicWebSearchAnsweringChain.streamEvents(
|
||||
|
@ -446,6 +525,7 @@ const handleWebSearch = (
|
|||
llm: BaseChatModel,
|
||||
embeddings: Embeddings,
|
||||
optimizationMode: 'speed' | 'balanced' | 'quality',
|
||||
fileIds: string[],
|
||||
) => {
|
||||
const emitter = basicWebSearch(
|
||||
message,
|
||||
|
@ -453,6 +533,7 @@ const handleWebSearch = (
|
|||
llm,
|
||||
embeddings,
|
||||
optimizationMode,
|
||||
fileIds,
|
||||
);
|
||||
return emitter;
|
||||
};
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import { sql } from 'drizzle-orm';
|
||||
import { text, integer, sqliteTable } from 'drizzle-orm/sqlite-core';
|
||||
|
||||
export const messages = sqliteTable('messages', {
|
||||
|
@ -11,9 +12,17 @@ export const messages = sqliteTable('messages', {
|
|||
}),
|
||||
});
|
||||
|
||||
interface File {
|
||||
name: string;
|
||||
fileId: string;
|
||||
}
|
||||
|
||||
export const chats = sqliteTable('chats', {
|
||||
id: text('id').primaryKey(),
|
||||
title: text('title').notNull(),
|
||||
createdAt: text('createdAt').notNull(),
|
||||
focusMode: text('focusMode').notNull(),
|
||||
files: text('files', { mode: 'json' })
|
||||
.$type<File[]>()
|
||||
.default(sql`'[]'`),
|
||||
});
|
||||
|
|
|
@ -6,7 +6,7 @@ import { ChatOllama } from '@langchain/community/chat_models/ollama';
|
|||
export const loadOllamaChatModels = async () => {
|
||||
const ollamaEndpoint = getOllamaApiEndpoint();
|
||||
const keepAlive = getKeepAlive();
|
||||
|
||||
|
||||
if (!ollamaEndpoint) return {};
|
||||
|
||||
try {
|
||||
|
@ -25,7 +25,7 @@ export const loadOllamaChatModels = async () => {
|
|||
baseUrl: ollamaEndpoint,
|
||||
model: model.model,
|
||||
temperature: 0.7,
|
||||
keepAlive: keepAlive
|
||||
keepAlive: keepAlive,
|
||||
}),
|
||||
};
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import suggestionsRouter from './suggestions';
|
|||
import chatsRouter from './chats';
|
||||
import searchRouter from './search';
|
||||
import discoverRouter from './discover';
|
||||
import uploadsRouter from './uploads';
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
|
@ -18,5 +19,6 @@ router.use('/suggestions', suggestionsRouter);
|
|||
router.use('/chats', chatsRouter);
|
||||
router.use('/search', searchRouter);
|
||||
router.use('/discover', discoverRouter);
|
||||
router.use('/uploads', uploadsRouter);
|
||||
|
||||
export default router;
|
||||
|
|
151
src/routes/uploads.ts
Normal file
151
src/routes/uploads.ts
Normal file
|
@ -0,0 +1,151 @@
|
|||
import express from 'express';
|
||||
import logger from '../utils/logger';
|
||||
import multer from 'multer';
|
||||
import path from 'path';
|
||||
import crypto from 'crypto';
|
||||
import fs from 'fs';
|
||||
import { Embeddings } from '@langchain/core/embeddings';
|
||||
import { getAvailableEmbeddingModelProviders } from '../lib/providers';
|
||||
import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf';
|
||||
import { DocxLoader } from '@langchain/community/document_loaders/fs/docx';
|
||||
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters';
|
||||
import { Document } from 'langchain/document';
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const splitter = new RecursiveCharacterTextSplitter({
|
||||
chunkSize: 500,
|
||||
chunkOverlap: 100,
|
||||
});
|
||||
|
||||
const storage = multer.diskStorage({
|
||||
destination: (req, file, cb) => {
|
||||
cb(null, path.join(process.cwd(), './uploads'));
|
||||
},
|
||||
filename: (req, file, cb) => {
|
||||
const splitedFileName = file.originalname.split('.');
|
||||
const fileExtension = splitedFileName[splitedFileName.length - 1];
|
||||
if (!['pdf', 'docx', 'txt'].includes(fileExtension)) {
|
||||
return cb(new Error('File type is not supported'), '');
|
||||
}
|
||||
cb(null, `${crypto.randomBytes(16).toString('hex')}.${fileExtension}`);
|
||||
},
|
||||
});
|
||||
|
||||
const upload = multer({ storage });
|
||||
|
||||
router.post(
|
||||
'/',
|
||||
upload.fields([
|
||||
{ name: 'files' },
|
||||
{ name: 'embedding_model', maxCount: 1 },
|
||||
{ name: 'embedding_model_provider', maxCount: 1 },
|
||||
]),
|
||||
async (req, res) => {
|
||||
try {
|
||||
const { embedding_model, embedding_model_provider } = req.body;
|
||||
|
||||
if (!embedding_model || !embedding_model_provider) {
|
||||
res
|
||||
.status(400)
|
||||
.json({ message: 'Missing embedding model or provider' });
|
||||
return;
|
||||
}
|
||||
|
||||
const embeddingModels = await getAvailableEmbeddingModelProviders();
|
||||
const provider =
|
||||
embedding_model_provider ?? Object.keys(embeddingModels)[0];
|
||||
const embeddingModel: Embeddings =
|
||||
embedding_model ?? Object.keys(embeddingModels[provider])[0];
|
||||
|
||||
let embeddingsModel: Embeddings | undefined;
|
||||
|
||||
if (
|
||||
embeddingModels[provider] &&
|
||||
embeddingModels[provider][embeddingModel]
|
||||
) {
|
||||
embeddingsModel = embeddingModels[provider][embeddingModel].model as
|
||||
| Embeddings
|
||||
| undefined;
|
||||
}
|
||||
|
||||
if (!embeddingsModel) {
|
||||
res.status(400).json({ message: 'Invalid LLM model selected' });
|
||||
return;
|
||||
}
|
||||
|
||||
const files = req.files['files'] as Express.Multer.File[];
|
||||
if (!files || files.length === 0) {
|
||||
res.status(400).json({ message: 'No files uploaded' });
|
||||
return;
|
||||
}
|
||||
|
||||
await Promise.all(
|
||||
files.map(async (file) => {
|
||||
let docs: Document[] = [];
|
||||
|
||||
if (file.mimetype === 'application/pdf') {
|
||||
const loader = new PDFLoader(file.path);
|
||||
docs = await loader.load();
|
||||
} else if (
|
||||
file.mimetype ===
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||
) {
|
||||
const loader = new DocxLoader(file.path);
|
||||
docs = await loader.load();
|
||||
} else if (file.mimetype === 'text/plain') {
|
||||
const text = fs.readFileSync(file.path, 'utf-8');
|
||||
docs = [
|
||||
new Document({
|
||||
pageContent: text,
|
||||
metadata: {
|
||||
title: file.originalname,
|
||||
},
|
||||
}),
|
||||
];
|
||||
}
|
||||
|
||||
const splitted = await splitter.splitDocuments(docs);
|
||||
|
||||
const json = JSON.stringify({
|
||||
title: file.originalname,
|
||||
contents: splitted.map((doc) => doc.pageContent),
|
||||
});
|
||||
|
||||
const pathToSave = file.path.replace(/\.\w+$/, '-extracted.json');
|
||||
fs.writeFileSync(pathToSave, json);
|
||||
|
||||
const embeddings = await embeddingsModel.embedDocuments(
|
||||
splitted.map((doc) => doc.pageContent),
|
||||
);
|
||||
|
||||
const embeddingsJSON = JSON.stringify({
|
||||
title: file.originalname,
|
||||
embeddings: embeddings,
|
||||
});
|
||||
|
||||
const pathToSaveEmbeddings = file.path.replace(
|
||||
/\.\w+$/,
|
||||
'-embeddings.json',
|
||||
);
|
||||
fs.writeFileSync(pathToSaveEmbeddings, embeddingsJSON);
|
||||
}),
|
||||
);
|
||||
|
||||
res.status(200).json({
|
||||
files: files.map((file) => {
|
||||
return {
|
||||
fileName: file.originalname,
|
||||
fileExtension: file.filename.split('.').pop(),
|
||||
fileId: file.filename.replace(/\.\w+$/, ''),
|
||||
};
|
||||
}),
|
||||
});
|
||||
} catch (err: any) {
|
||||
logger.error(`Error in uploading file results: ${err.message}`);
|
||||
res.status(500).json({ message: 'An error has occurred.' });
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
export default router;
|
|
@ -3,7 +3,7 @@ import { htmlToText } from 'html-to-text';
|
|||
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
|
||||
import { Document } from '@langchain/core/documents';
|
||||
import pdfParse from 'pdf-parse';
|
||||
import logger from '../utils/logger';
|
||||
import logger from './logger';
|
||||
|
||||
export const getDocumentsFromLinks = async ({ links }: { links: string[] }) => {
|
||||
const splitter = new RecursiveCharacterTextSplitter();
|
16
src/utils/files.ts
Normal file
16
src/utils/files.ts
Normal file
|
@ -0,0 +1,16 @@
|
|||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
export const getFileDetails = (fileId: string) => {
|
||||
const fileLoc = path.join(
|
||||
process.cwd(),
|
||||
'./uploads',
|
||||
fileId + '-extracted.json',
|
||||
);
|
||||
|
||||
const parsedFile = JSON.parse(fs.readFileSync(fileLoc, 'utf8'));
|
||||
|
||||
return {
|
||||
name: parsedFile.title,
|
||||
fileId: fileId,
|
||||
};
|
||||
};
|
|
@ -13,6 +13,7 @@ import db from '../db';
|
|||
import { chats, messages as messagesSchema } from '../db/schema';
|
||||
import { eq, asc, gt } from 'drizzle-orm';
|
||||
import crypto from 'crypto';
|
||||
import { getFileDetails } from '../utils/files';
|
||||
|
||||
type Message = {
|
||||
messageId: string;
|
||||
|
@ -26,6 +27,7 @@ type WSMessage = {
|
|||
type: string;
|
||||
focusMode: string;
|
||||
history: Array<[string, string]>;
|
||||
files: Array<string>;
|
||||
};
|
||||
|
||||
export const searchHandlers = {
|
||||
|
@ -141,6 +143,7 @@ export const handleMessage = async (
|
|||
llm,
|
||||
embeddings,
|
||||
parsedWSMessage.optimizationMode,
|
||||
parsedWSMessage.files,
|
||||
);
|
||||
|
||||
handleEmitterEvents(emitter, ws, aiMessageId, parsedMessage.chatId);
|
||||
|
@ -157,6 +160,7 @@ export const handleMessage = async (
|
|||
title: parsedMessage.content,
|
||||
createdAt: new Date().toString(),
|
||||
focusMode: parsedWSMessage.focusMode,
|
||||
files: parsedWSMessage.files.map(getFileDetails),
|
||||
})
|
||||
.execute();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue