This commit is contained in:
sagit 2024-09-05 12:26:56 -04:00 committed by GitHub
commit dd1cae4cb8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 77 additions and 35 deletions

View file

@ -42,5 +42,7 @@ export const getAvailableEmbeddingModelProviders = async () => {
} }
} }
models['custom_openai'] = {};
return models; return models;
}; };

View file

@ -8,7 +8,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Embeddings } from '@langchain/core/embeddings'; import type { Embeddings } from '@langchain/core/embeddings';
import type { IncomingMessage } from 'http'; import type { IncomingMessage } from 'http';
import logger from '../utils/logger'; import logger from '../utils/logger';
import { ChatOpenAI } from '@langchain/openai'; import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
export const handleConnection = async ( export const handleConnection = async (
ws: WebSocket, ws: WebSocket,
@ -61,11 +61,20 @@ export const handleConnection = async (
if ( if (
embeddingModelProviders[embeddingModelProvider] && embeddingModelProviders[embeddingModelProvider] &&
embeddingModelProviders[embeddingModelProvider][embeddingModel] embeddingModelProviders[embeddingModelProvider][embeddingModel] &&
embeddingModelProvider != 'custom_openai'
) { ) {
embeddings = embeddingModelProviders[embeddingModelProvider][ embeddings = embeddingModelProviders[embeddingModelProvider][
embeddingModel embeddingModel
] as Embeddings | undefined; ] as Embeddings | undefined;
} else if (embeddingModelProvider == 'custom_openai') {
embeddings = new OpenAIEmbeddings({
modelName: embeddingModel,
openAIApiKey: searchParams.get('openAIApiKey'),
configuration: {
baseURL: searchParams.get('openAIBaseURL'),
},
}) as unknown as Embeddings
} }
if (!llm || !embeddings) { if (!llm || !embeddings) {

View file

@ -130,6 +130,7 @@ const useSocket = (
if ( if (
embeddingModelProvider && embeddingModelProvider &&
embeddingModelProvider != 'custom_openai' &&
!embeddingModelProviders[embeddingModelProvider][embeddingModel] !embeddingModelProviders[embeddingModelProvider][embeddingModel]
) { ) {
embeddingModel = Object.keys( embeddingModel = Object.keys(
@ -159,6 +160,17 @@ const useSocket = (
searchParams.append('embeddingModel', embeddingModel!); searchParams.append('embeddingModel', embeddingModel!);
searchParams.append('embeddingModelProvider', embeddingModelProvider); searchParams.append('embeddingModelProvider', embeddingModelProvider);
if (embeddingModelProvider === 'custom_openai') {
searchParams.append(
'openAIApiKey',
localStorage.getItem('openAIApiKey')!,
);
searchParams.append(
'openAIBaseURL',
localStorage.getItem('openAIBaseURL')!,
);
}
wsURL.search = searchParams.toString(); wsURL.search = searchParams.toString();
const ws = new WebSocket(wsURL.toString()); const ws = new WebSocket(wsURL.toString());

View file

@ -9,7 +9,7 @@ import React, {
} from 'react'; } from 'react';
import ThemeSwitcher from './theme/Switcher'; import ThemeSwitcher from './theme/Switcher';
interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {} interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> { }
const Input = ({ className, ...restProps }: InputProps) => { const Input = ({ className, ...restProps }: InputProps) => {
return ( return (
@ -258,30 +258,30 @@ const SettingsDialog = ({
options={(() => { options={(() => {
const chatModelProvider = const chatModelProvider =
config.chatModelProviders[ config.chatModelProviders[
selectedChatModelProvider selectedChatModelProvider
]; ];
return chatModelProvider return chatModelProvider
? chatModelProvider.length > 0 ? chatModelProvider.length > 0
? chatModelProvider.map((model) => ({ ? chatModelProvider.map((model) => ({
value: model, value: model,
label: model, label: model,
})) }))
: [ : [
{
value: '',
label: 'No models available',
disabled: true,
},
]
: [
{ {
value: '', value: '',
label: label: 'No models available',
'Invalid provider, please check backend logs',
disabled: true, disabled: true,
}, },
]; ]
: [
{
value: '',
label:
'Invalid provider, please check backend logs',
disabled: true,
},
];
})()} })()}
/> />
</div> </div>
@ -340,9 +340,13 @@ const SettingsDialog = ({
value={selectedEmbeddingModelProvider ?? undefined} value={selectedEmbeddingModelProvider ?? undefined}
onChange={(e) => { onChange={(e) => {
setSelectedEmbeddingModelProvider(e.target.value); setSelectedEmbeddingModelProvider(e.target.value);
setSelectedEmbeddingModel( if (e.target.value === 'custom_openai') {
config.embeddingModelProviders[e.target.value][0], setSelectedEmbeddingModel('');
); } else {
setSelectedEmbeddingModel(
config.embeddingModelProviders[e.target.value][0],
);
}
}} }}
options={Object.keys( options={Object.keys(
config.embeddingModelProviders, config.embeddingModelProviders,
@ -355,7 +359,7 @@ const SettingsDialog = ({
/> />
</div> </div>
)} )}
{selectedEmbeddingModelProvider && ( {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider != 'custom_openai' && (
<div className="flex flex-col space-y-1"> <div className="flex flex-col space-y-1">
<p className="text-black/70 dark:text-white/70 text-sm"> <p className="text-black/70 dark:text-white/70 text-sm">
Embedding Model Embedding Model
@ -368,34 +372,49 @@ const SettingsDialog = ({
options={(() => { options={(() => {
const embeddingModelProvider = const embeddingModelProvider =
config.embeddingModelProviders[ config.embeddingModelProviders[
selectedEmbeddingModelProvider selectedEmbeddingModelProvider
]; ];
return embeddingModelProvider return embeddingModelProvider
? embeddingModelProvider.length > 0 ? embeddingModelProvider.length > 0
? embeddingModelProvider.map((model) => ({ ? embeddingModelProvider.map((model) => ({
label: model, label: model,
value: model, value: model,
})) }))
: [ : [
{
label: 'No embedding models available',
value: '',
disabled: true,
},
]
: [
{ {
label: label: 'No embedding models available',
'Invalid provider, please check backend logs',
value: '', value: '',
disabled: true, disabled: true,
}, },
]; ]
: [
{
label:
'Invalid provider, please check backend logs',
value: '',
disabled: true,
},
];
})()} })()}
/> />
</div> </div>
)} )}
{selectedEmbeddingModelProvider && selectedEmbeddingModelProvider === 'custom_openai' && (
<div className="flex flex-col space-y-1">
<p className="text-black/70 dark:text-white/70 text-sm">
Embedding Model
</p>
<Input
type="text"
placeholder="Embedding Model name"
defaultValue={selectedEmbeddingModel!}
onChange={(e) =>
setSelectedEmbeddingModel(e.target.value)
}
/>
</div>
)}
<div className="flex flex-col space-y-1"> <div className="flex flex-col space-y-1">
<p className="text-black/70 dark:text-white/70 text-sm"> <p className="text-black/70 dark:text-white/70 text-sm">
OpenAI API Key OpenAI API Key