From ebbe18ab4547eb01b2460313ffab863cf2447fa6 Mon Sep 17 00:00:00 2001 From: Hristo <53634432+izo0x90@users.noreply.github.com> Date: Tue, 14 May 2024 15:05:17 -0400 Subject: [PATCH] Adds Google VertexAI as model provider --- README.md | 21 +++++++++--------- docker-compose.yaml | 2 +- src/app.ts | 2 +- src/auth.ts | 34 ++++++++++++++++++------------ src/config.ts | 23 ++++++++++++-------- src/lib/providers.ts | 18 ++++++++++++++++ src/websocket/connectionManager.ts | 2 +- ui/components/ChatWindow.tsx | 15 ++++++------- ui/components/SearchImages.tsx | 25 ++++++++++------------ ui/components/SearchVideos.tsx | 25 ++++++++++------------ ui/lib/config.ts | 5 +++-- ui/lib/utils.ts | 6 +++--- 12 files changed, 101 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index 79989e3..df257a8 100644 --- a/README.md +++ b/README.md @@ -112,16 +112,17 @@ You need to edit the ports accordingly. 1a: Copy the `sample.env` file to `.env` 1b: Copy the `deploy/gcp/sample.env` file to `deploy/gcp/.env` 2a: Fillout desired LLM provider access keys etc. in `.env` - - Note: you will have to comeback and edit this file again once you have the address of the K8s backend deploy -2b: Fillout the GCP info in `deploy/gcp/.env` -3: Edit `GCP_REPO` to the correct docker image repo path if you are using something other than Container registry -4: Edit the `PREFIX` if you would like images and GKE entities to be prefixed with something else -5: In `deploy/gcp` run `make init` to initialize terraform -6: Follow the normal Preplexica configuration steps outlined in the project readme -7: Auth docker with the appropriate credential for repo Ex. for `gcr.io` -> `gcloud auth configure-docker` -8: In `deploy/gcp` run `make build-deplpy` to build and push the project images to the repo, create a GKE cluster and deploy the app -9: Once deployed successfully edit the `.env` file in the root project folder and update the `REMOTE_BACKEND_ADDRESS` with the remote k8s deployment address and port -10: In root project folder run `make rebuild-run-app-only` + +- Note: you will have to comeback and edit this file again once you have the address of the K8s backend deploy + 2b: Fillout the GCP info in `deploy/gcp/.env` + 3: Edit `GCP_REPO` to the correct docker image repo path if you are using something other than Container registry + 4: Edit the `PREFIX` if you would like images and GKE entities to be prefixed with something else + 5: In `deploy/gcp` run `make init` to initialize terraform + 6: Follow the normal Preplexica configuration steps outlined in the project readme + 7: Auth docker with the appropriate credential for repo Ex. for `gcr.io` -> `gcloud auth configure-docker` + 8: In `deploy/gcp` run `make build-deplpy` to build and push the project images to the repo, create a GKE cluster and deploy the app + 9: Once deployed successfully edit the `.env` file in the root project folder and update the `REMOTE_BACKEND_ADDRESS` with the remote k8s deployment address and port + 10: In root project folder run `make rebuild-run-app-only` If you configured everything correctly frontend app will run locally and provide you with a local url to open it. Now you can run queries against the remotely deployed backend from your local machine. :celebrate: diff --git a/docker-compose.yaml b/docker-compose.yaml index 0559871..9faa891 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -17,7 +17,7 @@ services: args: - SEARXNG_API_URL=null environment: - SEARXNG_API_URL: "http://searxng:8080" + SEARXNG_API_URL: 'http://searxng:8080' SUPER_SECRET_KEY: ${SUPER_SECRET_KEY} OPENAI: ${OPENAI} GROQ: ${GROQ} diff --git a/src/app.ts b/src/app.ts index 4109997..d5dcc68 100644 --- a/src/app.ts +++ b/src/app.ts @@ -21,7 +21,7 @@ app.use(cors(corsOptions)); if (getAccessKey()) { app.all('*', requireAccessKey); -}; +} app.use(express.json()); diff --git a/src/auth.ts b/src/auth.ts index c6f3cef..b9f7e3b 100644 --- a/src/auth.ts +++ b/src/auth.ts @@ -1,21 +1,29 @@ -import { - getAccessKey, -} from './config'; +import { auth } from 'google-auth-library'; +import { getAccessKey } from './config'; export const requireAccessKey = (req, res, next) => { - const authHeader = req.headers.authorization; + const authHeader = req.headers.authorization; - if (authHeader) { - if (!checkAccessKey(authHeader)) { - return res.sendStatus(403); - } - next(); - } else { - res.sendStatus(401); + if (authHeader) { + if (!checkAccessKey(authHeader)) { + return res.sendStatus(403); } + next(); + } else { + res.sendStatus(401); + } }; export const checkAccessKey = (authHeader) => { - const token = authHeader.split(' ')[1]; - return Boolean(authHeader && (token === getAccessKey())); + const token = authHeader.split(' ')[1]; + return Boolean(authHeader && token === getAccessKey()); +}; + +export const hasGCPCredentials = async () => { + try { + const credentials = await auth.getCredentials(); + return Object.keys(credentials).length > 0; + } catch (e) { + return false; + } }; diff --git a/src/config.ts b/src/config.ts index 05d824d..39f5c28 100644 --- a/src/config.ts +++ b/src/config.ts @@ -34,33 +34,38 @@ const loadEnv = () => { GENERAL: { PORT: Number(process.env.PORT), SIMILARITY_MEASURE: process.env.SIMILARITY_MEASURE, - SUPER_SECRET_KEY: process.env.SUPER_SECRET_KEY + SUPER_SECRET_KEY: process.env.SUPER_SECRET_KEY, }, API_KEYS: { OPENAI: process.env.OPENAI, - GROQ: process.env.GROQ + GROQ: process.env.GROQ, }, API_ENDPOINTS: { SEARXNG: process.env.SEARXNG_API_URL, - OLLAMA: process.env.OLLAMA_API_URL - } + OLLAMA: process.env.OLLAMA_API_URL, + }, } as Config; }; export const getPort = () => loadConfig().GENERAL.PORT; -export const getAccessKey = () => loadEnv().GENERAL.SUPER_SECRET_KEY || loadConfig().GENERAL.SUPER_SECRET_KEY; +export const getAccessKey = () => + loadEnv().GENERAL.SUPER_SECRET_KEY || loadConfig().GENERAL.SUPER_SECRET_KEY; export const getSimilarityMeasure = () => loadConfig().GENERAL.SIMILARITY_MEASURE; -export const getOpenaiApiKey = () => loadEnv().API_KEYS.OPENAI || loadConfig().API_KEYS.OPENAI; +export const getOpenaiApiKey = () => + loadEnv().API_KEYS.OPENAI || loadConfig().API_KEYS.OPENAI; -export const getGroqApiKey = () => loadEnv().API_KEYS.GROQ || loadConfig().API_KEYS.GROQ; +export const getGroqApiKey = () => + loadEnv().API_KEYS.GROQ || loadConfig().API_KEYS.GROQ; -export const getSearxngApiEndpoint = () => loadEnv().API_ENDPOINTS.SEARXNG || loadConfig().API_ENDPOINTS.SEARXNG; +export const getSearxngApiEndpoint = () => + loadEnv().API_ENDPOINTS.SEARXNG || loadConfig().API_ENDPOINTS.SEARXNG; -export const getOllamaApiEndpoint = () => loadEnv().API_ENDPOINTS.OLLAMA || loadConfig().API_ENDPOINTS.OLLAMA; +export const getOllamaApiEndpoint = () => + loadEnv().API_ENDPOINTS.OLLAMA || loadConfig().API_ENDPOINTS.OLLAMA; export const updateConfig = (config: RecursivePartial) => { const currentConfig = loadConfig(); diff --git a/src/lib/providers.ts b/src/lib/providers.ts index 751d23f..44867af 100644 --- a/src/lib/providers.ts +++ b/src/lib/providers.ts @@ -1,6 +1,7 @@ import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai'; import { ChatOllama } from '@langchain/community/chat_models/ollama'; import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama'; +import { hasGCPCredentials } from '../auth'; import { getGroqApiKey, getOllamaApiEndpoint, @@ -111,6 +112,23 @@ export const getAvailableChatModelProviders = async () => { } } + if (await hasGCPCredentials()) { + try { + models['vertexai'] = { + 'gemini-1.5-pro (preview-0409)': new VertexAI({ + temperature: 0.7, + modelName: 'gemini-1.5-pro-preview-0409', + }), + 'gemini-1.0-pro (Latest)': new VertexAI({ + temperature: 0.7, + modelName: 'gemini-1.0-pro', + }), + }; + } catch (err) { + logger.error(`Error loading VertexAI models: ${err}`); + } + } + models['custom_openai'] = {}; return models; diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index f504c10..b584e2d 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -31,7 +31,7 @@ export const handleConnection = async ( }), ); ws.close(); - }; + } } const [chatModelProviders, embeddingModelProviders] = await Promise.all([ diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index 3a58b9f..e50f2ee 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -36,14 +36,11 @@ const useSocket = (url: string) => { !embeddingModel || !embeddingModelProvider ) { - const providers = await clientFetch( - '/models', - { - headers: { - 'Content-Type': 'application/json', - }, + const providers = await clientFetch('/models', { + headers: { + 'Content-Type': 'application/json', }, - ).then(async (res) => await res.json()); + }).then(async (res) => await res.json()); const chatModelProviders = providers.chatModelProviders; const embeddingModelProviders = providers.embeddingModelProviders; @@ -103,8 +100,8 @@ const useSocket = (url: string) => { const secretToken = getAccessKey(); if (secretToken) { - protocols = ["Authorization", `${secretToken}`]; - }; + protocols = ['Authorization', `${secretToken}`]; + } const ws = new WebSocket(wsURL.toString(), protocols); diff --git a/ui/components/SearchImages.tsx b/ui/components/SearchImages.tsx index 076537f..68a305a 100644 --- a/ui/components/SearchImages.tsx +++ b/ui/components/SearchImages.tsx @@ -34,21 +34,18 @@ const SearchImages = ({ const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModel = localStorage.getItem('chatModel'); - const res = await clientFetch( - '/images', - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - query: query, - chat_history: chat_history, - chat_model_provider: chatModelProvider, - chat_model: chatModel, - }), + const res = await clientFetch('/images', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', }, - ); + body: JSON.stringify({ + query: query, + chat_history: chat_history, + chat_model_provider: chatModelProvider, + chat_model: chatModel, + }), + }); const data = await res.json(); diff --git a/ui/components/SearchVideos.tsx b/ui/components/SearchVideos.tsx index b91627c..e9ef479 100644 --- a/ui/components/SearchVideos.tsx +++ b/ui/components/SearchVideos.tsx @@ -47,21 +47,18 @@ const Searchvideos = ({ const chatModelProvider = localStorage.getItem('chatModelProvider'); const chatModel = localStorage.getItem('chatModel'); - const res = await clientFetch( - '/videos', - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - query: query, - chat_history: chat_history, - chat_model_provider: chatModelProvider, - chat_model: chatModel, - }), + const res = await clientFetch('/videos', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', }, - ); + body: JSON.stringify({ + query: query, + chat_history: chat_history, + chat_model_provider: chatModelProvider, + chat_model: chatModel, + }), + }); const data = await res.json(); diff --git a/ui/lib/config.ts b/ui/lib/config.ts index 675cd8d..f2d0eea 100644 --- a/ui/lib/config.ts +++ b/ui/lib/config.ts @@ -11,11 +11,12 @@ const loadEnv = () => { GENERAL: { NEXT_PUBLIC_SUPER_SECRET_KEY: process.env.NEXT_PUBLIC_SUPER_SECRET_KEY!, NEXT_PUBLIC_API_URL: process.env.NEXT_PUBLIC_API_URL!, - NEXT_PUBLIC_WS_URL: process.env.NEXT_PUBLIC_WS_URL! + NEXT_PUBLIC_WS_URL: process.env.NEXT_PUBLIC_WS_URL!, }, } as Config; }; -export const getAccessKey = () => loadEnv().GENERAL.NEXT_PUBLIC_SUPER_SECRET_KEY; +export const getAccessKey = () => + loadEnv().GENERAL.NEXT_PUBLIC_SUPER_SECRET_KEY; export const getBackendURL = () => loadEnv().GENERAL.NEXT_PUBLIC_API_URL; diff --git a/ui/lib/utils.ts b/ui/lib/utils.ts index 6d59e79..318873e 100644 --- a/ui/lib/utils.ts +++ b/ui/lib/utils.ts @@ -1,6 +1,6 @@ import clsx, { ClassValue } from 'clsx'; import { twMerge } from 'tailwind-merge'; -import { getAccessKey, getBackendURL } from './config' +import { getAccessKey, getBackendURL } from './config'; export const cn = (...classes: ClassValue[]) => twMerge(clsx(...classes)); @@ -29,11 +29,11 @@ export const clientFetch = async (path: string, payload: any): Promise => { if (secretToken) { if (headers == null) { headers = {}; - }; + } headers['Authorization'] = `Bearer ${secretToken}`; payload.headers = headers; - }; + } return await fetch(url, payload); };