From 8117815346bc23c15931550c45a65b887ba0725b Mon Sep 17 00:00:00 2001 From: Patrick Wiltrout Date: Thu, 7 Nov 2024 13:31:01 -0500 Subject: [PATCH] added small change to add LLM name to the chats. --- src/db/schema.ts | 2 ++ src/websocket/messageHandler.ts | 10 +++++++++- ui/components/ChatWindow.tsx | 10 ++++++++++ ui/components/MessageBox.tsx | 5 +++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/db/schema.ts b/src/db/schema.ts index 9eefa55..4e37e1a 100644 --- a/src/db/schema.ts +++ b/src/db/schema.ts @@ -5,6 +5,8 @@ export const messages = sqliteTable('messages', { content: text('content').notNull(), chatId: text('chatId').notNull(), messageId: text('messageId').notNull(), + llmName: text('llmName').default(""), + llmProvider: text('llmProvider').default(""), role: text('type', { enum: ['assistant', 'user'] }), metadata: text('metadata', { mode: 'json', diff --git a/src/websocket/messageHandler.ts b/src/websocket/messageHandler.ts index e915b22..eb7ae61 100644 --- a/src/websocket/messageHandler.ts +++ b/src/websocket/messageHandler.ts @@ -18,6 +18,8 @@ type Message = { messageId: string; chatId: string; content: string; + llmName: string; + llmProvider: string; }; type WSMessage = { @@ -42,6 +44,8 @@ const handleEmitterEvents = ( ws: WebSocket, messageId: string, chatId: string, + llmName: string, + llmProvider: string, ) => { let recievedMessage = ''; let sources = []; @@ -76,6 +80,8 @@ const handleEmitterEvents = ( content: recievedMessage, chatId: chatId, messageId: messageId, + llmName: llmName, + llmProvider: llmProvider, role: 'assistant', metadata: JSON.stringify({ createdAt: new Date(), @@ -143,7 +149,7 @@ export const handleMessage = async ( parsedWSMessage.optimizationMode, ); - handleEmitterEvents(emitter, ws, aiMessageId, parsedMessage.chatId); + handleEmitterEvents(emitter, ws, aiMessageId, parsedMessage.chatId, parsedMessage.llmName, parsedMessage.llmProvider); const chat = await db.query.chats.findFirst({ where: eq(chats.id, parsedMessage.chatId), @@ -172,6 +178,8 @@ export const handleMessage = async ( content: parsedMessage.content, chatId: parsedMessage.chatId, messageId: humanMessageId, + llmName: parsedMessage.llmName, + llmProvider: parsedMessage.llmProvider, role: 'user', metadata: JSON.stringify({ createdAt: new Date(), diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index f9bd583..31a0bc9 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -16,6 +16,8 @@ export type Message = { chatId: string; createdAt: Date; content: string; + llmName: string; + llmProvider: string; role: 'user' | 'assistant'; suggestions?: string[]; sources?: Document[]; @@ -353,6 +355,8 @@ const ChatWindow = ({ id }: { id?: string }) => { messageId: messageId, chatId: chatId!, content: message, + llmName: localStorage.getItem('chatModel') || 'NOT_SET', + llmProvider: localStorage.getItem('chatModelProvider') || 'NOT_SET', }, focusMode: focusMode, optimizationMode: optimizationMode, @@ -368,6 +372,8 @@ const ChatWindow = ({ id }: { id?: string }) => { chatId: chatId!, role: 'user', createdAt: new Date(), + llmName: localStorage.getItem('chatModel') || 'NOT_SET', + llmProvider: localStorage.getItem('chatModelProvider') || 'NOT_SET', }, ]); @@ -392,6 +398,8 @@ const ChatWindow = ({ id }: { id?: string }) => { role: 'assistant', sources: sources, createdAt: new Date(), + llmName: localStorage.getItem('chatModel') || 'NOT_SET', + llmProvider: localStorage.getItem('chatModelProvider') || 'NOT_SET', }, ]); added = true; @@ -410,6 +418,8 @@ const ChatWindow = ({ id }: { id?: string }) => { role: 'assistant', sources: sources, createdAt: new Date(), + llmName: localStorage.getItem('chatModel') || 'NOT_SET', + llmProvider: localStorage.getItem('chatModelProvider') || 'NOT_SET', }, ]); added = true; diff --git a/ui/components/MessageBox.tsx b/ui/components/MessageBox.tsx index 5222c7c..e0977a0 100644 --- a/ui/components/MessageBox.tsx +++ b/ui/components/MessageBox.tsx @@ -77,10 +77,15 @@ const MessageBox = ({ {message.role === 'assistant' && (
+
+
+

{message.llmName}

+
+ {message.sources && message.sources.length > 0 && (