From 7c6ee2ead1ee0489ffdfaf6ca04e31ddc7ff20d9 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:28:31 +0530 Subject: [PATCH] feat(video-search): handle custom OpenAI --- src/routes/videos.ts | 62 ++++++++++++++++++++++++++++------ ui/components/MessageBox.tsx | 4 +-- ui/components/SearchImages.tsx | 19 ++++++++--- ui/components/SearchVideos.tsx | 19 ++++++++--- ui/lib/actions.ts | 15 ++++++-- 5 files changed, 94 insertions(+), 25 deletions(-) diff --git a/src/routes/videos.ts b/src/routes/videos.ts index 9d43fd2..a2555f5 100644 --- a/src/routes/videos.ts +++ b/src/routes/videos.ts @@ -4,14 +4,28 @@ import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; import handleVideoSearch from '../agents/videoSearchAgent'; +import { ChatOpenAI } from '@langchain/openai'; const router = express.Router(); +interface ChatModel { + provider: string; + model: string; + customOpenAIBaseURL?: string; + customOpenAIKey?: string; +} + +interface VideoSearchBody { + query: string; + chatHistory: any[]; + chatModel?: ChatModel; +} + router.post('/', async (req, res) => { try { - let { query, chat_history, chat_model_provider, chat_model } = req.body; + let body: VideoSearchBody = req.body; - chat_history = chat_history.map((msg: any) => { + const chatHistory = body.chatHistory.map((msg: any) => { if (msg.role === 'user') { return new HumanMessage(msg.content); } else if (msg.role === 'assistant') { @@ -19,22 +33,50 @@ router.post('/', async (req, res) => { } }); - const chatModels = await getAvailableChatModelProviders(); - const provider = chat_model_provider ?? Object.keys(chatModels)[0]; - const chatModel = chat_model ?? Object.keys(chatModels[provider])[0]; + const chatModelProviders = await getAvailableChatModelProviders(); + + const chatModelProvider = + body.chatModel?.provider || Object.keys(chatModelProviders)[0]; + const chatModel = + body.chatModel?.model || + Object.keys(chatModelProviders[chatModelProvider])[0]; let llm: BaseChatModel | undefined; - if (chatModels[provider] && chatModels[provider][chatModel]) { - llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; + if (body.chatModel?.provider === 'custom_openai') { + if ( + !body.chatModel?.customOpenAIBaseURL || + !body.chatModel?.customOpenAIKey + ) { + return res + .status(400) + .json({ message: 'Missing custom OpenAI base URL or key' }); + } + + llm = new ChatOpenAI({ + modelName: body.chatModel.model, + openAIApiKey: body.chatModel.customOpenAIKey, + temperature: 0.7, + configuration: { + baseURL: body.chatModel.customOpenAIBaseURL, + }, + }) as unknown as BaseChatModel; + } else if ( + chatModelProviders[chatModelProvider] && + chatModelProviders[chatModelProvider][chatModel] + ) { + llm = chatModelProviders[chatModelProvider][chatModel] + .model as unknown as BaseChatModel | undefined; } if (!llm) { - res.status(500).json({ message: 'Invalid LLM model selected' }); - return; + return res.status(400).json({ message: 'Invalid model selected' }); } - const videos = await handleVideoSearch({ chat_history, query }, llm); + const videos = await handleVideoSearch( + { chat_history: chatHistory, query: body.query }, + llm, + ); res.status(200).json({ videos }); } catch (err) { diff --git a/ui/components/MessageBox.tsx b/ui/components/MessageBox.tsx index b111088..5222c7c 100644 --- a/ui/components/MessageBox.tsx +++ b/ui/components/MessageBox.tsx @@ -186,10 +186,10 @@ const MessageBox = ({