Adds Google VertexAI as model provider

This commit is contained in:
Hristo 2024-05-14 15:05:17 -04:00
parent cef75279c5
commit ebbe18ab45
12 changed files with 101 additions and 77 deletions

View file

@ -112,16 +112,17 @@ You need to edit the ports accordingly.
1a: Copy the `sample.env` file to `.env` 1a: Copy the `sample.env` file to `.env`
1b: Copy the `deploy/gcp/sample.env` file to `deploy/gcp/.env` 1b: Copy the `deploy/gcp/sample.env` file to `deploy/gcp/.env`
2a: Fillout desired LLM provider access keys etc. in `.env` 2a: Fillout desired LLM provider access keys etc. in `.env`
- Note: you will have to comeback and edit this file again once you have the address of the K8s backend deploy
2b: Fillout the GCP info in `deploy/gcp/.env` - Note: you will have to comeback and edit this file again once you have the address of the K8s backend deploy
3: Edit `GCP_REPO` to the correct docker image repo path if you are using something other than Container registry 2b: Fillout the GCP info in `deploy/gcp/.env`
4: Edit the `PREFIX` if you would like images and GKE entities to be prefixed with something else 3: Edit `GCP_REPO` to the correct docker image repo path if you are using something other than Container registry
5: In `deploy/gcp` run `make init` to initialize terraform 4: Edit the `PREFIX` if you would like images and GKE entities to be prefixed with something else
6: Follow the normal Preplexica configuration steps outlined in the project readme 5: In `deploy/gcp` run `make init` to initialize terraform
7: Auth docker with the appropriate credential for repo Ex. for `gcr.io` -> `gcloud auth configure-docker` 6: Follow the normal Preplexica configuration steps outlined in the project readme
8: In `deploy/gcp` run `make build-deplpy` to build and push the project images to the repo, create a GKE cluster and deploy the app 7: Auth docker with the appropriate credential for repo Ex. for `gcr.io` -> `gcloud auth configure-docker`
9: Once deployed successfully edit the `.env` file in the root project folder and update the `REMOTE_BACKEND_ADDRESS` with the remote k8s deployment address and port 8: In `deploy/gcp` run `make build-deplpy` to build and push the project images to the repo, create a GKE cluster and deploy the app
10: In root project folder run `make rebuild-run-app-only` 9: Once deployed successfully edit the `.env` file in the root project folder and update the `REMOTE_BACKEND_ADDRESS` with the remote k8s deployment address and port
10: In root project folder run `make rebuild-run-app-only`
If you configured everything correctly frontend app will run locally and provide you with a local url to open it. If you configured everything correctly frontend app will run locally and provide you with a local url to open it.
Now you can run queries against the remotely deployed backend from your local machine. :celebrate: Now you can run queries against the remotely deployed backend from your local machine. :celebrate:

View file

@ -17,7 +17,7 @@ services:
args: args:
- SEARXNG_API_URL=null - SEARXNG_API_URL=null
environment: environment:
SEARXNG_API_URL: "http://searxng:8080" SEARXNG_API_URL: 'http://searxng:8080'
SUPER_SECRET_KEY: ${SUPER_SECRET_KEY} SUPER_SECRET_KEY: ${SUPER_SECRET_KEY}
OPENAI: ${OPENAI} OPENAI: ${OPENAI}
GROQ: ${GROQ} GROQ: ${GROQ}

View file

@ -21,7 +21,7 @@ app.use(cors(corsOptions));
if (getAccessKey()) { if (getAccessKey()) {
app.all('*', requireAccessKey); app.all('*', requireAccessKey);
}; }
app.use(express.json()); app.use(express.json());

View file

