feat: add support for defining custom models in config, switched toml library for proper serialization

This commit is contained in:
Justin Luoma 2024-05-24 06:48:15 -04:00
parent d04ba91c85
commit 62910b5879
7 changed files with 9767 additions and 5596 deletions

3
.gitignore vendored
View file

@ -2,10 +2,13 @@
node_modules/
npm-debug.log
yarn-error.log
.yarnrc.yml
# Build output
/.next/
/out/
dist/
.yarn/
# IDE/Editor specific
.vscode/

View file

@ -20,7 +20,6 @@
"typescript": "^5.4.3"
},
"dependencies": {
"@iarna/toml": "^2.2.5",
"@langchain/openai": "^0.0.25",
"@xenova/transformers": "^2.17.1",
"axios": "^1.6.8",
@ -30,8 +29,10 @@
"dotenv": "^16.4.5",
"express": "^4.19.2",
"langchain": "^0.1.30",
"smol-toml": "^1.0.0",
"winston": "^3.13.0",
"ws": "^8.16.0",
"zod": "^3.22.4"
}
},
"packageManager": "yarn@3.6.1+sha512.de524adec81a6c3d7a26d936d439d2832e351cdfc5728f9d91f3fc85dd20b04391c038e9b4ecab11cae2b0dd9f0d55fd355af766bc5c1a7f8d25d96bb2a0b2ca"
}

View file

@ -8,4 +8,24 @@ GROQ = "" # Groq API key - gsk_1234567890abcdef1234567890abcdef
[API_ENDPOINTS]
SEARXNG = "http://localhost:32768" # SearxNG API URL
OLLAMA = "" # Ollama API URL - http://host.docker.internal:11434
OLLAMA = "" # Ollama API URL - http://host.docker.internal:11434
[[MODELS]]
name = "text-generation-webui"
api_key = "blah"
base_url = "http://localhost:5000/v1"
provider = "openai"
[[EMBEDDINGS]]
name = "text-generation-webui-small"
model = "text-embedding-3-small"
api_key = "blah"
base_url = "http://localhost:5000/v1"
provider = "openai"
[[EMBEDDINGS]]
name = "text-generation-webui-large"
model = "text-embedding-3-large"
api_key = "blah"
base_url = "http://localhost:5000/v1"
provider = "openai"

View file

@ -1,37 +1,54 @@
import fs from 'fs';
import path from 'path';
import toml from '@iarna/toml';
import {parse, stringify} from "smol-toml";
const configFileName = 'config.toml';
interface Config {
GENERAL: {
PORT: number;
SIMILARITY_MEASURE: string;
};
API_KEYS: {
OPENAI: string;
GROQ: string;
};
API_ENDPOINTS: {
SEARXNG: string;
OLLAMA: string;
};
GENERAL: {
PORT: number;
SIMILARITY_MEASURE: string;
};
API_KEYS: {
OPENAI: string;
GROQ: string;
};
API_ENDPOINTS: {
SEARXNG: string;
OLLAMA: string;
};
MODELS: [
{
"name": string;
"api_key": string;
"base_url": string;
"provider": string;
}
];
EMBEDDINGS: [
{
"name": string;
"model": string;
"api_key": string;
"base_url": string;
"provider": string;
}
];
}
type RecursivePartial<T> = {
[P in keyof T]?: RecursivePartial<T[P]>;
[P in keyof T]?: RecursivePartial<T[P]>;
};
const loadConfig = () =>
toml.parse(
fs.readFileSync(path.join(__dirname, `../${configFileName}`), 'utf-8'),
) as any as Config;
parse(
fs.readFileSync(path.join(__dirname, `../${configFileName}`), 'utf-8'),
) as any as Config;
export const getPort = () => loadConfig().GENERAL.PORT;
export const getSimilarityMeasure = () =>
loadConfig().GENERAL.SIMILARITY_MEASURE;
loadConfig().GENERAL.SIMILARITY_MEASURE;
export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI;
@ -41,29 +58,22 @@ export const getSearxngApiEndpoint = () => loadConfig().API_ENDPOINTS.SEARXNG;
export const getOllamaApiEndpoint = () => loadConfig().API_ENDPOINTS.OLLAMA;
export const getCustomModels = () => loadConfig().MODELS;
export const getCustomEmbeddingModels = () => loadConfig().EMBEDDINGS;
export const updateConfig = (config: RecursivePartial<Config>) => {
const currentConfig = loadConfig();
const currentConfig = loadConfig();
for (const key in currentConfig) {
if (!config[key]) config[key] = {};
const updatedConfig = {
...currentConfig,
...config
};
if (typeof currentConfig[key] === 'object' && currentConfig[key] !== null) {
for (const nestedKey in currentConfig[key]) {
if (
!config[key][nestedKey] &&
currentConfig[key][nestedKey] &&
config[key][nestedKey] !== ''
) {
config[key][nestedKey] = currentConfig[key][nestedKey];
}
}
} else if (currentConfig[key] && config[key] !== '') {
config[key] = currentConfig[key];
}
}
const toml = stringify(updatedConfig);
fs.writeFileSync(
path.join(__dirname, `../${configFileName}`),
toml.stringify(config),
);
fs.writeFileSync(
path.join(__dirname, `../${configFileName}`),
toml,
);
};

View file

