Perplexica/src/lib/providers.ts

183 lines
4.6 KiB
TypeScript
Raw Normal View History

2024-04-20 11:18:52 +05:30
import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { ChatOllama } from '@langchain/community/chat_models/ollama';
import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama';
import { HuggingFaceTransformersEmbeddings } from './huggingfaceTransformer';
2024-05-01 19:43:06 +05:30
import {
getGroqApiKey,
getOllamaApiEndpoint,
getOpenaiApiKey,
} from '../config';
2024-04-30 12:18:18 +05:30
import logger from '../utils/logger';
2024-04-20 11:18:52 +05:30
export const getAvailableChatModelProviders = async () => {
2024-04-20 11:18:52 +05:30
const openAIApiKey = getOpenaiApiKey();
2024-05-01 19:43:06 +05:30
const groqApiKey = getGroqApiKey();
2024-04-20 11:18:52 +05:30
const ollamaEndpoint = getOllamaApiEndpoint();
const models = {};
if (openAIApiKey) {
2024-04-21 20:52:47 +05:30
try {
models['openai'] = {
2024-05-01 19:43:06 +05:30
'GPT-3.5 turbo': new ChatOpenAI({
2024-04-21 20:52:47 +05:30
openAIApiKey,
modelName: 'gpt-3.5-turbo',
temperature: 0.7,
}),
2024-05-01 19:43:06 +05:30
'GPT-4': new ChatOpenAI({
2024-04-21 20:52:47 +05:30
openAIApiKey,
modelName: 'gpt-4',
temperature: 0.7,
}),
2024-05-01 19:43:06 +05:30
'GPT-4 turbo': new ChatOpenAI({
openAIApiKey,
modelName: 'gpt-4-turbo',
temperature: 0.7,
2024-04-21 20:52:47 +05:30
}),
};
} catch (err) {
2024-04-30 12:18:18 +05:30
logger.error(`Error loading OpenAI models: ${err}`);
2024-04-21 20:52:47 +05:30
}
2024-04-20 11:18:52 +05:30
}
2024-05-01 19:43:06 +05:30
if (groqApiKey) {
try {
models['groq'] = {
'LLaMA3 8b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-8b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'LLaMA3 70b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'llama3-70b-8192',
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Mixtral 8x7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'mixtral-8x7b-32768',
2024-05-01 19:43:06 +05:30
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
'Gemma 7b': new ChatOpenAI(
{
openAIApiKey: groqApiKey,
modelName: 'gemma-7b-it',
2024-05-01 19:43:06 +05:30
temperature: 0.7,
},
{
baseURL: 'https://api.groq.com/openai/v1',
},
),
};
} catch (err) {
logger.error(`Error loading Groq models: ${err}`);
}
}
2024-04-20 11:18:52 +05:30
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
2024-04-20 11:18:52 +05:30
const { models: ollamaModels } = (await response.json()) as any;
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new ChatOllama({
baseUrl: ollamaEndpoint,
model: model.model,
temperature: 0.7,
});
return acc;
}, {});
} catch (err) {
logger.error(`Error loading Ollama models: ${err}`);
}
}
models['custom_openai'] = {};
return models;
};
export const getAvailableEmbeddingModelProviders = async () => {
const openAIApiKey = getOpenaiApiKey();
const ollamaEndpoint = getOllamaApiEndpoint();
const models = {};
if (openAIApiKey) {
try {
models['openai'] = {
'Text embedding 3 small': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-small',
}),
'Text embedding 3 large': new OpenAIEmbeddings({
openAIApiKey,
modelName: 'text-embedding-3-large',
}),
};
} catch (err) {
logger.error(`Error loading OpenAI embeddings: ${err}`);
}
}
2024-04-20 11:18:52 +05:30
if (ollamaEndpoint) {
try {
const response = await fetch(`${ollamaEndpoint}/api/tags`, {
headers: {
'Content-Type': 'application/json',
},
});
const { models: ollamaModels } = (await response.json()) as any;
models['ollama'] = ollamaModels.reduce((acc, model) => {
acc[model.model] = new OllamaEmbeddings({
2024-04-20 11:18:52 +05:30
baseUrl: ollamaEndpoint,
model: model.model,
2024-04-20 11:18:52 +05:30
});
return acc;
}, {});
2024-04-20 11:18:52 +05:30
} catch (err) {
logger.error(`Error loading Ollama embeddings: ${err}`);
2024-04-20 11:18:52 +05:30
}
}
try {
models['local'] = {
'BGE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/bge-small-en-v1.5',
}),
'GTE Small': new HuggingFaceTransformersEmbeddings({
modelName: 'Xenova/gte-small',
}),
'Bert Multilingual': new HuggingFaceTransformersEmbeddings({
2024-05-09 20:43:04 +05:30
modelName: 'Xenova/bert-base-multilingual-uncased',
}),
};
2024-05-09 20:43:04 +05:30
} catch (err) {
logger.error(`Error loading local embeddings: ${err}`);
}
2024-04-20 11:18:52 +05:30
return models;
};