From c870ee0e73568e06086d852eae577a4b2f181191 Mon Sep 17 00:00:00 2001 From: sagitchu <601096721@qq.com> Date: Sun, 25 Aug 2024 11:38:53 +0800 Subject: [PATCH] add custom Embedding --- src/lib/providers/index.ts | 2 + src/websocket/connectionManager.ts | 13 +++++- ui/components/SettingsDialog.tsx | 75 ++++++++++++++++++------------ 3 files changed, 58 insertions(+), 32 deletions(-) diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index d919fd4..49f2e22 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -42,5 +42,7 @@ export const getAvailableEmbeddingModelProviders = async () => { } } + models['custom_openai'] = {}; + return models; }; diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index 70e20d9..b77106c 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -8,7 +8,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; import type { IncomingMessage } from 'http'; import logger from '../utils/logger'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai'; export const handleConnection = async ( ws: WebSocket, @@ -61,11 +61,20 @@ export const handleConnection = async ( if ( embeddingModelProviders[embeddingModelProvider] && - embeddingModelProviders[embeddingModelProvider][embeddingModel] + embeddingModelProviders[embeddingModelProvider][embeddingModel] && + embeddingModelProvider != 'custom_openai' ) { embeddings = embeddingModelProviders[embeddingModelProvider][ embeddingModel ] as Embeddings | undefined; + } else if (embeddingModelProvider == 'custom_openai') { + embeddings = new OpenAIEmbeddings({ + modelName: embeddingModel, + openAIApiKey: searchParams.get('openAIApiKey'), + configuration: { + baseURL: searchParams.get('openAIBaseURL'), + }, + }) as unknown as Embeddings } if (!llm || !embeddings) { diff --git a/ui/components/SettingsDialog.tsx b/ui/components/SettingsDialog.tsx index 171e812..7b84bf5 100644 --- a/ui/components/SettingsDialog.tsx +++ b/ui/components/SettingsDialog.tsx @@ -9,7 +9,7 @@ import React, { } from 'react'; import ThemeSwitcher from './theme/Switcher'; -interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {} +interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> { } const Input = ({ className, ...restProps }: InputProps) => { return ( @@ -258,30 +258,30 @@ const SettingsDialog = ({ options={(() => { const chatModelProvider = config.chatModelProviders[ - selectedChatModelProvider + selectedChatModelProvider ]; return chatModelProvider ? chatModelProvider.length > 0 ? chatModelProvider.map((model) => ({ - value: model, - label: model, - })) + value: model, + label: model, + })) : [ - { - value: '', - label: 'No models available', - disabled: true, - }, - ] - : [ { value: '', - label: - 'Invalid provider, please check backend logs', + label: 'No models available', disabled: true, }, - ]; + ] + : [ + { + value: '', + label: + 'Invalid provider, please check backend logs', + disabled: true, + }, + ]; })()} /> </div> @@ -355,7 +355,7 @@ const SettingsDialog = ({ /> </div> )} - {selectedEmbeddingModelProvider && ( + {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider != 'custom_openai' && ( <div className="flex flex-col space-y-1"> <p className="text-black/70 dark:text-white/70 text-sm"> Embedding Model @@ -368,34 +368,49 @@ const SettingsDialog = ({ options={(() => { const embeddingModelProvider = config.embeddingModelProviders[ - selectedEmbeddingModelProvider + selectedEmbeddingModelProvider ]; return embeddingModelProvider ? embeddingModelProvider.length > 0 ? embeddingModelProvider.map((model) => ({ - label: model, - value: model, - })) + label: model, + value: model, + })) : [ - { - label: 'No embedding models available', - value: '', - disabled: true, - }, - ] - : [ { - label: - 'Invalid provider, please check backend logs', + label: 'No embedding models available', value: '', disabled: true, }, - ]; + ] + : [ + { + label: + 'Invalid provider, please check backend logs', + value: '', + disabled: true, + }, + ]; })()} /> </div> )} + {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider === 'custom_openai' && ( + <div className="flex flex-col space-y-1"> + <p className="text-black/70 dark:text-white/70 text-sm"> + Embedding Model + </p> + <Input + type="text" + placeholder="Embedding Model name" + defaultValue={selectedEmbeddingModel!} + onChange={(e) => + setSelectedEmbeddingModel(e.target.value) + } + /> + </div> + )} <div className="flex flex-col space-y-1"> <p className="text-black/70 dark:text-white/70 text-sm"> OpenAI API Key