feat: add support for defining custom models in config, switched toml library for proper serialization
This commit is contained in:
parent
d04ba91c85
commit
62910b5879
7 changed files with 9767 additions and 5596 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -2,10 +2,13 @@
|
||||||
node_modules/
|
node_modules/
|
||||||
npm-debug.log
|
npm-debug.log
|
||||||
yarn-error.log
|
yarn-error.log
|
||||||
|
.yarnrc.yml
|
||||||
|
|
||||||
# Build output
|
# Build output
|
||||||
/.next/
|
/.next/
|
||||||
/out/
|
/out/
|
||||||
|
dist/
|
||||||
|
.yarn/
|
||||||
|
|
||||||
# IDE/Editor specific
|
# IDE/Editor specific
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
"typescript": "^5.4.3"
|
"typescript": "^5.4.3"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@iarna/toml": "^2.2.5",
|
|
||||||
"@langchain/openai": "^0.0.25",
|
"@langchain/openai": "^0.0.25",
|
||||||
"@xenova/transformers": "^2.17.1",
|
"@xenova/transformers": "^2.17.1",
|
||||||
"axios": "^1.6.8",
|
"axios": "^1.6.8",
|
||||||
|
@ -30,8 +29,10 @@
|
||||||
"dotenv": "^16.4.5",
|
"dotenv": "^16.4.5",
|
||||||
"express": "^4.19.2",
|
"express": "^4.19.2",
|
||||||
"langchain": "^0.1.30",
|
"langchain": "^0.1.30",
|
||||||
|
"smol-toml": "^1.0.0",
|
||||||
"winston": "^3.13.0",
|
"winston": "^3.13.0",
|
||||||
"ws": "^8.16.0",
|
"ws": "^8.16.0",
|
||||||
"zod": "^3.22.4"
|
"zod": "^3.22.4"
|
||||||
}
|
},
|
||||||
|
"packageManager": "yarn@3.6.1+sha512.de524adec81a6c3d7a26d936d439d2832e351cdfc5728f9d91f3fc85dd20b04391c038e9b4ecab11cae2b0dd9f0d55fd355af766bc5c1a7f8d25d96bb2a0b2ca"
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,4 +8,24 @@ GROQ = "" # Groq API key - gsk_1234567890abcdef1234567890abcdef
|
||||||
|
|
||||||
[API_ENDPOINTS]
|
[API_ENDPOINTS]
|
||||||
SEARXNG = "http://localhost:32768" # SearxNG API URL
|
SEARXNG = "http://localhost:32768" # SearxNG API URL
|
||||||
OLLAMA = "" # Ollama API URL - http://host.docker.internal:11434
|
OLLAMA = "" # Ollama API URL - http://host.docker.internal:11434
|
||||||
|
|
||||||
|
[[MODELS]]
|
||||||
|
name = "text-generation-webui"
|
||||||
|
api_key = "blah"
|
||||||
|
base_url = "http://localhost:5000/v1"
|
||||||
|
provider = "openai"
|
||||||
|
|
||||||
|
[[EMBEDDINGS]]
|
||||||
|
name = "text-generation-webui-small"
|
||||||
|
model = "text-embedding-3-small"
|
||||||
|
api_key = "blah"
|
||||||
|
base_url = "http://localhost:5000/v1"
|
||||||
|
provider = "openai"
|
||||||
|
|
||||||
|
[[EMBEDDINGS]]
|
||||||
|
name = "text-generation-webui-large"
|
||||||
|
model = "text-embedding-3-large"
|
||||||
|
api_key = "blah"
|
||||||
|
base_url = "http://localhost:5000/v1"
|
||||||
|
provider = "openai"
|
|
@ -1,37 +1,54 @@
|
||||||
import fs from 'fs';
|
import fs from 'fs';
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
import toml from '@iarna/toml';
|
import {parse, stringify} from "smol-toml";
|
||||||
|
|
||||||
const configFileName = 'config.toml';
|
const configFileName = 'config.toml';
|
||||||
|
|
||||||
interface Config {
|
interface Config {
|
||||||
GENERAL: {
|
GENERAL: {
|
||||||
PORT: number;
|
PORT: number;
|
||||||
SIMILARITY_MEASURE: string;
|
SIMILARITY_MEASURE: string;
|
||||||
};
|
};
|
||||||
API_KEYS: {
|
API_KEYS: {
|
||||||
OPENAI: string;
|
OPENAI: string;
|
||||||
GROQ: string;
|
GROQ: string;
|
||||||
};
|
};
|
||||||
API_ENDPOINTS: {
|
API_ENDPOINTS: {
|
||||||
SEARXNG: string;
|
SEARXNG: string;
|
||||||
OLLAMA: string;
|
OLLAMA: string;
|
||||||
};
|
};
|
||||||
|
MODELS: [
|
||||||
|
{
|
||||||
|
"name": string;
|
||||||
|
"api_key": string;
|
||||||
|
"base_url": string;
|
||||||
|
"provider": string;
|
||||||
|
}
|
||||||
|
];
|
||||||
|
EMBEDDINGS: [
|
||||||
|
{
|
||||||
|
"name": string;
|
||||||
|
"model": string;
|
||||||
|
"api_key": string;
|
||||||
|
"base_url": string;
|
||||||
|
"provider": string;
|
||||||
|
}
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
type RecursivePartial<T> = {
|
type RecursivePartial<T> = {
|
||||||
[P in keyof T]?: RecursivePartial<T[P]>;
|
[P in keyof T]?: RecursivePartial<T[P]>;
|
||||||
};
|
};
|
||||||
|
|
||||||
const loadConfig = () =>
|
const loadConfig = () =>
|
||||||
toml.parse(
|
parse(
|
||||||
fs.readFileSync(path.join(__dirname, `../${configFileName}`), 'utf-8'),
|
fs.readFileSync(path.join(__dirname, `../${configFileName}`), 'utf-8'),
|
||||||
) as any as Config;
|
) as any as Config;
|
||||||
|
|
||||||
export const getPort = () => loadConfig().GENERAL.PORT;
|
export const getPort = () => loadConfig().GENERAL.PORT;
|
||||||
|
|
||||||
export const getSimilarityMeasure = () =>
|
export const getSimilarityMeasure = () =>
|
||||||
loadConfig().GENERAL.SIMILARITY_MEASURE;
|
loadConfig().GENERAL.SIMILARITY_MEASURE;
|
||||||
|
|
||||||
export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI;
|
export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI;
|
||||||
|
|
||||||
|
@ -41,29 +58,22 @@ export const getSearxngApiEndpoint = () => loadConfig().API_ENDPOINTS.SEARXNG;
|
||||||
|
|
||||||
export const getOllamaApiEndpoint = () => loadConfig().API_ENDPOINTS.OLLAMA;
|
export const getOllamaApiEndpoint = () => loadConfig().API_ENDPOINTS.OLLAMA;
|
||||||
|
|
||||||
|
export const getCustomModels = () => loadConfig().MODELS;
|
||||||
|
|
||||||
|
export const getCustomEmbeddingModels = () => loadConfig().EMBEDDINGS;
|
||||||
|
|
||||||
export const updateConfig = (config: RecursivePartial<Config>) => {
|
export const updateConfig = (config: RecursivePartial<Config>) => {
|
||||||
const currentConfig = loadConfig();
|
const currentConfig = loadConfig();
|
||||||
|
|
||||||
for (const key in currentConfig) {
|
const updatedConfig = {
|
||||||
if (!config[key]) config[key] = {};
|
...currentConfig,
|
||||||
|
...config
|
||||||
|
};
|
||||||
|
|
||||||
if (typeof currentConfig[key] === 'object' && currentConfig[key] !== null) {
|
const toml = stringify(updatedConfig);
|
||||||
for (const nestedKey in currentConfig[key]) {
|
|
||||||
if (
|
|
||||||
!config[key][nestedKey] &&
|
|
||||||
currentConfig[key][nestedKey] &&
|
|
||||||
config[key][nestedKey] !== ''
|
|
||||||
) {
|
|
||||||
config[key][nestedKey] = currentConfig[key][nestedKey];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (currentConfig[key] && config[key] !== '') {
|
|
||||||
config[key] = currentConfig[key];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fs.writeFileSync(
|
fs.writeFileSync(
|
||||||
path.join(__dirname, `../${configFileName}`),
|
path.join(__dirname, `../${configFileName}`),
|
||||||
toml.stringify(config),
|
toml,
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,187 +1,236 @@
|
||||||
import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
|
import {ChatOpenAI, OpenAIEmbeddings} from '@langchain/openai';
|
||||||
import { ChatOllama } from '@langchain/community/chat_models/ollama';
|
import {ChatOllama} from '@langchain/community/chat_models/ollama';
|
||||||
import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama';
|
import {OllamaEmbeddings} from '@langchain/community/embeddings/ollama';
|
||||||
import { HuggingFaceTransformersEmbeddings } from './huggingfaceTransformer';
|
import {HuggingFaceTransformersEmbeddings} from './huggingfaceTransformer';
|
||||||
import {
|
import {
|
||||||
getGroqApiKey,
|
getCustomEmbeddingModels,
|
||||||
getOllamaApiEndpoint,
|
getCustomModels,
|
||||||
getOpenaiApiKey,
|
getGroqApiKey,
|
||||||
|
getOllamaApiEndpoint,
|
||||||
|
getOpenaiApiKey,
|
||||||
} from '../config';
|
} from '../config';
|
||||||
import logger from '../utils/logger';
|
import logger from '../utils/logger';
|
||||||
|
|
||||||
export const getAvailableChatModelProviders = async () => {
|
export const getAvailableChatModelProviders = async () => {
|
||||||
const openAIApiKey = getOpenaiApiKey();
|
const openAIApiKey = getOpenaiApiKey();
|
||||||
const groqApiKey = getGroqApiKey();
|
const groqApiKey = getGroqApiKey();
|
||||||
const ollamaEndpoint = getOllamaApiEndpoint();
|
const ollamaEndpoint = getOllamaApiEndpoint();
|
||||||
|
const customModels = getCustomModels();
|
||||||
|
|
||||||
const models = {};
|
const models = {};
|
||||||
|
|
||||||
if (openAIApiKey) {
|
if (openAIApiKey) {
|
||||||
try {
|
try {
|
||||||
models['openai'] = {
|
models['openai'] = {
|
||||||
'GPT-3.5 turbo': new ChatOpenAI({
|
'GPT-3.5 turbo': new ChatOpenAI({
|
||||||
openAIApiKey,
|
openAIApiKey,
|
||||||
modelName: 'gpt-3.5-turbo',
|
modelName: 'gpt-3.5-turbo',
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
}),
|
}),
|
||||||
'GPT-4': new ChatOpenAI({
|
'GPT-4': new ChatOpenAI({
|
||||||
openAIApiKey,
|
openAIApiKey,
|
||||||
modelName: 'gpt-4',
|
modelName: 'gpt-4',
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
}),
|
}),
|
||||||
'GPT-4 turbo': new ChatOpenAI({
|
'GPT-4 turbo': new ChatOpenAI({
|
||||||
openAIApiKey,
|
openAIApiKey,
|
||||||
modelName: 'gpt-4-turbo',
|
modelName: 'gpt-4-turbo',
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
}),
|
}),
|
||||||
'GPT-4 omni': new ChatOpenAI({
|
'GPT-4 omni': new ChatOpenAI({
|
||||||
openAIApiKey,
|
openAIApiKey,
|
||||||
modelName: 'gpt-4o',
|
modelName: 'gpt-4o',
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error(`Error loading OpenAI models: ${err}`);
|
logger.error(`Error loading OpenAI models: ${err}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (groqApiKey) {
|
if (groqApiKey) {
|
||||||
try {
|
try {
|
||||||
models['groq'] = {
|
models['groq'] = {
|
||||||
'LLaMA3 8b': new ChatOpenAI(
|
'LLaMA3 8b': new ChatOpenAI(
|
||||||
{
|
{
|
||||||
openAIApiKey: groqApiKey,
|
openAIApiKey: groqApiKey,
|
||||||
modelName: 'llama3-8b-8192',
|
modelName: 'llama3-8b-8192',
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
baseURL: 'https://api.groq.com/openai/v1',
|
baseURL: 'https://api.groq.com/openai/v1',
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
'LLaMA3 70b': new ChatOpenAI(
|
'LLaMA3 70b': new ChatOpenAI(
|
||||||
{
|
{
|
||||||
openAIApiKey: groqApiKey,
|
openAIApiKey: groqApiKey,
|
||||||
modelName: 'llama3-70b-8192',
|
modelName: 'llama3-70b-8192',
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
baseURL: 'https://api.groq.com/openai/v1',
|
baseURL: 'https://api.groq.com/openai/v1',
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
'Mixtral 8x7b': new ChatOpenAI(
|
'Mixtral 8x7b': new ChatOpenAI(
|
||||||
{
|
{
|
||||||
openAIApiKey: groqApiKey,
|
openAIApiKey: groqApiKey,
|
||||||
modelName: 'mixtral-8x7b-32768',
|
modelName: 'mixtral-8x7b-32768',
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
baseURL: 'https://api.groq.com/openai/v1',
|
baseURL: 'https://api.groq.com/openai/v1',
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
'Gemma 7b': new ChatOpenAI(
|
'Gemma 7b': new ChatOpenAI(
|
||||||
{
|
{
|
||||||
openAIApiKey: groqApiKey,
|
openAIApiKey: groqApiKey,
|
||||||
modelName: 'gemma-7b-it',
|
modelName: 'gemma-7b-it',
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
baseURL: 'https://api.groq.com/openai/v1',
|
baseURL: 'https://api.groq.com/openai/v1',
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error(`Error loading Groq models: ${err}`);
|
logger.error(`Error loading Groq models: ${err}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (ollamaEndpoint) {
|
if (ollamaEndpoint) {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
|
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const { models: ollamaModels } = (await response.json()) as any;
|
const {models: ollamaModels} = (await response.json()) as any;
|
||||||
|
|
||||||
models['ollama'] = ollamaModels.reduce((acc, model) => {
|
models['ollama'] = ollamaModels.reduce((acc, model) => {
|
||||||
acc[model.model] = new ChatOllama({
|
acc[model.model] = new ChatOllama({
|
||||||
baseUrl: ollamaEndpoint,
|
baseUrl: ollamaEndpoint,
|
||||||
model: model.model,
|
model: model.model,
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
});
|
});
|
||||||
return acc;
|
return acc;
|
||||||
}, {});
|
}, {});
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error(`Error loading Ollama models: ${err}`);
|
logger.error(`Error loading Ollama models: ${err}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
models['custom_openai'] = {};
|
models['custom_openai'] = {};
|
||||||
|
|
||||||
return models;
|
if (customModels && customModels.length > 0) {
|
||||||
|
models['custom'] = {};
|
||||||
|
try {
|
||||||
|
customModels.forEach((model) => {
|
||||||
|
if (model.provider === "openai") {
|
||||||
|
models['custom'] = {
|
||||||
|
...models['custom'],
|
||||||
|
[model.name]: new ChatOpenAI({
|
||||||
|
openAIApiKey: model.api_key,
|
||||||
|
modelName: model.name,
|
||||||
|
temperature: 0.7,
|
||||||
|
configuration: {
|
||||||
|
baseURL: model.base_url,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} catch (err) {
|
||||||
|
logger.error(`Error loading custom models: ${err}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getAvailableEmbeddingModelProviders = async () => {
|
export const getAvailableEmbeddingModelProviders = async () => {
|
||||||
const openAIApiKey = getOpenaiApiKey();
|
const openAIApiKey = getOpenaiApiKey();
|
||||||
const ollamaEndpoint = getOllamaApiEndpoint();
|
const ollamaEndpoint = getOllamaApiEndpoint();
|
||||||
|
const customEmbeddingModels = getCustomEmbeddingModels();
|
||||||
|
|
||||||
const models = {};
|
const models = {};
|
||||||
|
|
||||||
if (openAIApiKey) {
|
if (openAIApiKey) {
|
||||||
try {
|
try {
|
||||||
models['openai'] = {
|
models['openai'] = {
|
||||||
'Text embedding 3 small': new OpenAIEmbeddings({
|
'Text embedding 3 small': new OpenAIEmbeddings({
|
||||||
openAIApiKey,
|
openAIApiKey,
|
||||||
modelName: 'text-embedding-3-small',
|
modelName: 'text-embedding-3-small',
|
||||||
}),
|
}, {baseURL: "http://10.0.1.2:5000/v1"}),
|
||||||
'Text embedding 3 large': new OpenAIEmbeddings({
|
'Text embedding 3 large': new OpenAIEmbeddings({
|
||||||
openAIApiKey,
|
openAIApiKey,
|
||||||
modelName: 'text-embedding-3-large',
|
modelName: 'text-embedding-3-large',
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error(`Error loading OpenAI embeddings: ${err}`);
|
logger.error(`Error loading OpenAI embeddings: ${err}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (ollamaEndpoint) {
|
if (ollamaEndpoint) {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
|
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const { models: ollamaModels } = (await response.json()) as any;
|
const {models: ollamaModels} = (await response.json()) as any;
|
||||||
|
|
||||||
models['ollama'] = ollamaModels.reduce((acc, model) => {
|
models['ollama'] = ollamaModels.reduce((acc, model) => {
|
||||||
acc[model.model] = new OllamaEmbeddings({
|
acc[model.model] = new OllamaEmbeddings({
|
||||||
baseUrl: ollamaEndpoint,
|
baseUrl: ollamaEndpoint,
|
||||||
model: model.model,
|
model: model.model,
|
||||||
});
|
});
|
||||||
return acc;
|
return acc;
|
||||||
}, {});
|
}, {});
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error(`Error loading Ollama embeddings: ${err}`);
|
logger.error(`Error loading Ollama embeddings: ${err}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
if (customEmbeddingModels && customEmbeddingModels.length > 0) {
|
||||||
models['local'] = {
|
models['custom'] = {};
|
||||||
'BGE Small': new HuggingFaceTransformersEmbeddings({
|
try {
|
||||||
modelName: 'Xenova/bge-small-en-v1.5',
|
customEmbeddingModels.forEach((model) => {
|
||||||
}),
|
if (model.provider === "openai") {
|
||||||
'GTE Small': new HuggingFaceTransformersEmbeddings({
|
models['custom'] = {
|
||||||
modelName: 'Xenova/gte-small',
|
...models['custom'],
|
||||||
}),
|
[model.name]: new OpenAIEmbeddings({
|
||||||
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
|
openAIApiKey: model.api_key,
|
||||||
modelName: 'Xenova/bert-base-multilingual-uncased',
|
modelName: model.model,
|
||||||
}),
|
},
|
||||||
};
|
{
|
||||||
} catch (err) {
|
baseURL: model.base_url,
|
||||||
logger.error(`Error loading local embeddings: ${err}`);
|
}),
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} catch (err) {
|
||||||
|
logger.error(`Error loading custom models: ${err}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return models;
|
try {
|
||||||
|
models['local'] = {
|
||||||
|
'BGE Small': new HuggingFaceTransformersEmbeddings({
|
||||||
|
modelName: 'Xenova/bge-small-en-v1.5',
|
||||||
|
}),
|
||||||
|
'GTE Small': new HuggingFaceTransformersEmbeddings({
|
||||||
|
modelName: 'Xenova/gte-small',
|
||||||
|
}),
|
||||||
|
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
|
||||||
|
modelName: 'Xenova/bert-base-multilingual-uncased',
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
} catch (err) {
|
||||||
|
logger.error(`Error loading local embeddings: ${err}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return models;
|
||||||
};
|
};
|
||||||
|
|
8777
ui/yarn.lock
8777
ui/yarn.lock
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue