diff --git a/src/routes/videos.ts b/src/routes/videos.ts index 9d43fd2..a2555f5 100644 --- a/src/routes/videos.ts +++ b/src/routes/videos.ts @@ -4,14 +4,28 @@ import { getAvailableChatModelProviders } from '../lib/providers'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; import handleVideoSearch from '../agents/videoSearchAgent'; +import { ChatOpenAI } from '@langchain/openai'; const router = express.Router(); +interface ChatModel { + provider: string; + model: string; + customOpenAIBaseURL?: string; + customOpenAIKey?: string; +} + +interface VideoSearchBody { + query: string; + chatHistory: any[]; + chatModel?: ChatModel; +} + router.post('/', async (req, res) => { try { - let { query, chat_history, chat_model_provider, chat_model } = req.body; + let body: VideoSearchBody = req.body; - chat_history = chat_history.map((msg: any) => { + const chatHistory = body.chatHistory.map((msg: any) => { if (msg.role === 'user') { return new HumanMessage(msg.content); } else if (msg.role === 'assistant') { @@ -19,22 +33,50 @@ router.post('/', async (req, res) => { } }); - const chatModels = await getAvailableChatModelProviders(); - const provider = chat_model_provider ?? Object.keys(chatModels)[0]; - const chatModel = chat_model ?? Object.keys(chatModels[provider])[0]; + const chatModelProviders = await getAvailableChatModelProviders(); + + const chatModelProvider = + body.chatModel?.provider || Object.keys(chatModelProviders)[0]; + const chatModel = + body.chatModel?.model || + Object.keys(chatModelProviders[chatModelProvider])[0]; let llm: BaseChatModel | undefined; - if (chatModels[provider] && chatModels[provider][chatModel]) { - llm = chatModels[provider][chatModel].model as BaseChatModel | undefined; + if (body.chatModel?.provider === 'custom_openai') { + if ( + !body.chatModel?.customOpenAIBaseURL || + !body.chatModel?.customOpenAIKey + ) { + return res + .status(400) + .json({ message: 'Missing custom OpenAI base URL or key' }); + } + + llm = new ChatOpenAI({ + modelName: body.chatModel.model, + openAIApiKey: body.chatModel.customOpenAIKey, + temperature: 0.7, + configuration: { + baseURL: body.chatModel.customOpenAIBaseURL, + }, + }) as unknown as BaseChatModel; + } else if ( + chatModelProviders[chatModelProvider] && + chatModelProviders[chatModelProvider][chatModel] + ) { + llm = chatModelProviders[chatModelProvider][chatModel] + .model as unknown as BaseChatModel | undefined; } if (!llm) { - res.status(500).json({ message: 'Invalid LLM model selected' }); - return; + return res.status(400).json({ message: 'Invalid model selected' }); } - const videos = await handleVideoSearch({ chat_history, query }, llm); + const videos = await handleVideoSearch( + { chat_history: chatHistory, query: body.query }, + llm, + ); res.status(200).json({ videos }); } catch (err) { diff --git a/ui/components/MessageBox.tsx b/ui/components/MessageBox.tsx index b111088..5222c7c 100644 --- a/ui/components/MessageBox.tsx +++ b/ui/components/MessageBox.tsx @@ -186,10 +186,10 @@ const MessageBox = ({
diff --git a/ui/components/SearchImages.tsx b/ui/components/SearchImages.tsx index 6025925..b083af7 100644 --- a/ui/components/SearchImages.tsx +++ b/ui/components/SearchImages.tsx @@ -13,10 +13,10 @@ type Image = { const SearchImages = ({ query, - chat_history, + chatHistory, }: { query: string; - chat_history: Message[]; + chatHistory: Message[]; }) => { const [images, setImages] = useState(null); const [loading, setLoading] = useState(false); @@ -33,6 +33,9 @@ const SearchImages = ({ const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModel = localStorage.getItem('chatModel'); + const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); + const customOpenAIKey = localStorage.getItem('openAIApiKey'); + const res = await fetch( `${process.env.NEXT_PUBLIC_API_URL}/images`, { @@ -42,9 +45,15 @@ const SearchImages = ({ }, body: JSON.stringify({ query: query, - chat_history: chat_history, - chat_model_provider: chatModelProvider, - chat_model: chatModel, + chatHistory: chatHistory, + chatModel: { + provider: chatModelProvider, + model: chatModel, + ...(chatModelProvider === 'custom_openai' && { + customOpenAIBaseURL: customOpenAIBaseURL, + customOpenAIKey: customOpenAIKey, + }), + }, }), }, ); diff --git a/ui/components/SearchVideos.tsx b/ui/components/SearchVideos.tsx index fec229c..2d820ef 100644 --- a/ui/components/SearchVideos.tsx +++ b/ui/components/SearchVideos.tsx @@ -26,10 +26,10 @@ declare module 'yet-another-react-lightbox' { const Searchvideos = ({ query, - chat_history, + chatHistory, }: { query: string; - chat_history: Message[]; + chatHistory: Message[]; }) => { const [videos, setVideos] = useState(null); const [loading, setLoading] = useState(false); @@ -46,6 +46,9 @@ const Searchvideos = ({ const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModel = localStorage.getItem('chatModel'); + const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); + const customOpenAIKey = localStorage.getItem('openAIApiKey'); + const res = await fetch( `${process.env.NEXT_PUBLIC_API_URL}/videos`, { @@ -55,9 +58,15 @@ const Searchvideos = ({ }, body: JSON.stringify({ query: query, - chat_history: chat_history, - chat_model_provider: chatModelProvider, - chat_model: chatModel, + chatHistory: chatHistory, + chatModel: { + provider: chatModelProvider, + model: chatModel, + ...(chatModelProvider === 'custom_openai' && { + customOpenAIBaseURL: customOpenAIBaseURL, + customOpenAIKey: customOpenAIKey, + }), + }, }), }, ); diff --git a/ui/lib/actions.ts b/ui/lib/actions.ts index d7eb71f..a4409b0 100644 --- a/ui/lib/actions.ts +++ b/ui/lib/actions.ts @@ -4,15 +4,24 @@ export const getSuggestions = async (chatHisory: Message[]) => { const chatModel = localStorage.getItem('chatModel'); const chatModelProvider = localStorage.getItem('chatModelProvider'); + const customOpenAIKey = localStorage.getItem('openAIApiKey'); + const customOpenAIBaseURL = localStorage.getItem('openAIBaseURL'); + const res = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/suggestions`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ - chat_history: chatHisory, - chat_model: chatModel, - chat_model_provider: chatModelProvider, + chatHistory: chatHisory, + chatModel: { + provider: chatModelProvider, + model: chatModel, + ...(chatModelProvider === 'custom_openai' && { + customOpenAIKey, + customOpenAIBaseURL, + }), + }, }), });