@ -1,187 +1,236 @@
import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { ChatOllama } from '@langchain/community/chat_models/ollama';
import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama';
import { HuggingFaceTransformersEmbeddings } from './huggingfaceTransformer';
import {ChatOpenAI, OpenAIEmbeddings} from '@langchain/openai';
import {ChatOllama} from '@langchain/community/chat_models/ollama';
import {OllamaEmbeddings} from '@langchain/community/embeddings/ollama';
import {HuggingFaceTransformersEmbeddings} from './huggingfaceTransformer';
import {
getGroqApiKey,
getOllamaApiEndpoint,
getOpenaiApiKey,
getCustomEmbeddingModels,
getCustomModels,
getGroqApiKey,
getOllamaApiEndpoint,
getOpenaiApiKey,
} from '../config';
import logger from '../utils/logger';
export const getAvailableChatModelProviders = async () => {
const openAIApiKey = getOpenaiApiKey();
const groqApiKey = getGroqApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const openAIApiKey = getOpenaiApiKey();
const groqApiKey = getGroqApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const customModels = getCustomModels();
const models = {};
const models = {};
if (openAIApiKey) {
try {
models['openai'] = {
'GPT-3.5 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-3.5-turbo',
temperature: 0.7,
}),
'GPT-4': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4',
temperature: 0.7,
}),
'GPT-4 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4-turbo',
temperature: 0.7,
}),
'GPT-4 omni': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4o',
temperature: 0.7,
}),
};
} catch (err) {
logger.error(`Error loading OpenAI models: ${err}`);
if (openAIApiKey) {
try {
models['openai'] = {
'GPT-3.5 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-3.5-turbo',
temperature: 0.7,
}),
'GPT-4': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4',
temperature: 0.7,
}),
'GPT-4 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4-turbo',
temperature: 0.7,
}),
'GPT-4 omni': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4o',
temperature: 0.7,
}),
};
} catch (err) {
logger.error(`Error loading OpenAI models: ${err}`);
}
}
}
if (groqApiKey) {
try {
models['groq'] = {
'LLaMA3 8b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-8b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'LLaMA3 70b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-70b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Mixtral 8x7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'mixtral-8x7b-32768',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Gemma 7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'gemma-7b-it',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
};
} catch (err) {
logger.error(`Error loading Groq models: ${err}`);
if (groqApiKey) {
try {
models['groq'] = {
'LLaMA3 8b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-8b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'LLaMA3 70b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-70b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Mixtral 8x7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'mixtral-8x7b-32768',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Gemma 7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'gemma-7b-it',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
};
} catch (err) {
logger.error(`Error loading Groq models: ${err}`);
}
}
}
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
const { models: ollamaModels } = (await response.json()) as any;
const {models: ollamaModels} = (await response.json()) as any;
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new ChatOllama({
baseUrl: ollamaEndpoint,
model: model.model,
temperature: 0.7,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama models: ${err}`);
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new ChatOllama({
baseUrl: ollamaEndpoint,
model: model.model,
temperature: 0.7,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama models: ${err}`);
}
}
}
models['custom_openai'] = {};
models['custom_openai'] = {};
return models;
if (customModels && customModels.length > 0) {
models['custom'] = {};
try {
customModels.forEach((model) => {
if (model.provider === "openai") {
models['custom'] = {
...models['custom'],
[model.name]: new ChatOpenAI({
openAIApiKey: model.api_key,
modelName: model.name,
temperature: 0.7,
configuration: {
baseURL: model.base_url,
}
})
}
}
});
} catch (err) {
logger.error(`Error loading custom models: ${err}`);
}
}
return models;
};
export const getAvailableEmbeddingModelProviders = async () => {
const openAIApiKey = getOpenaiApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const openAIApiKey = getOpenaiApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const customEmbeddingModels = getCustomEmbeddingModels();
const models = {};
const models = {};
if (openAIApiKey) {
try {
models['openai'] = {
'Text embedding 3 small': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-small',
}),
'Text embedding 3 large': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-large',
}),
};
} catch (err) {
logger.error(`Error loading OpenAI embeddings: ${err}`);
if (openAIApiKey) {
try {
models['openai'] = {
'Text embedding 3 small': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-small',
}, {baseURL: "http://10.0.1.2:5000/v1"}),
'Text embedding 3 large': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-large',
}),
};
} catch (err) {
logger.error(`Error loading OpenAI embeddings: ${err}`);
}
}
}
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
const { models: ollamaModels } = (await response.json()) as any;
const {models: ollamaModels} = (await response.json()) as any;
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new OllamaEmbeddings({
baseUrl: ollamaEndpoint,
model: model.model,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama embeddings: ${err}`);
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new OllamaEmbeddings({
baseUrl: ollamaEndpoint,
model: model.model,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama embeddings: ${err}`);
}
}
}
try {
models['local'] = {
'BGE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bge-small-en-v1.5',
}),
'GTE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/gte-small',
}),
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bert-base-multilingual-uncased',
}),
};
} catch (err) {
logger.error(`Error loading local embeddings: ${err}`);
}
if (customEmbeddingModels && customEmbeddingModels.length > 0) {
models['custom'] = {};
try {
customEmbeddingModels.forEach((model) => {
if (model.provider === "openai") {
models['custom'] = {
...models['custom'],
[model.name]: new OpenAIEmbeddings({
openAIApiKey: model.api_key,
modelName: model.model,
},
{
baseURL: model.base_url,
}),
}
}
});
} catch (err) {
logger.error(`Error loading custom models: ${err}`);
}
}
return models;
try {
models['local'] = {
'BGE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bge-small-en-v1.5',
}),
'GTE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/gte-small',
}),
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bert-base-multilingual-uncased',
}),
};
} catch (err) {
logger.error(`Error loading local embeddings: ${err}`);
}
return models;
};

File diff suppressed because it is too large Load diff

6103
yarn.lock

File diff suppressed because it is too large Load diff