From c870ee0e73568e06086d852eae577a4b2f181191 Mon Sep 17 00:00:00 2001 From: sagitchu <601096721@qq.com> Date: Sun, 25 Aug 2024 11:38:53 +0800 Subject: [PATCH 1/3] add custom Embedding --- src/lib/providers/index.ts | 2 + src/websocket/connectionManager.ts | 13 +++++- ui/components/SettingsDialog.tsx | 75 ++++++++++++++++++------------ 3 files changed, 58 insertions(+), 32 deletions(-) diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index d919fd4..49f2e22 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -42,5 +42,7 @@ export const getAvailableEmbeddingModelProviders = async () => { } } + models['custom_openai'] = {}; + return models; }; diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index 70e20d9..b77106c 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -8,7 +8,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; import type { IncomingMessage } from 'http'; import logger from '../utils/logger'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai'; export const handleConnection = async ( ws: WebSocket, @@ -61,11 +61,20 @@ export const handleConnection = async ( if ( embeddingModelProviders[embeddingModelProvider] && - embeddingModelProviders[embeddingModelProvider][embeddingModel] + embeddingModelProviders[embeddingModelProvider][embeddingModel] && + embeddingModelProvider != 'custom_openai' ) { embeddings = embeddingModelProviders[embeddingModelProvider][ embeddingModel ] as Embeddings | undefined; + } else if (embeddingModelProvider == 'custom_openai') { + embeddings = new OpenAIEmbeddings({ + modelName: embeddingModel, + openAIApiKey: searchParams.get('openAIApiKey'), + configuration: { + baseURL: searchParams.get('openAIBaseURL'), + }, + }) as unknown as Embeddings } if (!llm || !embeddings) { diff --git a/ui/components/SettingsDialog.tsx b/ui/components/SettingsDialog.tsx index 171e812..7b84bf5 100644 --- a/ui/components/SettingsDialog.tsx +++ b/ui/components/SettingsDialog.tsx @@ -9,7 +9,7 @@ import React, { } from 'react'; import ThemeSwitcher from './theme/Switcher'; -interface InputProps extends React.InputHTMLAttributes {} +interface InputProps extends React.InputHTMLAttributes { } const Input = ({ className, ...restProps }: InputProps) => { return ( @@ -258,30 +258,30 @@ const SettingsDialog = ({ options={(() => { const chatModelProvider = config.chatModelProviders[ - selectedChatModelProvider + selectedChatModelProvider ]; return chatModelProvider ? chatModelProvider.length > 0 ? chatModelProvider.map((model) => ({ - value: model, - label: model, - })) + value: model, + label: model, + })) : [ - { - value: '', - label: 'No models available', - disabled: true, - }, - ] - : [ { value: '', - label: - 'Invalid provider, please check backend logs', + label: 'No models available', disabled: true, }, - ]; + ] + : [ + { + value: '', + label: + 'Invalid provider, please check backend logs', + disabled: true, + }, + ]; })()} /> @@ -355,7 +355,7 @@ const SettingsDialog = ({ /> )} - {selectedEmbeddingModelProvider && ( + {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider != 'custom_openai' && (

Embedding Model @@ -368,34 +368,49 @@ const SettingsDialog = ({ options={(() => { const embeddingModelProvider = config.embeddingModelProviders[ - selectedEmbeddingModelProvider + selectedEmbeddingModelProvider ]; return embeddingModelProvider ? embeddingModelProvider.length > 0 ? embeddingModelProvider.map((model) => ({ - label: model, - value: model, - })) + label: model, + value: model, + })) : [ - { - label: 'No embedding models available', - value: '', - disabled: true, - }, - ] - : [ { - label: - 'Invalid provider, please check backend logs', + label: 'No embedding models available', value: '', disabled: true, }, - ]; + ] + : [ + { + label: + 'Invalid provider, please check backend logs', + value: '', + disabled: true, + }, + ]; })()} />

)} + {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider === 'custom_openai' && ( +
+

+ Embedding Model +

+ + setSelectedEmbeddingModel(e.target.value) + } + /> +
+ )}

OpenAI API Key From 441b218245255c0b8653eef0f8e88dfeda5d4dbf Mon Sep 17 00:00:00 2001 From: sagitchu <601096721@qq.com> Date: Fri, 30 Aug 2024 15:15:49 +0800 Subject: [PATCH 2/3] add test --- ui/components/ChatWindow.tsx | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index b3d0089..3d00cb2 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -130,6 +130,7 @@ const useSocket = ( if ( embeddingModelProvider && + embeddingModelProvider != 'custom_openai' && !embeddingModelProviders[embeddingModelProvider][embeddingModel] ) { embeddingModel = Object.keys( @@ -159,6 +160,17 @@ const useSocket = ( searchParams.append('embeddingModel', embeddingModel!); searchParams.append('embeddingModelProvider', embeddingModelProvider); + if (embeddingModelProvider === 'custom_openai') { + searchParams.append( + 'openAIApiKey', + localStorage.getItem('openAIApiKey')!, + ); + searchParams.append( + 'openAIBaseURL', + localStorage.getItem('openAIBaseURL')!, + ); + } + wsURL.search = searchParams.toString(); const ws = new WebSocket(wsURL.toString()); From ee9d1d24710fbc745752373e0e3971c40468bfe7 Mon Sep 17 00:00:00 2001 From: sagitchu <601096721@qq.com> Date: Fri, 30 Aug 2024 15:33:39 +0800 Subject: [PATCH 3/3] add test --- ui/components/SettingsDialog.tsx | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ui/components/SettingsDialog.tsx b/ui/components/SettingsDialog.tsx index 7b84bf5..b03bd5c 100644 --- a/ui/components/SettingsDialog.tsx +++ b/ui/components/SettingsDialog.tsx @@ -340,9 +340,13 @@ const SettingsDialog = ({ value={selectedEmbeddingModelProvider ?? undefined} onChange={(e) => { setSelectedEmbeddingModelProvider(e.target.value); - setSelectedEmbeddingModel( - config.embeddingModelProviders[e.target.value][0], - ); + if (e.target.value === 'custom_openai') { + setSelectedEmbeddingModel(''); + } else { + setSelectedEmbeddingModel( + config.embeddingModelProviders[e.target.value][0], + ); + } }} options={Object.keys( config.embeddingModelProviders,