@ -1,6 +1,5 @@
import { import { auth } from 'google-auth-library';
getAccessKey, import { getAccessKey } from './config';
} from './config';
export const requireAccessKey = (req, res, next) => { export const requireAccessKey = (req, res, next) => {
const authHeader = req.headers.authorization; const authHeader = req.headers.authorization;
@ -17,5 +16,14 @@ export const requireAccessKey = (req, res, next) => {
export const checkAccessKey = (authHeader) => { export const checkAccessKey = (authHeader) => {
const token = authHeader.split(' ')[1]; const token = authHeader.split(' ')[1];
return Boolean(authHeader && (token === getAccessKey())); return Boolean(authHeader && token === getAccessKey());
};
export const hasGCPCredentials = async () => {
try {
const credentials = await auth.getCredentials();
return Object.keys(credentials).length > 0;
} catch (e) {
return false;
}
}; };

View file

@ -34,33 +34,38 @@ const loadEnv = () => {
GENERAL: { GENERAL: {
PORT: Number(process.env.PORT), PORT: Number(process.env.PORT),
SIMILARITY_MEASURE: process.env.SIMILARITY_MEASURE, SIMILARITY_MEASURE: process.env.SIMILARITY_MEASURE,
SUPER_SECRET_KEY: process.env.SUPER_SECRET_KEY SUPER_SECRET_KEY: process.env.SUPER_SECRET_KEY,
}, },
API_KEYS: { API_KEYS: {
OPENAI: process.env.OPENAI, OPENAI: process.env.OPENAI,
GROQ: process.env.GROQ GROQ: process.env.GROQ,
}, },
API_ENDPOINTS: { API_ENDPOINTS: {
SEARXNG: process.env.SEARXNG_API_URL, SEARXNG: process.env.SEARXNG_API_URL,
OLLAMA: process.env.OLLAMA_API_URL OLLAMA: process.env.OLLAMA_API_URL,
} },
} as Config; } as Config;
}; };
export const getPort = () => loadConfig().GENERAL.PORT; export const getPort = () => loadConfig().GENERAL.PORT;
export const getAccessKey = () => loadEnv().GENERAL.SUPER_SECRET_KEY || loadConfig().GENERAL.SUPER_SECRET_KEY; export const getAccessKey = () =>
loadEnv().GENERAL.SUPER_SECRET_KEY || loadConfig().GENERAL.SUPER_SECRET_KEY;
export const getSimilarityMeasure = () => export const getSimilarityMeasure = () =>
loadConfig().GENERAL.SIMILARITY_MEASURE; loadConfig().GENERAL.SIMILARITY_MEASURE;
export const getOpenaiApiKey = () => loadEnv().API_KEYS.OPENAI || loadConfig().API_KEYS.OPENAI; export const getOpenaiApiKey = () =>
loadEnv().API_KEYS.OPENAI || loadConfig().API_KEYS.OPENAI;
export const getGroqApiKey = () => loadEnv().API_KEYS.GROQ || loadConfig().API_KEYS.GROQ; export const getGroqApiKey = () =>
loadEnv().API_KEYS.GROQ || loadConfig().API_KEYS.GROQ;
export const getSearxngApiEndpoint = () => loadEnv().API_ENDPOINTS.SEARXNG || loadConfig().API_ENDPOINTS.SEARXNG; export const getSearxngApiEndpoint = () =>
loadEnv().API_ENDPOINTS.SEARXNG || loadConfig().API_ENDPOINTS.SEARXNG;
export const getOllamaApiEndpoint = () => loadEnv().API_ENDPOINTS.OLLAMA || loadConfig().API_ENDPOINTS.OLLAMA; export const getOllamaApiEndpoint = () =>
loadEnv().API_ENDPOINTS.OLLAMA || loadConfig().API_ENDPOINTS.OLLAMA;
export const updateConfig = (config: RecursivePartial<Config>) => { export const updateConfig = (config: RecursivePartial<Config>) => {
const currentConfig = loadConfig(); const currentConfig = loadConfig();

View file

@ -1,6 +1,7 @@
import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai'; import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { ChatOllama } from '@langchain/community/chat_models/ollama'; import { ChatOllama } from '@langchain/community/chat_models/ollama';
import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama'; import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama';
import { hasGCPCredentials } from '../auth';
import { import {
getGroqApiKey, getGroqApiKey,
getOllamaApiEndpoint, getOllamaApiEndpoint,
@ -111,6 +112,23 @@ export const getAvailableChatModelProviders = async () => {
} }
} }
if (await hasGCPCredentials()) {
try {
models['vertexai'] = {
'gemini-1.5-pro (preview-0409)': new VertexAI({
temperature: 0.7,
modelName: 'gemini-1.5-pro-preview-0409',
}),
'gemini-1.0-pro (Latest)': new VertexAI({
temperature: 0.7,
modelName: 'gemini-1.0-pro',
}),
};
} catch (err) {
logger.error(`Error loading VertexAI models: ${err}`);
}
}
models['custom_openai'] = {}; models['custom_openai'] = {};
return models; return models;

View file

@ -31,7 +31,7 @@ export const handleConnection = async (
}), }),
); );
ws.close(); ws.close();
}; }
} }
const [chatModelProviders, embeddingModelProviders] = await Promise.all([ const [chatModelProviders, embeddingModelProviders] = await Promise.all([

View file

@ -36,14 +36,11 @@ const useSocket = (url: string) => {
!embeddingModel || !embeddingModel ||
!embeddingModelProvider !embeddingModelProvider
) { ) {
const providers = await clientFetch( const providers = await clientFetch('/models', {
'/models',
{
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
}, }).then(async (res) => await res.json());
).then(async (res) => await res.json());
const chatModelProviders = providers.chatModelProviders; const chatModelProviders = providers.chatModelProviders;
const embeddingModelProviders = providers.embeddingModelProviders; const embeddingModelProviders = providers.embeddingModelProviders;
@ -103,8 +100,8 @@ const useSocket = (url: string) => {
const secretToken = getAccessKey(); const secretToken = getAccessKey();
if (secretToken) { if (secretToken) {
protocols = ["Authorization", `${secretToken}`]; protocols = ['Authorization', `${secretToken}`];
}; }
const ws = new WebSocket(wsURL.toString(), protocols); const ws = new WebSocket(wsURL.toString(), protocols);

View file

@ -34,9 +34,7 @@ const SearchImages = ({
const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModelProvider = localStorage.getItem('chatModelProvider');
const chatModel = localStorage.getItem('chatModel'); const chatModel = localStorage.getItem('chatModel');
const res = await clientFetch( const res = await clientFetch('/images', {
'/images',
{
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -47,8 +45,7 @@ const SearchImages = ({
chat_model_provider: chatModelProvider, chat_model_provider: chatModelProvider,
chat_model: chatModel, chat_model: chatModel,
}), }),
}, });
);
const data = await res.json(); const data = await res.json();

View file

@ -47,9 +47,7 @@ const Searchvideos = ({
const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModelProvider = localStorage.getItem('chatModelProvider');
const chatModel = localStorage.getItem('chatModel'); const chatModel = localStorage.getItem('chatModel');
const res = await clientFetch( const res = await clientFetch('/videos', {
'/videos',
{
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -60,8 +58,7 @@ const Searchvideos = ({
chat_model_provider: chatModelProvider, chat_model_provider: chatModelProvider,
chat_model: chatModel, chat_model: chatModel,
}), }),
}, });
);
const data = await res.json(); const data = await res.json();

View file

@ -11,11 +11,12 @@ const loadEnv = () => {
GENERAL: { GENERAL: {
NEXT_PUBLIC_SUPER_SECRET_KEY: process.env.NEXT_PUBLIC_SUPER_SECRET_KEY!, NEXT_PUBLIC_SUPER_SECRET_KEY: process.env.NEXT_PUBLIC_SUPER_SECRET_KEY!,
NEXT_PUBLIC_API_URL: process.env.NEXT_PUBLIC_API_URL!, NEXT_PUBLIC_API_URL: process.env.NEXT_PUBLIC_API_URL!,
NEXT_PUBLIC_WS_URL: process.env.NEXT_PUBLIC_WS_URL! NEXT_PUBLIC_WS_URL: process.env.NEXT_PUBLIC_WS_URL!,
}, },
} as Config; } as Config;
}; };
export const getAccessKey = () => loadEnv().GENERAL.NEXT_PUBLIC_SUPER_SECRET_KEY; export const getAccessKey = () =>
loadEnv().GENERAL.NEXT_PUBLIC_SUPER_SECRET_KEY;
export const getBackendURL = () => loadEnv().GENERAL.NEXT_PUBLIC_API_URL; export const getBackendURL = () => loadEnv().GENERAL.NEXT_PUBLIC_API_URL;

View file

@ -1,6 +1,6 @@
import clsx, { ClassValue } from 'clsx'; import clsx, { ClassValue } from 'clsx';
import { twMerge } from 'tailwind-merge'; import { twMerge } from 'tailwind-merge';
import { getAccessKey, getBackendURL } from './config' import { getAccessKey, getBackendURL } from './config';
export const cn = (...classes: ClassValue[]) => twMerge(clsx(...classes)); export const cn = (...classes: ClassValue[]) => twMerge(clsx(...classes));
@ -29,11 +29,11 @@ export const clientFetch = async (path: string, payload: any): Promise<any> => {
if (secretToken) { if (secretToken) {
if (headers == null) { if (headers == null) {
headers = {}; headers = {};
}; }
headers['Authorization'] = `Bearer ${secretToken}`; headers['Authorization'] = `Bearer ${secretToken}`;
payload.headers = headers; payload.headers = headers;
}; }
return await fetch(url, payload); return await fetch(url, payload);
}; };