79 lines
2.8 KiB
TypeScript
79 lines
2.8 KiB
TypeScript
import { WebSocket } from "ws";
|
|
import { handleMessage } from "./messageHandler";
|
|
import { getAvailableEmbeddingModelProviders, getAvailableChatModelProviders } from "../lib/providers";
|
|
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
|
|
import type { Embeddings } from "@langchain/core/embeddings";
|
|
import type { IncomingMessage } from "node:http";
|
|
import logger from "../utils/logger";
|
|
import { ChatOpenAI } from "@langchain/openai";
|
|
|
|
export const handleConnection = async (ws: WebSocket, request: IncomingMessage) => {
|
|
try {
|
|
const searchParameters = new URL(request.url, `http://${request.headers.host}`).searchParams;
|
|
|
|
const [chatModelProviders, embeddingModelProviders] = await Promise.all([
|
|
getAvailableChatModelProviders(),
|
|
getAvailableEmbeddingModelProviders(),
|
|
]);
|
|
|
|
const chatModelProvider = searchParameters.get("chatModelProvider") || Object.keys(chatModelProviders)[0];
|
|
const chatModel = searchParameters.get("chatModel") || Object.keys(chatModelProviders[chatModelProvider])[0];
|
|
|
|
const embeddingModelProvider =
|
|
searchParameters.get("embeddingModelProvider") || Object.keys(embeddingModelProviders)[0];
|
|
const embeddingModel =
|
|
searchParameters.get("embeddingModel") || Object.keys(embeddingModelProviders[embeddingModelProvider])[0];
|
|
|
|
let llm: BaseChatModel | undefined;
|
|
let embeddings: Embeddings | undefined;
|
|
|
|
if (
|
|
chatModelProviders[chatModelProvider] &&
|
|
chatModelProviders[chatModelProvider][chatModel] &&
|
|
chatModelProvider != "custom_openai"
|
|
) {
|
|
llm = chatModelProviders[chatModelProvider][chatModel] as BaseChatModel | undefined;
|
|
} else if (chatModelProvider == "custom_openai") {
|
|
llm = new ChatOpenAI({
|
|
modelName: chatModel,
|
|
openAIApiKey: searchParameters.get("openAIApiKey"),
|
|
temperature: 0.7,
|
|
configuration: {
|
|
baseURL: searchParameters.get("openAIBaseURL"),
|
|
},
|
|
});
|
|
}
|
|
|
|
if (
|
|
embeddingModelProviders[embeddingModelProvider] &&
|
|
embeddingModelProviders[embeddingModelProvider][embeddingModel]
|
|
) {
|
|
embeddings = embeddingModelProviders[embeddingModelProvider][embeddingModel] as Embeddings | undefined;
|
|
}
|
|
|
|
if (!llm || !embeddings) {
|
|
ws.send(
|
|
JSON.stringify({
|
|
type: "error",
|
|
data: "Invalid LLM or embeddings model selected, please refresh the page and try again.",
|
|
key: "INVALID_MODEL_SELECTED",
|
|
}),
|
|
);
|
|
ws.close();
|
|
}
|
|
|
|
ws.on("message", async message => await handleMessage(message.toString(), ws, llm, embeddings));
|
|
|
|
ws.on("close", () => logger.debug("Connection closed"));
|
|
} catch (error) {
|
|
ws.send(
|
|
JSON.stringify({
|
|
type: "error",
|
|
data: "Internal server error.",
|
|
key: "INTERNAL_SERVER_ERROR",
|
|
}),
|
|
);
|
|
ws.close();
|
|
logger.error(error);
|
|
}
|
|
};
|