From c870ee0e73568e06086d852eae577a4b2f181191 Mon Sep 17 00:00:00 2001
From: sagitchu <601096721@qq.com>
Date: Sun, 25 Aug 2024 11:38:53 +0800
Subject: [PATCH] add custom Embedding

---
 src/lib/providers/index.ts         |  2 +
 src/websocket/connectionManager.ts | 13 +++++-
 ui/components/SettingsDialog.tsx   | 75 ++++++++++++++++++------------
 3 files changed, 58 insertions(+), 32 deletions(-)

diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts
index d919fd4..49f2e22 100644
--- a/src/lib/providers/index.ts
+++ b/src/lib/providers/index.ts
@@ -42,5 +42,7 @@ export const getAvailableEmbeddingModelProviders = async () => {
     }
   }
 
+  models['custom_openai'] = {};
+
   return models;
 };
diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts
index 70e20d9..b77106c 100644
--- a/src/websocket/connectionManager.ts
+++ b/src/websocket/connectionManager.ts
@@ -8,7 +8,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
 import type { Embeddings } from '@langchain/core/embeddings';
 import type { IncomingMessage } from 'http';
 import logger from '../utils/logger';
-import { ChatOpenAI } from '@langchain/openai';
+import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
 
 export const handleConnection = async (
   ws: WebSocket,
@@ -61,11 +61,20 @@ export const handleConnection = async (
 
     if (
       embeddingModelProviders[embeddingModelProvider] &&
-      embeddingModelProviders[embeddingModelProvider][embeddingModel]
+      embeddingModelProviders[embeddingModelProvider][embeddingModel] &&
+      embeddingModelProvider != 'custom_openai'
     ) {
       embeddings = embeddingModelProviders[embeddingModelProvider][
         embeddingModel
       ] as Embeddings | undefined;
+    } else if (embeddingModelProvider == 'custom_openai') {
+      embeddings = new OpenAIEmbeddings({
+        modelName: embeddingModel,
+        openAIApiKey: searchParams.get('openAIApiKey'),
+        configuration: {
+          baseURL: searchParams.get('openAIBaseURL'),
+        },
+      }) as unknown as Embeddings
     }
 
     if (!llm || !embeddings) {
diff --git a/ui/components/SettingsDialog.tsx b/ui/components/SettingsDialog.tsx
index 171e812..7b84bf5 100644
--- a/ui/components/SettingsDialog.tsx
+++ b/ui/components/SettingsDialog.tsx
@@ -9,7 +9,7 @@ import React, {
 } from 'react';
 import ThemeSwitcher from './theme/Switcher';
 
-interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {}
+interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> { }
 
 const Input = ({ className, ...restProps }: InputProps) => {
   return (
@@ -258,30 +258,30 @@ const SettingsDialog = ({
                             options={(() => {
                               const chatModelProvider =
                                 config.chatModelProviders[
-                                  selectedChatModelProvider
+                                selectedChatModelProvider
                                 ];
 
                               return chatModelProvider
                                 ? chatModelProvider.length > 0
                                   ? chatModelProvider.map((model) => ({
-                                      value: model,
-                                      label: model,
-                                    }))
+                                    value: model,
+                                    label: model,
+                                  }))
                                   : [
-                                      {
-                                        value: '',
-                                        label: 'No models available',
-                                        disabled: true,
-                                      },
-                                    ]
-                                : [
                                     {
                                       value: '',
-                                      label:
-                                        'Invalid provider, please check backend logs',
+                                      label: 'No models available',
                                       disabled: true,
                                     },
-                                  ];
+                                  ]
+                                : [
+                                  {
+                                    value: '',
+                                    label:
+                                      'Invalid provider, please check backend logs',
+                                    disabled: true,
+                                  },
+                                ];
                             })()}
                           />
                         </div>
@@ -355,7 +355,7 @@ const SettingsDialog = ({
                         />
                       </div>
                     )}
-                    {selectedEmbeddingModelProvider && (
+                    {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider != 'custom_openai' && (
                       <div className="flex flex-col space-y-1">
                         <p className="text-black/70 dark:text-white/70 text-sm">
                           Embedding Model
@@ -368,34 +368,49 @@ const SettingsDialog = ({
                           options={(() => {
                             const embeddingModelProvider =
                               config.embeddingModelProviders[
-                                selectedEmbeddingModelProvider
+                              selectedEmbeddingModelProvider
                               ];
 
                             return embeddingModelProvider
                               ? embeddingModelProvider.length > 0
                                 ? embeddingModelProvider.map((model) => ({
-                                    label: model,
-                                    value: model,
-                                  }))
+                                  label: model,
+                                  value: model,
+                                }))
                                 : [
-                                    {
-                                      label: 'No embedding models available',
-                                      value: '',
-                                      disabled: true,
-                                    },
-                                  ]
-                              : [
                                   {
-                                    label:
-                                      'Invalid provider, please check backend logs',
+                                    label: 'No embedding models available',
                                     value: '',
                                     disabled: true,
                                   },
-                                ];
+                                ]
+                              : [
+                                {
+                                  label:
+                                    'Invalid provider, please check backend logs',
+                                  value: '',
+                                  disabled: true,
+                                },
+                              ];
                           })()}
                         />
                       </div>
                     )}
+                    {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider === 'custom_openai' && (
+                      <div className="flex flex-col space-y-1">
+                        <p className="text-black/70 dark:text-white/70 text-sm">
+                          Embedding Model
+                        </p>
+                        <Input
+                          type="text"
+                          placeholder="Embedding Model name"
+                          defaultValue={selectedEmbeddingModel!}
+                          onChange={(e) =>
+                            setSelectedEmbeddingModel(e.target.value)
+                          }
+                        />
+                      </div>
+                    )}
                     <div className="flex flex-col space-y-1">
                       <p className="text-black/70 dark:text-white/70 text-sm">
                         OpenAI API Key