From 811822c03ba04c9b3b77a92690cd2c8844ec2fee Mon Sep 17 00:00:00 2001 From: ItzCrazyKns Date: Wed, 17 Apr 2024 10:22:20 +0530 Subject: [PATCH] feat(agents): use ollama models --- .env.example | 4 ++-- src/agents/academicSearchAgent.ts | 19 ++++++++++++------- src/agents/imageSearchAgent.ts | 6 +++--- src/agents/redditSearchAgent.ts | 19 ++++++++++++------- src/agents/webSearchAgent.ts | 19 ++++++++++++------- src/agents/wolframAlphaSearchAgent.ts | 13 ++++++++----- src/agents/writingAssistant.ts | 7 ++++--- src/agents/youtubeSearchAgent.ts | 19 ++++++++++++------- 8 files changed, 65 insertions(+), 41 deletions(-) diff --git a/.env.example b/.env.example index bc67919..1c735cc 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,5 @@ PORT=3001 -OPENAI_API_KEY= +OLLAMA_URL=http://localhost:11434 # url of the ollama server SIMILARITY_MEASURE=cosine # cosine or dot SEARXNG_API_URL= # no need to fill this if using docker -MODEL_NAME=gpt-3.5-turbo \ No newline at end of file +MODEL_NAME=llama2 \ No newline at end of file diff --git a/src/agents/academicSearchAgent.ts b/src/agents/academicSearchAgent.ts index 7c3d448..0e78581 100644 --- a/src/agents/academicSearchAgent.ts +++ b/src/agents/academicSearchAgent.ts @@ -9,7 +9,9 @@ import { RunnableMap, RunnableLambda, } from '@langchain/core/runnables'; -import { ChatOpenAI, OpenAI, OpenAIEmbeddings } from '@langchain/openai'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; +import { Ollama } from '@langchain/community/llms/ollama'; +import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama'; import { StringOutputParser } from '@langchain/core/output_parsers'; import { Document } from '@langchain/core/documents'; import { searchSearxng } from '../core/searxng'; @@ -18,18 +20,21 @@ import formatChatHistoryAsString from '../utils/formatHistory'; import eventEmitter from 'events'; import computeSimilarity from '../utils/computeSimilarity'; -const chatLLM = new ChatOpenAI({ - modelName: process.env.MODEL_NAME, +const chatLLM = new ChatOllama({ + baseUrl: process.env.OLLAMA_URL, + model: process.env.MODEL_NAME, temperature: 0.7, }); -const llm = new OpenAI({ +const llm = new Ollama({ temperature: 0, - modelName: process.env.MODEL_NAME, + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); -const embeddings = new OpenAIEmbeddings({ - modelName: 'text-embedding-3-large', +const embeddings = new OllamaEmbeddings({ + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); const basicAcademicSearchRetrieverPrompt = ` diff --git a/src/agents/imageSearchAgent.ts b/src/agents/imageSearchAgent.ts index 3a2c9db..1692aba 100644 --- a/src/agents/imageSearchAgent.ts +++ b/src/agents/imageSearchAgent.ts @@ -4,15 +4,15 @@ import { RunnableLambda, } from '@langchain/core/runnables'; import { PromptTemplate } from '@langchain/core/prompts'; -import { OpenAI } from '@langchain/openai'; +import { Ollama } from '@langchain/community/llms/ollama'; import formatChatHistoryAsString from '../utils/formatHistory'; import { BaseMessage } from '@langchain/core/messages'; import { StringOutputParser } from '@langchain/core/output_parsers'; import { searchSearxng } from '../core/searxng'; -const llm = new OpenAI({ +const llm = new Ollama({ temperature: 0, - modelName: process.env.MODEL_NAME, + model: process.env.MODEL_NAME, }); const imageSearchChainPrompt = ` diff --git a/src/agents/redditSearchAgent.ts b/src/agents/redditSearchAgent.ts index 77f293e..d5ab77c 100644 --- a/src/agents/redditSearchAgent.ts +++ b/src/agents/redditSearchAgent.ts @@ -9,7 +9,9 @@ import { RunnableMap, RunnableLambda, } from '@langchain/core/runnables'; -import { ChatOpenAI, OpenAI, OpenAIEmbeddings } from '@langchain/openai'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; +import { Ollama } from '@langchain/community/llms/ollama'; +import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama'; import { StringOutputParser } from '@langchain/core/output_parsers'; import { Document } from '@langchain/core/documents'; import { searchSearxng } from '../core/searxng'; @@ -18,18 +20,21 @@ import formatChatHistoryAsString from '../utils/formatHistory'; import eventEmitter from 'events'; import computeSimilarity from '../utils/computeSimilarity'; -const chatLLM = new ChatOpenAI({ - modelName: process.env.MODEL_NAME, +const chatLLM = new ChatOllama({ + baseUrl: process.env.OLLAMA_URL, + model: process.env.MODEL_NAME, temperature: 0.7, }); -const llm = new OpenAI({ +const llm = new Ollama({ temperature: 0, - modelName: process.env.MODEL_NAME, + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); -const embeddings = new OpenAIEmbeddings({ - modelName: 'text-embedding-3-large', +const embeddings = new OllamaEmbeddings({ + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); const basicRedditSearchRetrieverPrompt = ` diff --git a/src/agents/webSearchAgent.ts b/src/agents/webSearchAgent.ts index f5799e3..5d60dda 100644 --- a/src/agents/webSearchAgent.ts +++ b/src/agents/webSearchAgent.ts @@ -9,7 +9,9 @@ import { RunnableMap, RunnableLambda, } from '@langchain/core/runnables'; -import { ChatOpenAI, OpenAI, OpenAIEmbeddings } from '@langchain/openai'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; +import { Ollama } from '@langchain/community/llms/ollama'; +import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama'; import { StringOutputParser } from '@langchain/core/output_parsers'; import { Document } from '@langchain/core/documents'; import { searchSearxng } from '../core/searxng'; @@ -18,18 +20,21 @@ import formatChatHistoryAsString from '../utils/formatHistory'; import eventEmitter from 'events'; import computeSimilarity from '../utils/computeSimilarity'; -const chatLLM = new ChatOpenAI({ - modelName: process.env.MODEL_NAME, +const chatLLM = new ChatOllama({ + baseUrl: process.env.OLLAMA_URL, + model: process.env.MODEL_NAME, temperature: 0.7, }); -const llm = new OpenAI({ +const llm = new Ollama({ temperature: 0, - modelName: process.env.MODEL_NAME, + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); -const embeddings = new OpenAIEmbeddings({ - modelName: 'text-embedding-3-large', +const embeddings = new OllamaEmbeddings({ + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); const basicSearchRetrieverPrompt = ` diff --git a/src/agents/wolframAlphaSearchAgent.ts b/src/agents/wolframAlphaSearchAgent.ts index a9a3202..5f42ed7 100644 --- a/src/agents/wolframAlphaSearchAgent.ts +++ b/src/agents/wolframAlphaSearchAgent.ts @@ -9,7 +9,8 @@ import { RunnableMap, RunnableLambda, } from '@langchain/core/runnables'; -import { ChatOpenAI, OpenAI, OpenAIEmbeddings } from '@langchain/openai'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; +import { Ollama } from '@langchain/community/llms/ollama'; import { StringOutputParser } from '@langchain/core/output_parsers'; import { Document } from '@langchain/core/documents'; import { searchSearxng } from '../core/searxng'; @@ -17,14 +18,16 @@ import type { StreamEvent } from '@langchain/core/tracers/log_stream'; import formatChatHistoryAsString from '../utils/formatHistory'; import eventEmitter from 'events'; -const chatLLM = new ChatOpenAI({ - modelName: process.env.MODEL_NAME, +const chatLLM = new ChatOllama({ + baseUrl: process.env.OLLAMA_URL, + model: process.env.MODEL_NAME, temperature: 0.7, }); -const llm = new OpenAI({ +const llm = new Ollama({ temperature: 0, - modelName: process.env.MODEL_NAME, + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); const basicWolframAlphaSearchRetrieverPrompt = ` diff --git a/src/agents/writingAssistant.ts b/src/agents/writingAssistant.ts index 2c8d66e..eba9872 100644 --- a/src/agents/writingAssistant.ts +++ b/src/agents/writingAssistant.ts @@ -4,13 +4,14 @@ import { MessagesPlaceholder, } from '@langchain/core/prompts'; import { RunnableSequence } from '@langchain/core/runnables'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; import { StringOutputParser } from '@langchain/core/output_parsers'; import type { StreamEvent } from '@langchain/core/tracers/log_stream'; import eventEmitter from 'events'; -const chatLLM = new ChatOpenAI({ - modelName: process.env.MODEL_NAME, +const chatLLM = new ChatOllama({ + baseUrl: process.env.OLLAMA_URL, + model: process.env.MODEL_NAME, temperature: 0.7, }); diff --git a/src/agents/youtubeSearchAgent.ts b/src/agents/youtubeSearchAgent.ts index 9ab5ed8..7fa258b 100644 --- a/src/agents/youtubeSearchAgent.ts +++ b/src/agents/youtubeSearchAgent.ts @@ -9,7 +9,9 @@ import { RunnableMap, RunnableLambda, } from '@langchain/core/runnables'; -import { ChatOpenAI, OpenAI, OpenAIEmbeddings } from '@langchain/openai'; +import { ChatOllama } from '@langchain/community/chat_models/ollama'; +import { Ollama } from '@langchain/community/llms/ollama'; +import { OllamaEmbeddings } from '@langchain/community/embeddings/ollama'; import { StringOutputParser } from '@langchain/core/output_parsers'; import { Document } from '@langchain/core/documents'; import { searchSearxng } from '../core/searxng'; @@ -18,18 +20,21 @@ import formatChatHistoryAsString from '../utils/formatHistory'; import eventEmitter from 'events'; import computeSimilarity from '../utils/computeSimilarity'; -const chatLLM = new ChatOpenAI({ - modelName: process.env.MODEL_NAME, +const chatLLM = new ChatOllama({ + baseUrl: process.env.OLLAMA_URL, + model: process.env.MODEL_NAME, temperature: 0.7, }); -const llm = new OpenAI({ +const llm = new Ollama({ temperature: 0, - modelName: process.env.MODEL_NAME, + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); -const embeddings = new OpenAIEmbeddings({ - modelName: 'text-embedding-3-large', +const embeddings = new OllamaEmbeddings({ + model: process.env.MODEL_NAME, + baseUrl: process.env.OLLAMA_URL, }); const basicYoutubeSearchRetrieverPrompt = `