From 3d378735733a955d211401c5f72b62584eb89553 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Thu, 12 Oct 2023 20:37:32 +0530 Subject: [PATCH] enable disable --- app/ui/src/@types/bot.ts | 15 + .../components/Bot/Settings/SettingsCard.tsx | 29 +- app/ui/src/routes/bot/settings.tsx | 16 +- .../src/routes/api/v1/bot/handlers/types.ts | 2 + .../bot/playground/handlers/post.handler.ts | 566 ++++++++++-------- server/src/routes/api/v1/bot/schema/index.ts | 8 +- .../src/routes/bot/handlers/post.handler.ts | 17 +- 7 files changed, 360 insertions(+), 293 deletions(-) create mode 100644 app/ui/src/@types/bot.ts diff --git a/app/ui/src/@types/bot.ts b/app/ui/src/@types/bot.ts new file mode 100644 index 00000000..8dda624e --- /dev/null +++ b/app/ui/src/@types/bot.ts @@ -0,0 +1,15 @@ +export type BotSettings = { + id: string; + name: string; + model: string; + public_id: string; + temperature: number; + embedding: string; + qaPrompt: string; + questionGeneratorPrompt: string; + streaming: boolean; + showRef: boolean; + use_hybrid_search: boolean; + bot_protect: boolean; + use_rag: boolean; +}; diff --git a/app/ui/src/components/Bot/Settings/SettingsCard.tsx b/app/ui/src/components/Bot/Settings/SettingsCard.tsx index 8eeaf6da..a7f1d37d 100644 --- a/app/ui/src/components/Bot/Settings/SettingsCard.tsx +++ b/app/ui/src/components/Bot/Settings/SettingsCard.tsx @@ -13,24 +13,9 @@ import { HELPFUL_ASSISTANT_WITH_CONTEXT_PROMPT, HELPFUL_ASSISTANT_WITHOUT_CONTEXT_PROMPT, } from "../../../utils/prompts"; +import { BotSettings } from "../../../@types/bot"; -export const SettingsCard = ({ - data, -}: { - data: { - id: string; - name: string; - model: string; - public_id: string; - temperature: number; - embedding: string; - qaPrompt: string; - questionGeneratorPrompt: string; - streaming: boolean; - showRef: boolean; - use_hybrid_search: boolean; - }; -}) => { +export const SettingsCard = ({ data }: { data: BotSettings }) => { const [form] = Form.useForm(); const [disableStreaming, setDisableStreaming] = React.useState(false); const params = useParams<{ id: string }>(); @@ -133,6 +118,7 @@ export const SettingsCard = ({ streaming: data.streaming, showRef: data.showRef, use_hybrid_search: data.use_hybrid_search, + bot_protect: data.bot_protect, }} form={form} requiredMark={false} @@ -317,6 +303,15 @@ export const SettingsCard = ({ > + + + + diff --git a/app/ui/src/routes/bot/settings.tsx b/app/ui/src/routes/bot/settings.tsx index 3eb65b57..748d0730 100644 --- a/app/ui/src/routes/bot/settings.tsx +++ b/app/ui/src/routes/bot/settings.tsx @@ -4,6 +4,7 @@ import api from "../../services/api"; import React from "react"; import { SkeletonLoading } from "../../components/Common/SkeletonLoading"; import { SettingsCard } from "../../components/Bot/Settings/SettingsCard"; +import { BotSettings } from "../../@types/bot"; export default function BotSettingsRoot() { const param = useParams<{ id: string }>(); @@ -14,19 +15,7 @@ export default function BotSettingsRoot() { async () => { const response = await api.get(`/bot/${param.id}`); return response.data as { - data: { - id: string; - name: string; - model: string; - public_id: string; - temperature: number; - embedding: string; - qaPrompt: string; - questionGeneratorPrompt: string; - streaming: boolean; - showRef: boolean; - use_hybrid_search: boolean; - }; + data: BotSettings }; }, { @@ -41,7 +30,6 @@ export default function BotSettingsRoot() { }, [status]); return (
- {status === "loading" && } {status === "success" && }
diff --git a/server/src/routes/api/v1/bot/handlers/types.ts b/server/src/routes/api/v1/bot/handlers/types.ts index 741dc27f..90722df0 100644 --- a/server/src/routes/api/v1/bot/handlers/types.ts +++ b/server/src/routes/api/v1/bot/handlers/types.ts @@ -64,5 +64,7 @@ export interface UpdateBotById { streaming: boolean; showRef: boolean; use_hybrid_search: boolean; + bot_protect: boolean; + use_rag: boolean; }; } diff --git a/server/src/routes/api/v1/bot/playground/handlers/post.handler.ts b/server/src/routes/api/v1/bot/playground/handlers/post.handler.ts index 733ec12f..fea1c419 100644 --- a/server/src/routes/api/v1/bot/playground/handlers/post.handler.ts +++ b/server/src/routes/api/v1/bot/playground/handlers/post.handler.ts @@ -14,135 +14,154 @@ export const chatRequestHandler = async ( const bot_id = request.params.id; const { message, history, history_id } = request.body; + try { + const prisma = request.server.prisma; - const prisma = request.server.prisma; - - const bot = await prisma.bot.findFirst({ - where: { - id: bot_id, - user_id: request.user.user_id, - }, - }); - - if (!bot) { - return { - bot: { - text: "You are in the wrong place, buddy.", - sourceDocuments: [], + const bot = await prisma.bot.findFirst({ + where: { + id: bot_id, + user_id: request.user.user_id, }, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", + }); + + if (!bot) { + return { + bot: { text: "You are in the wrong place, buddy.", + sourceDocuments: [], }, - ], - }; - } + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: "You are in the wrong place, buddy.", + }, + ], + }; + } - const temperature = bot.temperature; + const temperature = bot.temperature; - const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingModel = embeddings(bot.embedding); + const sanitizedQuestion = message.trim().replaceAll("\n", " "); + const embeddingModel = embeddings(bot.embedding); - let retriever: BaseRetriever; + let retriever: BaseRetriever; - if (bot.use_hybrid_search) { - retriever = new DialoqbaseHybridRetrival(embeddingModel, { - botId: bot.id, - sourceId: null, - }); - } else { - const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( - embeddingModel, - { + if (bot.use_hybrid_search) { + retriever = new DialoqbaseHybridRetrival(embeddingModel, { botId: bot.id, sourceId: null, - }, - ); - - retriever = vectorstore.asRetriever(); - } - - const model = chatModelProvider(bot.provider, bot.model, temperature); + }); + } else { + const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( + embeddingModel, + { + botId: bot.id, + sourceId: null, + }, + ); - const chain = ConversationalRetrievalQAChain.fromLLM( - model, - retriever, - { - qaTemplate: bot.qaPrompt, - questionGeneratorTemplate: bot.questionGeneratorPrompt, - returnSourceDocuments: true, - }, - ); - - const chat_history = history - .map((chatMessage: any) => { - if (chatMessage.type === "human") { - return `Human: ${chatMessage.text}`; - } else if (chatMessage.type === "ai") { - return `Assistant: ${chatMessage.text}`; - } else { - return `${chatMessage.text}`; - } - }) - .join("\n"); - - console.log(chat_history); - - const response = await chain.call({ - question: sanitizedQuestion, - chat_history: chat_history, - }); + retriever = vectorstore.asRetriever(); + } - let historyId = history_id; + const model = chatModelProvider(bot.provider, bot.model, temperature); - if (!historyId) { - const newHistory = await prisma.botPlayground.create({ - data: { - botId: bot.id, - title: message, + const chain = ConversationalRetrievalQAChain.fromLLM( + model, + retriever, + { + qaTemplate: bot.qaPrompt, + questionGeneratorTemplate: bot.questionGeneratorPrompt, + returnSourceDocuments: true, }, + ); + + const chat_history = history + .map((chatMessage: any) => { + if (chatMessage.type === "human") { + return `Human: ${chatMessage.text}`; + } else if (chatMessage.type === "ai") { + return `Assistant: ${chatMessage.text}`; + } else { + return `${chatMessage.text}`; + } + }) + .join("\n"); + + console.log(chat_history); + + const response = await chain.call({ + question: sanitizedQuestion, + chat_history: chat_history, }); - historyId = newHistory.id; - } - await prisma.botPlaygroundMessage.create({ - data: { - type: "human", - message: message, - botPlaygroundId: historyId, - }, - }); + let historyId = history_id; - await prisma.botPlaygroundMessage.create({ - data: { - type: "ai", - message: response.text, - botPlaygroundId: historyId, - isBot: true, - sources: response?.sourceDocuments, - }, - }); + if (!historyId) { + const newHistory = await prisma.botPlayground.create({ + data: { + botId: bot.id, + title: message, + }, + }); + historyId = newHistory.id; + } - return { - bot: response, - history: [ - ...history, - { + await prisma.botPlaygroundMessage.create({ + data: { type: "human", - text: message, + message: message, + botPlaygroundId: historyId, }, - { + }); + + await prisma.botPlaygroundMessage.create({ + data: { type: "ai", - text: response.text, + message: response.text, + botPlaygroundId: historyId, + isBot: true, + sources: response?.sourceDocuments, }, - ], - }; + }); + + return { + bot: response, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: response.text, + }, + ], + }; + } catch (e) { + return { + bot: { + text: "There was an error processing your request.", + sourceDocuments: [], + }, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: "There was an error processing your request.", + }, + ], + }; + } }; function nextTick() { @@ -160,183 +179,212 @@ export const chatRequestStreamHandler = async ( // type: string; // text: string; // }[]; + try { + console.log("history", history); + const prisma = request.server.prisma; - console.log("history", history); - const prisma = request.server.prisma; - - const bot = await prisma.bot.findFirst({ - where: { - id: bot_id, - user_id: request.user.user_id, - }, - }); - - if (!bot) { - return { - bot: { - text: "You are in the wrong place, buddy.", - sourceDocuments: [], + const bot = await prisma.bot.findFirst({ + where: { + id: bot_id, + user_id: request.user.user_id, }, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", + }); + + if (!bot) { + return { + bot: { text: "You are in the wrong place, buddy.", + sourceDocuments: [], }, - ], - }; - } + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: "You are in the wrong place, buddy.", + }, + ], + }; + } - const temperature = bot.temperature; + const temperature = bot.temperature; - const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingModel = embeddings(bot.embedding); + const sanitizedQuestion = message.trim().replaceAll("\n", " "); + const embeddingModel = embeddings(bot.embedding); - let retriever: BaseRetriever; + let retriever: BaseRetriever; - if (bot.use_hybrid_search) { - retriever = new DialoqbaseHybridRetrival(embeddingModel, { - botId: bot.id, - sourceId: null, - }); - } else { - const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( - embeddingModel, - { + if (bot.use_hybrid_search) { + retriever = new DialoqbaseHybridRetrival(embeddingModel, { botId: bot.id, sourceId: null, - }, - ); + }); + } else { + const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( + embeddingModel, + { + botId: bot.id, + sourceId: null, + }, + ); - retriever = vectorstore.asRetriever(); - } + retriever = vectorstore.asRetriever(); + } - let response: any = null; + let response: any = null; - reply.raw.on("close", () => { - console.log("closed"); - }); + reply.raw.on("close", () => { + console.log("closed"); + }); - const streamedModel = chatModelProvider( - bot.provider, - bot.model, - temperature, - { - streaming: true, - callbacks: [ - { - handleLLMNewToken(token: string) { - // if (token !== '[DONE]') { - // console.log(token); - return reply.sse({ - id: "", - event: "chunk", - data: JSON.stringify({ - message: token || "", - }), - }); - // } else { - // console.log("done"); - // } + const streamedModel = chatModelProvider( + bot.provider, + bot.model, + temperature, + { + streaming: true, + callbacks: [ + { + handleLLMNewToken(token: string) { + // if (token !== '[DONE]') { + // console.log(token); + return reply.sse({ + id: "", + event: "chunk", + data: JSON.stringify({ + message: token || "", + }), + }); + // } else { + // console.log("done"); + // } + }, }, + ], + }, + ); + + const nonStreamingModel = chatModelProvider( + bot.provider, + bot.model, + temperature, + ); + + const chain = ConversationalRetrievalQAChain.fromLLM( + streamedModel, + retriever, + { + qaTemplate: bot.qaPrompt, + questionGeneratorTemplate: bot.questionGeneratorPrompt, + returnSourceDocuments: true, + questionGeneratorChainOptions: { + llm: nonStreamingModel, }, - ], - }, - ); - - const nonStreamingModel = chatModelProvider( - bot.provider, - bot.model, - temperature, - ); - - const chain = ConversationalRetrievalQAChain.fromLLM( - streamedModel, - retriever, - { - qaTemplate: bot.qaPrompt, - questionGeneratorTemplate: bot.questionGeneratorPrompt, - returnSourceDocuments: true, - questionGeneratorChainOptions: { - llm: nonStreamingModel, }, - }, - ); - - const chat_history = history - .map((chatMessage: any) => { - if (chatMessage.type === "human") { - return `Human: ${chatMessage.text}`; - } else if (chatMessage.type === "ai") { - return `Assistant: ${chatMessage.text}`; - } else { - return `${chatMessage.text}`; - } - }) - .join("\n"); - - console.log("Waiting for response..."); - - response = await chain.call({ - question: sanitizedQuestion, - chat_history: chat_history, - }); + ); + + const chat_history = history + .map((chatMessage: any) => { + if (chatMessage.type === "human") { + return `Human: ${chatMessage.text}`; + } else if (chatMessage.type === "ai") { + return `Assistant: ${chatMessage.text}`; + } else { + return `${chatMessage.text}`; + } + }) + .join("\n"); + + console.log("Waiting for response..."); + + response = await chain.call({ + question: sanitizedQuestion, + chat_history: chat_history, + }); + + let historyId = history_id; - let historyId = history_id; + if (!historyId) { + const newHistory = await prisma.botPlayground.create({ + data: { + botId: bot.id, + title: message, + }, + }); + historyId = newHistory.id; + } - if (!historyId) { - const newHistory = await prisma.botPlayground.create({ + await prisma.botPlaygroundMessage.create({ data: { - botId: bot.id, - title: message, + type: "human", + message: message, + botPlaygroundId: historyId, }, }); - historyId = newHistory.id; - } - await prisma.botPlaygroundMessage.create({ - data: { - type: "human", - message: message, - botPlaygroundId: historyId, - }, - }); - - await prisma.botPlaygroundMessage.create({ - data: { - type: "ai", - message: response.text, - botPlaygroundId: historyId, - isBot: true, - sources: response?.sourceDocuments, - }, - }); + await prisma.botPlaygroundMessage.create({ + data: { + type: "ai", + message: response.text, + botPlaygroundId: historyId, + isBot: true, + sources: response?.sourceDocuments, + }, + }); - reply.sse({ - event: "result", - id: "", - data: JSON.stringify({ - bot: response, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", - text: response.text, + reply.sse({ + event: "result", + id: "", + data: JSON.stringify({ + bot: response, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: response.text, + }, + ], + history_id: historyId, + }), + }); + await nextTick(); + return reply.raw.end(); + } catch (e) { + console.log(e); + reply.raw.setHeader("Content-Type", "text/event-stream"); + + reply.sse({ + event: "result", + id: "", + data: JSON.stringify({ + bot: { + text: "There was an error processing your request.", + sourceDocuments: [], }, - ], - history_id: historyId, - }), - }); - await nextTick(); - return reply.raw.end(); + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: "There was an error processing your request.", + }, + ], + }), + }); + await nextTick(); + + return reply.raw.end(); + } }; export const updateBotAudioSettingsHandler = async ( diff --git a/server/src/routes/api/v1/bot/schema/index.ts b/server/src/routes/api/v1/bot/schema/index.ts index 2d03f9c4..6170007f 100644 --- a/server/src/routes/api/v1/bot/schema/index.ts +++ b/server/src/routes/api/v1/bot/schema/index.ts @@ -122,7 +122,13 @@ export const updateBotByIdSchema: FastifySchema = { }, use_hybrid_search: { type: "boolean", - } + }, + bot_protect: { + type: "boolean", + }, + use_rag: { + type: "boolean", + }, }, }, }; diff --git a/server/src/routes/bot/handlers/post.handler.ts b/server/src/routes/bot/handlers/post.handler.ts index ffe3ee7b..6317f090 100644 --- a/server/src/routes/bot/handlers/post.handler.ts +++ b/server/src/routes/bot/handlers/post.handler.ts @@ -206,7 +206,10 @@ export const chatRequestStreamHandler = async ( if (bot.bot_protect) { if (!request.session.get("is_bot_allowed")) { console.log("not allowed"); - return reply.sse({ + + reply.raw.setHeader("Content-Type", "text/event-stream"); + + reply.sse({ event: "result", id: "", data: JSON.stringify({ @@ -227,6 +230,11 @@ export const chatRequestStreamHandler = async ( ], }), }); + + + await nextTick(); + + return reply.raw.end(); } } @@ -350,7 +358,9 @@ export const chatRequestStreamHandler = async ( return reply.raw.end(); } catch (e) { console.log(e); - return reply.sse({ + reply.raw.setHeader("Content-Type", "text/event-stream"); + + reply.sse({ event: "result", id: "", data: JSON.stringify({ @@ -371,5 +381,8 @@ export const chatRequestStreamHandler = async ( ], }), }); + await nextTick(); + + return reply.raw.end(); } };