This commit is contained in:
Jin Yucong 2024-07-05 14:36:50 +08:00
parent 5b1aaee605
commit 3b737a078a
63 changed files with 1132 additions and 1853 deletions

View file

@ -1,24 +1,16 @@
import { BaseMessage } from '@langchain/core/messages';
import {
PromptTemplate,
ChatPromptTemplate,
MessagesPlaceholder,
} from '@langchain/core/prompts';
import {
RunnableSequence,
RunnableMap,
RunnableLambda,
} from '@langchain/core/runnables';
import { StringOutputParser } from '@langchain/core/output_parsers';
import { Document } from '@langchain/core/documents';
import { searchSearxng } from '../lib/searxng';
import type { StreamEvent } from '@langchain/core/tracers/log_stream';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Embeddings } from '@langchain/core/embeddings';
import formatChatHistoryAsString from '../utils/formatHistory';
import eventEmitter from 'events';
import computeSimilarity from '../utils/computeSimilarity';
import logger from '../utils/logger';
import { BaseMessage } from "@langchain/core/messages";
import { PromptTemplate, ChatPromptTemplate, MessagesPlaceholder } from "@langchain/core/prompts";
import { RunnableSequence, RunnableMap, RunnableLambda } from "@langchain/core/runnables";
import { StringOutputParser } from "@langchain/core/output_parsers";
import { Document } from "@langchain/core/documents";
import { searchSearxng } from "../lib/searxng";
import type { StreamEvent } from "@langchain/core/tracers/log_stream";
import type { BaseChatModel } from "@langchain/core/language_models/chat_models";
import type { Embeddings } from "@langchain/core/embeddings";
import formatChatHistoryAsString from "../utils/formatHistory";
import eventEmitter from "events";
import computeSimilarity from "../utils/computeSimilarity";
import logger from "../utils/logger";
const basicSearchRetrieverPrompt = `
You will be given a conversation below and a follow up question. You need to rephrase the follow-up question if needed so it is a standalone question that can be used by the LLM to search the web for information.
@ -65,34 +57,16 @@ const basicWebSearchResponsePrompt = `
const strParser = new StringOutputParser();
const handleStream = async (
stream: AsyncGenerator<StreamEvent, unknown, unknown>,
emitter: eventEmitter,
) => {
const handleStream = async (stream: AsyncGenerator<StreamEvent, unknown, unknown>, emitter: eventEmitter) => {
for await (const event of stream) {
if (
event.event === 'on_chain_end' &&
event.name === 'FinalSourceRetriever'
) {
emitter.emit(
'data',
JSON.stringify({ type: 'sources', data: event.data.output }),
);
if (event.event === "on_chain_end" && event.name === "FinalSourceRetriever") {
emitter.emit("data", JSON.stringify({ type: "sources", data: event.data.output }));
}
if (
event.event === 'on_chain_stream' &&
event.name === 'FinalResponseGenerator'
) {
emitter.emit(
'data',
JSON.stringify({ type: 'response', data: event.data.chunk }),
);
if (event.event === "on_chain_stream" && event.name === "FinalResponseGenerator") {
emitter.emit("data", JSON.stringify({ type: "response", data: event.data.chunk }));
}
if (
event.event === 'on_chain_end' &&
event.name === 'FinalResponseGenerator'
) {
emitter.emit('end');
if (event.event === "on_chain_end" && event.name === "FinalResponseGenerator") {
emitter.emit("end");
}
}
};
@ -108,16 +82,16 @@ const createBasicWebSearchRetrieverChain = (llm: BaseChatModel) => {
llm,
strParser,
RunnableLambda.from(async (input: string) => {
if (input === 'not_needed') {
return { query: '', docs: [] };
if (input === "not_needed") {
return { query: "", docs: [] };
}
const res = await searchSearxng(input, {
language: 'en',
language: "en",
});
const documents = res.results.map(
(result) =>
result =>
new Document({
pageContent: result.content,
metadata: {
@ -133,35 +107,22 @@ const createBasicWebSearchRetrieverChain = (llm: BaseChatModel) => {
]);
};
const createBasicWebSearchAnsweringChain = (
llm: BaseChatModel,
embeddings: Embeddings,
) => {
const createBasicWebSearchAnsweringChain = (llm: BaseChatModel, embeddings: Embeddings) => {
const basicWebSearchRetrieverChain = createBasicWebSearchRetrieverChain(llm);
const processDocs = async (docs: Document[]) => {
return docs
.map((_, index) => `${index + 1}. ${docs[index].pageContent}`)
.join('\n');
return docs.map((_, index) => `${index + 1}. ${docs[index].pageContent}`).join("\n");
};
const rerankDocs = async ({
query,
docs,
}: {
query: string;
docs: Document[];
}) => {
const rerankDocs = async ({ query, docs }: { query: string; docs: Document[] }) => {
if (docs.length === 0) {
return docs;
}
const docsWithContent = docs.filter(
(doc) => doc.pageContent && doc.pageContent.length > 0,
);
const docsWithContent = docs.filter(doc => doc.pageContent && doc.pageContent.length > 0);
const [docEmbeddings, queryEmbedding] = await Promise.all([
embeddings.embedDocuments(docsWithContent.map((doc) => doc.pageContent)),
embeddings.embedDocuments(docsWithContent.map(doc => doc.pageContent)),
embeddings.embedQuery(query),
]);
@ -176,9 +137,9 @@ const createBasicWebSearchAnsweringChain = (
const sortedDocs = similarity
.sort((a, b) => b.similarity - a.similarity)
.filter((sim) => sim.similarity > 0.5)
.filter(sim => sim.similarity > 0.5)
.slice(0, 15)
.map((sim) => docsWithContent[sim.index]);
.map(sim => docsWithContent[sim.index]);
return sortedDocs;
};
@ -188,43 +149,35 @@ const createBasicWebSearchAnsweringChain = (
query: (input: BasicChainInput) => input.query,
chat_history: (input: BasicChainInput) => input.chat_history,
context: RunnableSequence.from([
(input) => ({
input => ({
query: input.query,
chat_history: formatChatHistoryAsString(input.chat_history),
}),
basicWebSearchRetrieverChain
.pipe(rerankDocs)
.withConfig({
runName: 'FinalSourceRetriever',
runName: "FinalSourceRetriever",
})
.pipe(processDocs),
]),
}),
ChatPromptTemplate.fromMessages([
['system', basicWebSearchResponsePrompt],
new MessagesPlaceholder('chat_history'),
['user', '{query}'],
["system", basicWebSearchResponsePrompt],
new MessagesPlaceholder("chat_history"),
["user", "{query}"],
]),
llm,
strParser,
]).withConfig({
runName: 'FinalResponseGenerator',
runName: "FinalResponseGenerator",
});
};
const basicWebSearch = (
query: string,
history: BaseMessage[],
llm: BaseChatModel,
embeddings: Embeddings,
) => {
const basicWebSearch = (query: string, history: BaseMessage[], llm: BaseChatModel, embeddings: Embeddings) => {
const emitter = new eventEmitter();
try {
const basicWebSearchAnsweringChain = createBasicWebSearchAnsweringChain(
llm,
embeddings,
);
const basicWebSearchAnsweringChain = createBasicWebSearchAnsweringChain(llm, embeddings);
const stream = basicWebSearchAnsweringChain.streamEvents(
{
@ -232,28 +185,20 @@ const basicWebSearch = (
query: query,
},
{
version: 'v1',
version: "v1",
},
);
handleStream(stream, emitter);
} catch (err) {
emitter.emit(
'error',
JSON.stringify({ data: 'An error has occurred please try again later' }),
);
emitter.emit("error", JSON.stringify({ data: "An error has occurred please try again later" }));
logger.error(`Error in websearch: ${err}`);
}
return emitter;
};
const handleWebSearch = (
message: string,
history: BaseMessage[],
llm: BaseChatModel,
embeddings: Embeddings,
) => {
const handleWebSearch = (message: string, history: BaseMessage[], llm: BaseChatModel, embeddings: Embeddings) => {
const emitter = basicWebSearch(message, history, llm, embeddings);
return emitter;
};