feat(app): add file uploads

This commit is contained in:
ItzCrazyKns 2024-11-23 15:04:19 +05:30
parent c650d1c3d9
commit 4b89008f3a
25 changed files with 1035 additions and 86 deletions

View file

@ -20,10 +20,12 @@ import eventEmitter from 'events';
import computeSimilarity from '../utils/computeSimilarity';
import logger from '../utils/logger';
import LineListOutputParser from '../lib/outputParsers/listLineOutputParser';
import { getDocumentsFromLinks } from '../lib/linkDocument';
import LineOutputParser from '../lib/outputParsers/lineOutputParser';
import { IterableReadableStream } from '@langchain/core/utils/stream';
import { ChatOpenAI } from '@langchain/openai';
import path from 'path';
import fs from 'fs';
import { getDocumentsFromLinks } from '../utils/documents';
const basicSearchRetrieverPrompt = `
You are an AI question rephraser. You will be given a conversation and a follow-up question, you will have to rephrase the follow up question so it is a standalone question and can be used by another LLM to search the web for information to answer it.
@ -316,6 +318,7 @@ const createBasicWebSearchAnsweringChain = (
llm: BaseChatModel,
embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
fileIds: string[],
) => {
const basicWebSearchRetrieverChain = createBasicWebSearchRetrieverChain(llm);
@ -336,8 +339,32 @@ const createBasicWebSearchAnsweringChain = (
return docs;
}
const filesData = fileIds
.map((file) => {
const filePath = path.join(process.cwd(), 'uploads', file);
const contentPath = filePath + '-extracted.json';
const embeddingsPath = filePath + '-embeddings.json';
const content = JSON.parse(fs.readFileSync(contentPath, 'utf8'));
const embeddings = JSON.parse(fs.readFileSync(embeddingsPath, 'utf8'));
const fileSimilaritySearchObject = content.contents.map(
(c: string, i) => {
return {
fileName: content.title,
content: c,
embeddings: embeddings.embeddings[i],
};
},
);
return fileSimilaritySearchObject;
})
.flat();
if (query.toLocaleLowerCase() === 'summarize') {
return docs.slice(0, 15)
return docs.slice(0, 15);
}
const docsWithContent = docs.filter(
@ -345,7 +372,43 @@ const createBasicWebSearchAnsweringChain = (
);
if (optimizationMode === 'speed') {
return docsWithContent.slice(0, 15);
if (filesData.length > 0) {
const [queryEmbedding] = await Promise.all([
embeddings.embedQuery(query),
]);
const fileDocs = filesData.map((fileData) => {
return new Document({
pageContent: fileData.content,
metadata: {
title: fileData.fileName,
url: `File`,
},
});
});
const similarity = filesData.map((fileData, i) => {
const sim = computeSimilarity(queryEmbedding, fileData.embeddings);
return {
index: i,
similarity: sim,
};
});
const sortedDocs = similarity
.filter((sim) => sim.similarity > 0.3)
.sort((a, b) => b.similarity - a.similarity)
.slice(0, 8)
.map((sim) => fileDocs[sim.index]);
return [
...sortedDocs,
...docsWithContent.slice(0, 15 - sortedDocs.length),
];
} else {
return docsWithContent.slice(0, 15);
}
} else if (optimizationMode === 'balanced') {
const [docEmbeddings, queryEmbedding] = await Promise.all([
embeddings.embedDocuments(
@ -354,6 +417,20 @@ const createBasicWebSearchAnsweringChain = (
embeddings.embedQuery(query),
]);
docsWithContent.push(
...filesData.map((fileData) => {
return new Document({
pageContent: fileData.content,
metadata: {
title: fileData.fileName,
url: `File`,
},
});
}),
);
docEmbeddings.push(...filesData.map((fileData) => fileData.embeddings));
const similarity = docEmbeddings.map((docEmbedding, i) => {
const sim = computeSimilarity(queryEmbedding, docEmbedding);
@ -408,6 +485,7 @@ const basicWebSearch = (
llm: BaseChatModel,
embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
fileIds: string[],
) => {
const emitter = new eventEmitter();
@ -416,6 +494,7 @@ const basicWebSearch = (
llm,
embeddings,
optimizationMode,
fileIds,
);
const stream = basicWebSearchAnsweringChain.streamEvents(
@ -446,6 +525,7 @@ const handleWebSearch = (
llm: BaseChatModel,
embeddings: Embeddings,
optimizationMode: 'speed' | 'balanced' | 'quality',
fileIds: string[],
) => {
const emitter = basicWebSearch(
message,
@ -453,6 +533,7 @@ const handleWebSearch = (
llm,
embeddings,
optimizationMode,
fileIds,
);
return emitter;
};

View file

@ -1,3 +1,4 @@
import { sql } from 'drizzle-orm';
import { text, integer, sqliteTable } from 'drizzle-orm/sqlite-core';
export const messages = sqliteTable('messages', {
@ -11,9 +12,17 @@ export const messages = sqliteTable('messages', {
}),
});
interface File {
name: string;
fileId: string;
}
export const chats = sqliteTable('chats', {
id: text('id').primaryKey(),
title: text('title').notNull(),
createdAt: text('createdAt').notNull(),
focusMode: text('focusMode').notNull(),
files: text('files', { mode: 'json' })
.$type<File[]>()
.default(sql`'[]'`),
});

View file

@ -6,7 +6,7 @@ import { ChatOllama } from '@langchain/community/chat_models/ollama';
export const loadOllamaChatModels = async () => {
const ollamaEndpoint = getOllamaApiEndpoint();
const keepAlive = getKeepAlive();
if (!ollamaEndpoint) return {};
try {
@ -25,7 +25,7 @@ export const loadOllamaChatModels = async () => {
baseUrl: ollamaEndpoint,
model: model.model,
temperature: 0.7,
keepAlive: keepAlive
keepAlive: keepAlive,
}),
};

View file

@ -7,6 +7,7 @@ import suggestionsRouter from './suggestions';
import chatsRouter from './chats';
import searchRouter from './search';
import discoverRouter from './discover';
import uploadsRouter from './uploads';
const router = express.Router();
@ -18,5 +19,6 @@ router.use('/suggestions', suggestionsRouter);
router.use('/chats', chatsRouter);
router.use('/search', searchRouter);
router.use('/discover', discoverRouter);
router.use('/uploads', uploadsRouter);
export default router;

151
src/routes/uploads.ts Normal file
View file

@ -0,0 +1,151 @@
import express from 'express';
import logger from '../utils/logger';
import multer from 'multer';
import path from 'path';
import crypto from 'crypto';
import fs from 'fs';
import { Embeddings } from '@langchain/core/embeddings';
import { getAvailableEmbeddingModelProviders } from '../lib/providers';
import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf';
import { DocxLoader } from '@langchain/community/document_loaders/fs/docx';
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters';
import { Document } from 'langchain/document';
const router = express.Router();
const splitter = new RecursiveCharacterTextSplitter({
chunkSize: 500,
chunkOverlap: 100,
});
const storage = multer.diskStorage({
destination: (req, file, cb) => {
cb(null, path.join(process.cwd(), './uploads'));
},
filename: (req, file, cb) => {
const splitedFileName = file.originalname.split('.');
const fileExtension = splitedFileName[splitedFileName.length - 1];
if (!['pdf', 'docx', 'txt'].includes(fileExtension)) {
return cb(new Error('File type is not supported'), '');
}
cb(null, `${crypto.randomBytes(16).toString('hex')}.${fileExtension}`);
},
});
const upload = multer({ storage });
router.post(
'/',
upload.fields([
{ name: 'files' },
{ name: 'embedding_model', maxCount: 1 },
{ name: 'embedding_model_provider', maxCount: 1 },
]),
async (req, res) => {
try {
const { embedding_model, embedding_model_provider } = req.body;
if (!embedding_model || !embedding_model_provider) {
res
.status(400)
.json({ message: 'Missing embedding model or provider' });
return;
}
const embeddingModels = await getAvailableEmbeddingModelProviders();
const provider =
embedding_model_provider ?? Object.keys(embeddingModels)[0];
const embeddingModel: Embeddings =
embedding_model ?? Object.keys(embeddingModels[provider])[0];
let embeddingsModel: Embeddings | undefined;
if (
embeddingModels[provider] &&
embeddingModels[provider][embeddingModel]
) {
embeddingsModel = embeddingModels[provider][embeddingModel].model as
| Embeddings
| undefined;
}
if (!embeddingsModel) {
res.status(400).json({ message: 'Invalid LLM model selected' });
return;
}
const files = req.files['files'] as Express.Multer.File[];
if (!files || files.length === 0) {
res.status(400).json({ message: 'No files uploaded' });
return;
}
await Promise.all(
files.map(async (file) => {
let docs: Document[] = [];
if (file.mimetype === 'application/pdf') {
const loader = new PDFLoader(file.path);
docs = await loader.load();
} else if (
file.mimetype ===
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
) {
const loader = new DocxLoader(file.path);
docs = await loader.load();
} else if (file.mimetype === 'text/plain') {
const text = fs.readFileSync(file.path, 'utf-8');
docs = [
new Document({
pageContent: text,
metadata: {
title: file.originalname,
},
}),
];
}
const splitted = await splitter.splitDocuments(docs);
const json = JSON.stringify({
title: file.originalname,
contents: splitted.map((doc) => doc.pageContent),
});
const pathToSave = file.path.replace(/\.\w+$/, '-extracted.json');
fs.writeFileSync(pathToSave, json);
const embeddings = await embeddingsModel.embedDocuments(
splitted.map((doc) => doc.pageContent),
);
const embeddingsJSON = JSON.stringify({
title: file.originalname,
embeddings: embeddings,
});
const pathToSaveEmbeddings = file.path.replace(
/\.\w+$/,
'-embeddings.json',
);
fs.writeFileSync(pathToSaveEmbeddings, embeddingsJSON);
}),
);
res.status(200).json({
files: files.map((file) => {
return {
fileName: file.originalname,
fileExtension: file.filename.split('.').pop(),
fileId: file.filename.replace(/\.\w+$/, ''),
};
}),
});
} catch (err: any) {
logger.error(`Error in uploading file results: ${err.message}`);
res.status(500).json({ message: 'An error has occurred.' });
}
},
);
export default router;

View file

@ -3,7 +3,7 @@ import { htmlToText } from 'html-to-text';
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
import { Document } from '@langchain/core/documents';
import pdfParse from 'pdf-parse';
import logger from '../utils/logger';
import logger from './logger';
export const getDocumentsFromLinks = async ({ links }: { links: string[] }) => {
const splitter = new RecursiveCharacterTextSplitter();

16
src/utils/files.ts Normal file
View file

@ -0,0 +1,16 @@
import path from 'path';
import fs from 'fs';
export const getFileDetails = (fileId: string) => {
const fileLoc = path.join(
process.cwd(),
'./uploads',
fileId + '-extracted.json',
);
const parsedFile = JSON.parse(fs.readFileSync(fileLoc, 'utf8'));
return {
name: parsedFile.title,
fileId: fileId,
};
};

View file

@ -13,6 +13,7 @@ import db from '../db';
import { chats, messages as messagesSchema } from '../db/schema';
import { eq, asc, gt } from 'drizzle-orm';
import crypto from 'crypto';
import { getFileDetails } from '../utils/files';
type Message = {
messageId: string;
@ -26,6 +27,7 @@ type WSMessage = {
type: string;
focusMode: string;
history: Array<[string, string]>;
files: Array<string>;
};
export const searchHandlers = {
@ -141,6 +143,7 @@ export const handleMessage = async (
llm,
embeddings,
parsedWSMessage.optimizationMode,
parsedWSMessage.files,
);
handleEmitterEvents(emitter, ws, aiMessageId, parsedMessage.chatId);
@ -157,6 +160,7 @@ export const handleMessage = async (
title: parsedMessage.content,
createdAt: new Date().toString(),
focusMode: parsedWSMessage.focusMode,
files: parsedWSMessage.files.map(getFileDetails),
})
.execute();
}