150 lines
4.2 KiB
TypeScript
150 lines
4.2 KiB
TypeScript
import express from 'express';
|
|
import logger from '../utils/logger';
|
|
import { BaseChatModel } from 'langchain/chat_models/base';
|
|
import { Embeddings } from 'langchain/embeddings/base';
|
|
import { ChatOpenAI } from '@langchain/openai';
|
|
import {
|
|
getAvailableChatModelProviders,
|
|
getAvailableEmbeddingModelProviders,
|
|
} from '../lib/providers';
|
|
import { searchHandlers } from '../websocket/messageHandler';
|
|
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
|
|
|
|
const router = express.Router();
|
|
|
|
interface chatModel {
|
|
provider: string;
|
|
model: string;
|
|
customOpenAIBaseURL?: string;
|
|
customOpenAIKey?: string;
|
|
}
|
|
|
|
interface embeddingModel {
|
|
provider: string;
|
|
model: string;
|
|
}
|
|
|
|
interface RequestBody {
|
|
focusMode: string;
|
|
chatModel?: chatModel;
|
|
embeddingModel?: embeddingModel;
|
|
query: string;
|
|
history: Array<[string, string]>;
|
|
}
|
|
|
|
router.post('/', async (req, res) => {
|
|
try {
|
|
const body: RequestBody = req.body;
|
|
|
|
if (!body.focusMode || !body.query) {
|
|
return res.status(400).json({ message: 'Missing focus mode or query' });
|
|
}
|
|
|
|
body.history = body.history || [];
|
|
|
|
const history: BaseMessage[] = body.history.map((msg) => {
|
|
if (msg[0] === 'human') {
|
|
return new HumanMessage({
|
|
content: msg[1],
|
|
});
|
|
} else {
|
|
return new AIMessage({
|
|
content: msg[1],
|
|
});
|
|
}
|
|
});
|
|
|
|
const [chatModelProviders, embeddingModelProviders] = await Promise.all([
|
|
getAvailableChatModelProviders(),
|
|
getAvailableEmbeddingModelProviders(),
|
|
]);
|
|
|
|
const chatModelProvider =
|
|
body.chatModel?.provider || Object.keys(chatModelProviders)[0];
|
|
const chatModel =
|
|
body.chatModel?.model ||
|
|
Object.keys(chatModelProviders[chatModelProvider])[0];
|
|
|
|
const embeddingModelProvider =
|
|
body.embeddingModel?.provider || Object.keys(embeddingModelProviders)[0];
|
|
const embeddingModel =
|
|
body.embeddingModel?.model ||
|
|
Object.keys(embeddingModelProviders[embeddingModelProvider])[0];
|
|
|
|
let llm: BaseChatModel | undefined;
|
|
let embeddings: Embeddings | undefined;
|
|
|
|
if (body.chatModel?.provider === 'custom_openai') {
|
|
if (
|
|
!body.chatModel?.customOpenAIBaseURL ||
|
|
!body.chatModel?.customOpenAIKey
|
|
) {
|
|
return res
|
|
.status(400)
|
|
.json({ message: 'Missing custom OpenAI base URL or key' });
|
|
}
|
|
|
|
llm = new ChatOpenAI({
|
|
modelName: body.chatModel.model,
|
|
openAIApiKey: body.chatModel.customOpenAIKey,
|
|
temperature: 0.7,
|
|
configuration: {
|
|
baseURL: body.chatModel.customOpenAIBaseURL,
|
|
},
|
|
}) as unknown as BaseChatModel;
|
|
} else if (
|
|
chatModelProviders[chatModelProvider] &&
|
|
chatModelProviders[chatModelProvider][chatModel]
|
|
) {
|
|
llm = chatModelProviders[chatModelProvider][chatModel]
|
|
.model as unknown as BaseChatModel | undefined;
|
|
}
|
|
|
|
if (
|
|
embeddingModelProviders[embeddingModelProvider] &&
|
|
embeddingModelProviders[embeddingModelProvider][embeddingModel]
|
|
) {
|
|
embeddings = embeddingModelProviders[embeddingModelProvider][
|
|
embeddingModel
|
|
].model as Embeddings | undefined;
|
|
}
|
|
|
|
if (!llm || !embeddings) {
|
|
return res.status(400).json({ message: 'Invalid model selected' });
|
|
}
|
|
|
|
const searchHandler = searchHandlers[body.focusMode];
|
|
|
|
if (!searchHandler) {
|
|
return res.status(400).json({ message: 'Invalid focus mode' });
|
|
}
|
|
|
|
const emitter = searchHandler(body.query, history, llm, embeddings);
|
|
|
|
let message = '';
|
|
let sources = [];
|
|
|
|
emitter.on('data', (data) => {
|
|
const parsedData = JSON.parse(data);
|
|
if (parsedData.type === 'response') {
|
|
message += parsedData.data;
|
|
} else if (parsedData.type === 'sources') {
|
|
sources = parsedData.data;
|
|
}
|
|
});
|
|
|
|
emitter.on('end', () => {
|
|
res.status(200).json({ message, sources });
|
|
});
|
|
|
|
emitter.on('error', (data) => {
|
|
const parsedData = JSON.parse(data);
|
|
res.status(500).json({ message: parsedData.data });
|
|
});
|
|
} catch (err: any) {
|
|
logger.error(`Error in getting search results: ${err.message}`);
|
|
res.status(500).json({ message: 'An error has occurred.' });
|
|
}
|
|
});
|
|
|
|
export default router;
|