From 2685455c3e5eee6b80f467450e070741d8b4cf2d Mon Sep 17 00:00:00 2001
From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com>
Date: Wed, 30 Oct 2024 10:28:31 +0530
Subject: [PATCH] feat(video-search): handle custom OpenAI

---
 src/routes/videos.ts           | 62 ++++++++++++++++++++++++++++------
 ui/components/MessageBox.tsx   |  4 +--
 ui/components/SearchImages.tsx | 19 ++++++++---
 ui/components/SearchVideos.tsx | 19 ++++++++---
 ui/lib/actions.ts              | 15 ++++++--
 5 files changed, 94 insertions(+), 25 deletions(-)

diff --git a/src/routes/videos.ts b/src/routes/videos.ts
index 9d43fd2..a2555f5 100644
--- a/src/routes/videos.ts
+++ b/src/routes/videos.ts
@@ -4,14 +4,28 @@ import { getAvailableChatModelProviders } from '../lib/providers';
 import { HumanMessage, AIMessage } from '@langchain/core/messages';
 import logger from '../utils/logger';
 import handleVideoSearch from '../agents/videoSearchAgent';
+import { ChatOpenAI } from '@langchain/openai';
 
 const router = express.Router();
 
+interface ChatModel {
+  provider: string;
+  model: string;
+  customOpenAIBaseURL?: string;
+  customOpenAIKey?: string;
+}
+
+interface VideoSearchBody {
+  query: string;
+  chatHistory: any[];
+  chatModel?: ChatModel;
+}
+
 router.post('/', async (req, res) => {
   try {
-    let { query, chat_history, chat_model_provider, chat_model } = req.body;
+    let body: VideoSearchBody = req.body;
 
-    chat_history = chat_history.map((msg: any) => {
+    const chatHistory = body.chatHistory.map((msg: any) => {
       if (msg.role === 'user') {
         return new HumanMessage(msg.content);
       } else if (msg.role === 'assistant') {
@@ -19,22 +33,50 @@ router.post('/', async (req, res) => {
       }
     });
 
-    const chatModels = await getAvailableChatModelProviders();
-    const provider = chat_model_provider ?? Object.keys(chatModels)[0];
-    const chatModel = chat_model ?? Object.keys(chatModels[provider])[0];
+    const chatModelProviders = await getAvailableChatModelProviders();
+
+    const chatModelProvider =
+      body.chatModel?.provider || Object.keys(chatModelProviders)[0];
+    const chatModel =
+      body.chatModel?.model ||
+      Object.keys(chatModelProviders[chatModelProvider])[0];
 
     let llm: BaseChatModel | undefined;
 
-    if (chatModels[provider] && chatModels[provider][chatModel]) {
-      llm = chatModels[provider][chatModel].model as BaseChatModel | undefined;
+    if (body.chatModel?.provider === 'custom_openai') {
+      if (
+        !body.chatModel?.customOpenAIBaseURL ||
+        !body.chatModel?.customOpenAIKey
+      ) {
+        return res
+          .status(400)
+          .json({ message: 'Missing custom OpenAI base URL or key' });
+      }
+
+      llm = new ChatOpenAI({
+        modelName: body.chatModel.model,
+        openAIApiKey: body.chatModel.customOpenAIKey,
+        temperature: 0.7,
+        configuration: {
+          baseURL: body.chatModel.customOpenAIBaseURL,
+        },
+      }) as unknown as BaseChatModel;
+    } else if (
+      chatModelProviders[chatModelProvider] &&
+      chatModelProviders[chatModelProvider][chatModel]
+    ) {
+      llm = chatModelProviders[chatModelProvider][chatModel]
+        .model as unknown as BaseChatModel | undefined;
     }
 
     if (!llm) {
-      res.status(500).json({ message: 'Invalid LLM model selected' });
-      return;
+      return res.status(400).json({ message: 'Invalid model selected' });
     }
 
-    const videos = await handleVideoSearch({ chat_history, query }, llm);
+    const videos = await handleVideoSearch(
+      { chat_history: chatHistory, query: body.query },
+      llm,
+    );
 
     res.status(200).json({ videos });
   } catch (err) {
diff --git a/ui/components/MessageBox.tsx b/ui/components/MessageBox.tsx
index b111088..5222c7c 100644
--- a/ui/components/MessageBox.tsx
+++ b/ui/components/MessageBox.tsx
@@ -186,10 +186,10 @@ const MessageBox = ({
           <div className="lg:sticky lg:top-20 flex flex-col items-center space-y-3 w-full lg:w-3/12 z-30 h-full pb-4">
             <SearchImages
               query={history[messageIndex - 1].content}
-              chat_history={history.slice(0, messageIndex - 1)}
+              chatHistory={history.slice(0, messageIndex - 1)}
             />
             <SearchVideos
-              chat_history={history.slice(0, messageIndex - 1)}
+              chatHistory={history.slice(0, messageIndex - 1)}
               query={history[messageIndex - 1].content}
             />
           </div>
diff --git a/ui/components/SearchImages.tsx b/ui/components/SearchImages.tsx
index 6025925..b083af7 100644
--- a/ui/components/SearchImages.tsx
+++ b/ui/components/SearchImages.tsx
@@ -13,10 +13,10 @@ type Image = {
 
 const SearchImages = ({
   query,
-  chat_history,
+  chatHistory,
 }: {
   query: string;
-  chat_history: Message[];
+  chatHistory: Message[];
 }) => {
   const [images, setImages] = useState<Image[] | null>(null);
   const [loading, setLoading] = useState(false);
@@ -33,6 +33,9 @@ const SearchImages = ({
             const chatModelProvider = localStorage.getItem('chatModelProvider');
             const chatModel = localStorage.getItem('chatModel');
 
+            const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL');
+            const customOpenAIKey = localStorage.getItem('openAIApiKey');
+
             const res = await fetch(
               `${process.env.NEXT_PUBLIC_API_URL}/images`,
               {
@@ -42,9 +45,15 @@ const SearchImages = ({
                 },
                 body: JSON.stringify({
                   query: query,
-                  chat_history: chat_history,
-                  chat_model_provider: chatModelProvider,
-                  chat_model: chatModel,
+                  chatHistory: chatHistory,
+                  chatModel: {
+                    provider: chatModelProvider,
+                    model: chatModel,
+                    ...(chatModelProvider === 'custom_openai' && {
+                      customOpenAIBaseURL: customOpenAIBaseURL,
+                      customOpenAIKey: customOpenAIKey,
+                    }),
+                  },
                 }),
               },
             );
diff --git a/ui/components/SearchVideos.tsx b/ui/components/SearchVideos.tsx
index fec229c..2d820ef 100644
--- a/ui/components/SearchVideos.tsx
+++ b/ui/components/SearchVideos.tsx
@@ -26,10 +26,10 @@ declare module 'yet-another-react-lightbox' {
 
 const Searchvideos = ({
   query,
-  chat_history,
+  chatHistory,
 }: {
   query: string;
-  chat_history: Message[];
+  chatHistory: Message[];
 }) => {
   const [videos, setVideos] = useState<Video[] | null>(null);
   const [loading, setLoading] = useState(false);
@@ -46,6 +46,9 @@ const Searchvideos = ({
             const chatModelProvider = localStorage.getItem('chatModelProvider');
             const chatModel = localStorage.getItem('chatModel');
 
+            const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL');
+            const customOpenAIKey = localStorage.getItem('openAIApiKey');
+
             const res = await fetch(
               `${process.env.NEXT_PUBLIC_API_URL}/videos`,
               {
@@ -55,9 +58,15 @@ const Searchvideos = ({
                 },
                 body: JSON.stringify({
                   query: query,
-                  chat_history: chat_history,
-                  chat_model_provider: chatModelProvider,
-                  chat_model: chatModel,
+                  chatHistory: chatHistory,
+                  chatModel: {
+                    provider: chatModelProvider,
+                    model: chatModel,
+                    ...(chatModelProvider === 'custom_openai' && {
+                      customOpenAIBaseURL: customOpenAIBaseURL,
+                      customOpenAIKey: customOpenAIKey,
+                    }),
+                  },
                 }),
               },
             );
diff --git a/ui/lib/actions.ts b/ui/lib/actions.ts
index d7eb71f..a4409b0 100644
--- a/ui/lib/actions.ts
+++ b/ui/lib/actions.ts
@@ -4,15 +4,24 @@ export const getSuggestions = async (chatHisory: Message[]) => {
   const chatModel = localStorage.getItem('chatModel');
   const chatModelProvider = localStorage.getItem('chatModelProvider');
 
+  const customOpenAIKey = localStorage.getItem('openAIApiKey');
+  const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL');
+
   const res = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/suggestions`, {
     method: 'POST',
     headers: {
       'Content-Type': 'application/json',
     },
     body: JSON.stringify({
-      chat_history: chatHisory,
-      chat_model: chatModel,
-      chat_model_provider: chatModelProvider,
+      chatHistory: chatHisory,
+      chatModel: {
+        provider: chatModelProvider,
+        model: chatModel,
+        ...(chatModelProvider === 'custom_openai' && {
+          customOpenAIKey,
+          customOpenAIBaseURL,
+        }),
+      },
     }),
   });