diff --git a/app/ui/src/routes/settings/application.tsx b/app/ui/src/routes/settings/application.tsx index ff70e79f..ec1579f1 100644 --- a/app/ui/src/routes/settings/application.tsx +++ b/app/ui/src/routes/settings/application.tsx @@ -1,4 +1,4 @@ -import { Form, InputNumber, Switch, notification, Select } from "antd"; +import { Form, InputNumber, Switch, notification, Select, Input } from "antd"; import React from "react"; import api from "../../services/api"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; @@ -25,6 +25,9 @@ export default function SettingsApplicationRoot() { defaultChunkOverlap: number; defaultChatModel: string; defaultEmbeddingModel: string; + hideDefaultModels: boolean; + dynamicallyFetchOllamaModels: boolean; + ollamaURL: string; }; }); @@ -63,6 +66,22 @@ export default function SettingsApplicationRoot() { } ); + const { mutateAsync: updateModelSettings, isLoading: isModelLoading } = + useMutation(onUpdateApplicatoon, { + onSuccess: (data) => { + queryClient.invalidateQueries(["fetchBotCreateConfig"]); + notification.success({ + message: "Success", + description: data.message, + }); + }, + onError: (error: any) => { + notification.error({ + message: "Error", + description: error?.response?.data?.message || "Something went wrong", + }); + }, + }); const { mutateAsync: updateRagSettings, isLoading: isRagLoading } = useMutation(onRagApplicationUpdate, { onSuccess: (data) => { @@ -135,6 +154,33 @@ export default function SettingsApplicationRoot() { > + +
+ +
+ + + + + +
+
+
+ + + + + + + + + +
+ { try { @@ -130,21 +131,29 @@ export const getAllModelsHandler = async ( request: FastifyRequest, reply: FastifyReply ) => { - try { - const prisma = request.server.prisma; - const user = request.user; + const prisma = request.server.prisma; + const user = request.user; - if (!user.is_admin) { - return reply.status(403).send({ - message: "Forbidden", - }); - } - const allModels = await prisma.dialoqbaseModels.findMany({ - where: { - deleted: false, - }, + if (!user.is_admin) { + return reply.status(403).send({ + message: "Forbidden", }); + } + const settings = await getSettings(prisma); + + const not_to_hide_providers = settings?.hideDefaultModels + ? [ "Local", "local", "ollama", "transformer", "Transformer"] + : undefined; + const allModels = await prisma.dialoqbaseModels.findMany({ + where: { + deleted: false, + model_provider: { + in: not_to_hide_providers, + }, + }, + }); + try { return { data: allModels.filter((model) => model.model_type !== "embedding"), embedding: allModels.filter((model) => model.model_type === "embedding"), @@ -245,7 +254,7 @@ export const saveModelFromInputedUrlHandler = async ( }); } - let newModelId = model_id.trim() + `_custom_${new Date().getTime()}`; + let newModelId = model_id.trim() + `_dialoqbase_${new Date().getTime()}`; await prisma.dialoqbaseModels.create({ data: { name: isModelExist.name, diff --git a/server/src/handlers/api/v1/bot/bot/api.handler.ts b/server/src/handlers/api/v1/bot/bot/api.handler.ts index eb9dbb5a..bda42af7 100644 --- a/server/src/handlers/api/v1/bot/bot/api.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/api.handler.ts @@ -16,6 +16,7 @@ import { uniqueNamesGenerator, } from "unique-names-generator"; import { validateDataSource } from "../../../../../utils/datasource-validation"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const createBotAPIHandler = async ( request: FastifyRequest, @@ -55,19 +56,11 @@ export const createBotAPIHandler = async ( message: `Reach maximum limit of ${maxBotsAllowed} bots per user`, }); } - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - hide: false, - deleted: false, - OR: [ - { - model_id: model, - }, - { - model_id: `${model}-dbase`, - }, - ], - }, + + const modelInfo = await getModelInfo({ + model, + prisma, + type: "chat", }); if (!modelInfo) { @@ -76,19 +69,10 @@ export const createBotAPIHandler = async ( }); } - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - OR: [ - { - model_id: embedding, - }, - { - model_id: `dialoqbase_eb_${embedding}`, - }, - ], - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { diff --git a/server/src/handlers/api/v1/bot/bot/chat.handler.ts b/server/src/handlers/api/v1/bot/bot/chat.handler.ts index 6a371bf6..18e5b10f 100644 --- a/server/src/handlers/api/v1/bot/bot/chat.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/chat.handler.ts @@ -7,377 +7,368 @@ import { DialoqbaseHybridRetrival } from "../../../../../utils/hybrid"; import { DialoqbaseVectorStore } from "../../../../../utils/store"; import { chatModelProvider } from "../../../../../utils/models"; import { createChain, groupMessagesByConversation } from "../../../../../chain"; +import { getModelInfo } from "../../../../../utils/get-model-info"; function nextTick() { - return new Promise((resolve) => setTimeout(resolve, 0)); + return new Promise((resolve) => setTimeout(resolve, 0)); } export const chatRequestAPIHandler = async ( - request: FastifyRequest, - reply: FastifyReply + request: FastifyRequest, + reply: FastifyReply ) => { - const { message, history, stream } = request.body; - if (stream) { - try { - const bot_id = request.params.id; - const prisma = request.server.prisma; - const user_id = request.user.user_id; - - const bot = await prisma.bot.findFirst({ - where: { - id: bot_id, - user_id - }, - }); - - if (!bot) { - return reply.status(404).send({ - message: "Bot not found", - }); - } - - - const temperature = bot.temperature; - - const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, - }); - - if (!embeddingInfo) { - return reply.status(404).send({ - message: "Embedding not found", - }); - } - - const embeddingModel = embeddings( - embeddingInfo.model_provider!.toLowerCase(), - embeddingInfo.model_id, - embeddingInfo?.config - ); - - reply.raw.on("close", () => { - console.log("closed"); - }); - - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, - }); - - if (!modelinfo) { - return reply.status(404).send({ - message: "Model not found", - }); - } - - const botConfig = (modelinfo.config as {}) || {}; - let retriever: BaseRetriever; - let resolveWithDocuments: (value: Document[]) => void; - const documentPromise = new Promise((resolve) => { - resolveWithDocuments = resolve; - }); - if (bot.use_hybrid_search) { - retriever = new DialoqbaseHybridRetrival(embeddingModel, { - botId: bot.id, - sourceId: null, - callbacks: [ - { - handleRetrieverEnd(documents) { - resolveWithDocuments(documents); - }, - }, - ], - }); - } else { - const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( - embeddingModel, - { - botId: bot.id, - sourceId: null, - } - ); - - retriever = vectorstore.asRetriever({ - callbacks: [ - { - handleRetrieverEnd(documents) { - resolveWithDocuments(documents); - }, - }, - ], - }); - } - - let response: string = ""; - const streamedModel = chatModelProvider( - bot.provider, - bot.model, - temperature, - { - streaming: true, - ...botConfig, - } - ); - - const nonStreamingModel = chatModelProvider( - bot.provider, - bot.model, - temperature, - { - ...botConfig, - } - ); - - reply.raw.on("close", () => { - // close the model - }); - - const chain = createChain({ - llm: streamedModel, - question_llm: nonStreamingModel, - question_template: bot.questionGeneratorPrompt, - response_template: bot.qaPrompt, - retriever, - }); - - let stream = await chain.stream({ - question: sanitizedQuestion, - chat_history: groupMessagesByConversation( - history.map((message) => ({ - type: message.role, - content: message.text, - })) - ), - }); - - for await (const token of stream) { - reply.sse({ - id: "", - event: "chunk", - data: JSON.stringify({ - bot: { - text: token || "", - sourceDocuments: [], - }, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", - text: token || "", - }, - ], - }), - }); - response += token; - } - - const documents = await documentPromise; - - - await prisma.botApiHistory.create({ - data: { - api_key: request.headers.authorization || "", - bot_id: bot.id, - human: message, - bot: response, - } - }) - - reply.sse({ - event: "result", - id: "", - data: JSON.stringify({ - bot: { - text: response, - sourceDocuments: documents, - }, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", - text: response, - }, - ], - }), - }); - await nextTick(); - return reply.raw.end(); - } catch (e) { - return reply.status(500).send({ - message: "Internal Server Error", - }); + const { message, history, stream } = request.body; + if (stream) { + try { + const bot_id = request.params.id; + const prisma = request.server.prisma; + const user_id = request.user.user_id; + + const bot = await prisma.bot.findFirst({ + where: { + id: bot_id, + user_id, + }, + }); + + if (!bot) { + return reply.status(404).send({ + message: "Bot not found", + }); + } + + const temperature = bot.temperature; + + const sanitizedQuestion = message.trim().replaceAll("\n", " "); + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", + }); + + if (!embeddingInfo) { + return reply.status(404).send({ + message: "Embedding not found", + }); + } + + const embeddingModel = embeddings( + embeddingInfo.model_provider!.toLowerCase(), + embeddingInfo.model_id, + embeddingInfo?.config + ); + + reply.raw.on("close", () => { + console.log("closed"); + }); + + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", + }); + + if (!modelinfo) { + return reply.status(404).send({ + message: "Model not found", + }); + } + + const botConfig = (modelinfo.config as {}) || {}; + let retriever: BaseRetriever; + let resolveWithDocuments: (value: Document[]) => void; + const documentPromise = new Promise((resolve) => { + resolveWithDocuments = resolve; + }); + if (bot.use_hybrid_search) { + retriever = new DialoqbaseHybridRetrival(embeddingModel, { + botId: bot.id, + sourceId: null, + callbacks: [ + { + handleRetrieverEnd(documents) { + resolveWithDocuments(documents); + }, + }, + ], + }); + } else { + const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( + embeddingModel, + { + botId: bot.id, + sourceId: null, + } + ); + + retriever = vectorstore.asRetriever({ + callbacks: [ + { + handleRetrieverEnd(documents) { + resolveWithDocuments(documents); + }, + }, + ], + }); + } + + let response: string = ""; + const streamedModel = chatModelProvider( + bot.provider, + bot.model, + temperature, + { + streaming: true, + ...botConfig, } - } else { - try { - const bot_id = request.params.id; - const user_id = request.user.user_id; - - const prisma = request.server.prisma; - - const bot = await prisma.bot.findFirst({ - where: { - id: bot_id, - user_id - }, - }); - - if (!bot) { - return reply.status(404).send({ - message: "Bot not found", - }); - } - - const temperature = bot.temperature; - - const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, - }); - - if (!embeddingInfo) { - return reply.status(404).send({ - message: "Embedding not found", - }); - } - - const embeddingModel = embeddings( - embeddingInfo.model_provider!.toLowerCase(), - embeddingInfo.model_id, - embeddingInfo?.config - ); - - let retriever: BaseRetriever; - let resolveWithDocuments: (value: Document[]) => void; - const documentPromise = new Promise((resolve) => { - resolveWithDocuments = resolve; - }); - if (bot.use_hybrid_search) { - retriever = new DialoqbaseHybridRetrival(embeddingModel, { - botId: bot.id, - sourceId: null, - callbacks: [ - { - handleRetrieverEnd(documents) { - resolveWithDocuments(documents); - }, - }, - ], - }); - } else { - const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( - embeddingModel, - { - botId: bot.id, - sourceId: null, - } - ); - - retriever = vectorstore.asRetriever({ - callbacks: [ - { - handleRetrieverEnd(documents) { - resolveWithDocuments(documents); - }, - }, - ], - }); - } - - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, - }); - - if (!modelinfo) { - return reply.status(404).send({ - message: "Model not found", - }); - } - - const botConfig: any = (modelinfo.config as {}) || {}; - if (bot.provider.toLowerCase() === "openai") { - if (bot.bot_model_api_key && bot.bot_model_api_key.trim() !== "") { - botConfig.configuration = { - apiKey: bot.bot_model_api_key, - }; - } - } - - const model = chatModelProvider(bot.provider, bot.model, temperature, { - ...botConfig, - }); - - const chain = createChain({ - llm: model, - question_llm: model, - question_template: bot.questionGeneratorPrompt, - response_template: bot.qaPrompt, - retriever, - }); - - const botResponse = await chain.invoke({ - question: sanitizedQuestion, - chat_history: groupMessagesByConversation( - history.map((message) => ({ - type: message.role, - content: message.text, - })) - ), - }); - - const documents = await documentPromise; - - await prisma.botApiHistory.create({ - data: { - api_key: request.headers.authorization || "", - bot_id: bot.id, - human: message, - bot: botResponse, - } - }) - - return { - bot: { - text: botResponse, - sourceDocuments: documents, - }, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", - text: botResponse, - }, - ], - }; - } catch (e) { - return reply.status(500).send({ - message: "Internal Server Error", - }); + ); + + const nonStreamingModel = chatModelProvider( + bot.provider, + bot.model, + temperature, + { + ...botConfig, } + ); + + reply.raw.on("close", () => { + // close the model + }); + + const chain = createChain({ + llm: streamedModel, + question_llm: nonStreamingModel, + question_template: bot.questionGeneratorPrompt, + response_template: bot.qaPrompt, + retriever, + }); + + let stream = await chain.stream({ + question: sanitizedQuestion, + chat_history: groupMessagesByConversation( + history.map((message) => ({ + type: message.role, + content: message.text, + })) + ), + }); + + for await (const token of stream) { + reply.sse({ + id: "", + event: "chunk", + data: JSON.stringify({ + bot: { + text: token || "", + sourceDocuments: [], + }, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: token || "", + }, + ], + }), + }); + response += token; + } + + const documents = await documentPromise; + + await prisma.botApiHistory.create({ + data: { + api_key: request.headers.authorization || "", + bot_id: bot.id, + human: message, + bot: response, + }, + }); + + reply.sse({ + event: "result", + id: "", + data: JSON.stringify({ + bot: { + text: response, + sourceDocuments: documents, + }, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: response, + }, + ], + }), + }); + await nextTick(); + return reply.raw.end(); + } catch (e) { + return reply.status(500).send({ + message: "Internal Server Error", + }); } + } else { + try { + const bot_id = request.params.id; + const user_id = request.user.user_id; + + const prisma = request.server.prisma; + + const bot = await prisma.bot.findFirst({ + where: { + id: bot_id, + user_id, + }, + }); + + if (!bot) { + return reply.status(404).send({ + message: "Bot not found", + }); + } + + const temperature = bot.temperature; + + const sanitizedQuestion = message.trim().replaceAll("\n", " "); + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", + }); + + if (!embeddingInfo) { + return reply.status(404).send({ + message: "Embedding not found", + }); + } + + const embeddingModel = embeddings( + embeddingInfo.model_provider!.toLowerCase(), + embeddingInfo.model_id, + embeddingInfo?.config + ); + + let retriever: BaseRetriever; + let resolveWithDocuments: (value: Document[]) => void; + const documentPromise = new Promise((resolve) => { + resolveWithDocuments = resolve; + }); + if (bot.use_hybrid_search) { + retriever = new DialoqbaseHybridRetrival(embeddingModel, { + botId: bot.id, + sourceId: null, + callbacks: [ + { + handleRetrieverEnd(documents) { + resolveWithDocuments(documents); + }, + }, + ], + }); + } else { + const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( + embeddingModel, + { + botId: bot.id, + sourceId: null, + } + ); + + retriever = vectorstore.asRetriever({ + callbacks: [ + { + handleRetrieverEnd(documents) { + resolveWithDocuments(documents); + }, + }, + ], + }); + } + + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", + }); + + if (!modelinfo) { + return reply.status(404).send({ + message: "Model not found", + }); + } + + const botConfig: any = (modelinfo.config as {}) || {}; + if (bot.provider.toLowerCase() === "openai") { + if (bot.bot_model_api_key && bot.bot_model_api_key.trim() !== "") { + botConfig.configuration = { + apiKey: bot.bot_model_api_key, + }; + } + } + + const model = chatModelProvider(bot.provider, bot.model, temperature, { + ...botConfig, + }); + + const chain = createChain({ + llm: model, + question_llm: model, + question_template: bot.questionGeneratorPrompt, + response_template: bot.qaPrompt, + retriever, + }); + + const botResponse = await chain.invoke({ + question: sanitizedQuestion, + chat_history: groupMessagesByConversation( + history.map((message) => ({ + type: message.role, + content: message.text, + })) + ), + }); + + const documents = await documentPromise; + + await prisma.botApiHistory.create({ + data: { + api_key: request.headers.authorization || "", + bot_id: bot.id, + human: message, + bot: botResponse, + }, + }); + + return { + bot: { + text: botResponse, + sourceDocuments: documents, + }, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: botResponse, + }, + ], + }; + } catch (e) { + return reply.status(500).send({ + message: "Internal Server Error", + }); + } + } }; diff --git a/server/src/handlers/api/v1/bot/bot/get.handler.ts b/server/src/handlers/api/v1/bot/bot/get.handler.ts index c6f2cc9d..d02cfa9c 100644 --- a/server/src/handlers/api/v1/bot/bot/get.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/get.handler.ts @@ -2,6 +2,7 @@ import { FastifyReply, FastifyRequest } from "fastify"; import { GetBotRequestById } from "./types"; import { getSettings } from "../../../../../utils/common"; +import { getAllOllamaModels } from "../../../../../utils/ollama"; export const getBotByIdEmbeddingsHandler = async ( request: FastifyRequest, @@ -125,10 +126,18 @@ export const getCreateBotConfigHandler = async ( reply: FastifyReply ) => { const prisma = request.server.prisma; + const settings = await getSettings(prisma); + + const not_to_hide_providers = settings?.hideDefaultModels + ? ["Local", "local", "ollama", "transformer", "Transformer"] + : undefined; const models = await prisma.dialoqbaseModels.findMany({ where: { hide: false, deleted: false, + model_provider: { + in: not_to_hide_providers, + }, }, }); @@ -146,16 +155,30 @@ export const getCreateBotConfigHandler = async ( .filter((model) => model.model_type === "embedding") .map((model) => { return { - label: `${model.name || model.model_id} ${model.model_id === "dialoqbase_eb_dialoqbase-ollama" + label: `${model.name || model.model_id} ${ + model.model_id === "dialoqbase_eb_dialoqbase-ollama" ? "(Deprecated)" : "" - }`, + }`, value: model.model_id, disabled: model.model_id === "dialoqbase_eb_dialoqbase-ollama", }; }); - - const settings = await getSettings(prisma); + + if (settings?.dynamicallyFetchOllamaModels) { + const ollamaModels = await getAllOllamaModels(settings.ollamaURL); + chatModel.push( + ...ollamaModels?.filter((model) => { + return ( + !model?.details?.families?.includes("bert") && + !model?.details?.families?.includes("nomic-bert") + ); + }) + ); + embeddingModel.push( + ...ollamaModels.map((model) => ({ ...model, disabled: false })) + ); + } return { chatModel, @@ -200,10 +223,11 @@ export const getBotByIdSettingsHandler = async ( .filter((model) => model.model_type === "embedding") .map((model) => { return { - label: `${model.name || model.model_id} ${model.model_id === "dialoqbase_eb_dialoqbase-ollama" + label: `${model.name || model.model_id} ${ + model.model_id === "dialoqbase_eb_dialoqbase-ollama" ? "(Deprecated)" : "" - }`, + }`, value: model.model_id, disabled: model.model_id === "dialoqbase_eb_dialoqbase-ollama", }; @@ -221,7 +245,6 @@ export const getBotByIdSettingsHandler = async ( }; }; - export const isBotReadyHandler = async ( request: FastifyRequest, reply: FastifyReply @@ -252,4 +275,4 @@ export const isBotReadyHandler = async ( return { is_ready: source === 0, }; -}; \ No newline at end of file +}; diff --git a/server/src/handlers/api/v1/bot/bot/post.handler.ts b/server/src/handlers/api/v1/bot/bot/post.handler.ts index 238b7133..fd557a7b 100644 --- a/server/src/handlers/api/v1/bot/bot/post.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/post.handler.ts @@ -16,6 +16,7 @@ import { HELPFUL_ASSISTANT_WITH_CONTEXT_PROMPT, HELPFUL_ASSISTANT_WITHOUT_CONTEXT_PROMPT, } from "../../../../../utils/prompts"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const createBotHandler = async ( request: FastifyRequest, @@ -55,12 +56,10 @@ export const createBotHandler = async ( }); } // const providerName = modelProviderName(model); - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: model, - hide: false, - deleted: false, - }, + const modelInfo = await getModelInfo({ + model, + prisma, + type: "chat", }); if (!modelInfo) { @@ -69,12 +68,10 @@ export const createBotHandler = async ( }); } - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { diff --git a/server/src/handlers/api/v1/bot/bot/put.handler.ts b/server/src/handlers/api/v1/bot/bot/put.handler.ts index 380d6fe7..7dc0fc97 100644 --- a/server/src/handlers/api/v1/bot/bot/put.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/put.handler.ts @@ -4,6 +4,7 @@ import { apiKeyValidaton, apiKeyValidatonMessage, } from "../../../../../utils/validate"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const updateBotByIdHandler = async ( request: FastifyRequest, @@ -25,12 +26,9 @@ export const updateBotByIdHandler = async ( }); } - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: request.body.model, - hide: false, - deleted: false, - }, + const modelInfo = await getModelInfo({ + model: request.body.model, + prisma, }); if (!modelInfo) { @@ -96,26 +94,15 @@ export const updateBotAPIByIdHandler = async ( questionGeneratorPrompt: request.body?.question_generator_prompt, system_prompt: undefined, question_generator_prompt: undefined, - } - + }; if (updateBody.model) { - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - hide: false, - deleted: false, - OR: [ - { - model_id: updateBody.model - }, - { - model_id: `${updateBody.model}-dbase` - } - ], - }, + const modelInfo = await getModelInfo({ + model: updateBody.model, + prisma, + type: "chat", }); - if (!modelInfo) { return reply.status(400).send({ message: "Model not found", @@ -140,8 +127,7 @@ export const updateBotAPIByIdHandler = async ( updateBody = { ...updateBody, provider: modelInfo.model_provider || "", - } - + }; } await prisma.bot.update({ where: { diff --git a/server/src/handlers/api/v1/bot/bot/upload.handler.ts b/server/src/handlers/api/v1/bot/bot/upload.handler.ts index cf0f1b9a..6264c04c 100644 --- a/server/src/handlers/api/v1/bot/bot/upload.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/upload.handler.ts @@ -19,6 +19,7 @@ const pump = util.promisify(pipeline); import { fileTypeFinder } from "../../../../../utils/fileType"; import { getSettings } from "../../../../../utils/common"; import { HELPFUL_ASSISTANT_WITH_CONTEXT_PROMPT } from "../../../../../utils/prompts"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const createBotFileHandler = async ( request: FastifyRequest, @@ -52,12 +53,10 @@ export const createBotFileHandler = async ( }); } - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -76,12 +75,10 @@ export const createBotFileHandler = async ( }); } - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: model, - hide: false, - deleted: false, - }, + const modelInfo = await getModelInfo({ + model, + prisma, + type: "chat", }); if (!modelInfo) { diff --git a/server/src/handlers/api/v1/bot/playground/chat.handler.ts b/server/src/handlers/api/v1/bot/playground/chat.handler.ts index ab950f97..2e7661ec 100644 --- a/server/src/handlers/api/v1/bot/playground/chat.handler.ts +++ b/server/src/handlers/api/v1/bot/playground/chat.handler.ts @@ -7,6 +7,7 @@ import { DialoqbaseHybridRetrival } from "../../../../../utils/hybrid"; import { BaseRetriever } from "@langchain/core/retrievers"; import { Document } from "langchain/document"; import { createChain, groupMessagesByConversation } from "../../../../../chain"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const chatRequestHandler = async ( request: FastifyRequest, @@ -48,14 +49,11 @@ export const chatRequestHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "all", }); - if (!embeddingInfo) { return { bot: { @@ -118,12 +116,10 @@ export const chatRequestHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); if (!modelinfo) { @@ -295,12 +291,10 @@ export const chatRequestStreamHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -375,12 +369,10 @@ export const chatRequestStreamHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); if (!modelinfo) { reply.raw.setHeader("Content-Type", "text/event-stream"); diff --git a/server/src/handlers/bot/api.handler.ts b/server/src/handlers/bot/api.handler.ts index aef50c1b..d7e995aa 100644 --- a/server/src/handlers/bot/api.handler.ts +++ b/server/src/handlers/bot/api.handler.ts @@ -8,6 +8,7 @@ import { Document } from "langchain/document"; import { BaseRetriever } from "@langchain/core/retrievers"; import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { createChain, groupMessagesByConversation } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; export const chatRequestAPIHandler = async ( request: FastifyRequest, @@ -40,12 +41,10 @@ export const chatRequestAPIHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + prisma, + model: bot.embedding, + type: "embedding", }); if (!embeddingInfo) { @@ -64,12 +63,10 @@ export const chatRequestAPIHandler = async ( console.log("closed"); }); - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + prisma, + model: bot.model, + type: "chat", }); if (!modelinfo) { @@ -172,15 +169,14 @@ export const chatRequestAPIHandler = async ( }); const documents = await documentPromise; - await prisma.botApiHistory.create({ data: { api_key: request.headers["x-api-key"], bot_id: bot.id, human: message, bot: response, - } - }) + }, + }); reply.sse({ event: "result", @@ -237,12 +233,10 @@ export const chatRequestAPIHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + prisma, + model: bot.embedding, + type: "embedding", }); if (!embeddingInfo) { @@ -308,12 +302,10 @@ export const chatRequestAPIHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + prisma, + model: bot.model, + type: "chat", }); if (!modelinfo) { @@ -375,8 +367,8 @@ export const chatRequestAPIHandler = async ( bot_id: bot.id, human: message, bot: botResponse, - } - }) + }, + }); return { bot: { diff --git a/server/src/handlers/bot/post.handler.ts b/server/src/handlers/bot/post.handler.ts index 9bbc17c9..04e2a2c9 100644 --- a/server/src/handlers/bot/post.handler.ts +++ b/server/src/handlers/bot/post.handler.ts @@ -7,6 +7,7 @@ import { BaseRetriever } from "@langchain/core/retrievers"; import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { Document } from "langchain/document"; import { createChain, groupMessagesByConversation } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; export const chatRequestHandler = async ( request: FastifyRequest, @@ -69,12 +70,10 @@ export const chatRequestHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -139,13 +138,11 @@ export const chatRequestHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, - }); + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", + }) if (!modelinfo) { return { @@ -341,13 +338,11 @@ export const chatRequestStreamHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { return { @@ -416,13 +411,11 @@ export const chatRequestStreamHandler = async ( console.log("closed"); }); - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, - }); + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", + }) if (!modelinfo) { reply.raw.setHeader("Content-Type", "text/event-stream"); diff --git a/server/src/integration/handlers/discord.handler.ts b/server/src/integration/handlers/discord.handler.ts index 3eb9c4da..4549b37e 100644 --- a/server/src/integration/handlers/discord.handler.ts +++ b/server/src/integration/handlers/discord.handler.ts @@ -6,6 +6,7 @@ import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { Document } from "langchain/document"; import { BaseRetriever } from "@langchain/core/retrievers"; import { createChain } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; const prisma = new PrismaClient(); export const discordBotHandler = async ( @@ -51,14 +52,11 @@ export const discordBotHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, - }); - + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { return { text: "Opps! Model not found", @@ -108,12 +106,10 @@ export const discordBotHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); if (!modelinfo) { diff --git a/server/src/integration/handlers/telegram.handler.ts b/server/src/integration/handlers/telegram.handler.ts index 9f675ff2..0be218d1 100644 --- a/server/src/integration/handlers/telegram.handler.ts +++ b/server/src/integration/handlers/telegram.handler.ts @@ -5,6 +5,7 @@ import { chatModelProvider } from "../../utils/models"; import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { BaseRetriever } from "@langchain/core/retrievers"; import { createChain } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; const prisma = new PrismaClient(); export const telegramBotHandler = async ( @@ -46,12 +47,10 @@ export const telegramBotHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -83,14 +82,12 @@ export const telegramBotHandler = async ( retriever = vectorstore.asRetriever({}); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); - + if (!modelinfo) { return "Unable to find model"; } diff --git a/server/src/integration/handlers/whatsapp.handler.ts b/server/src/integration/handlers/whatsapp.handler.ts index 7945312f..bfed6e21 100644 --- a/server/src/integration/handlers/whatsapp.handler.ts +++ b/server/src/integration/handlers/whatsapp.handler.ts @@ -5,6 +5,7 @@ import { chatModelProvider } from "../../utils/models"; import { BaseRetriever } from "@langchain/core/retrievers"; import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { createChain } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; const prisma = new PrismaClient(); export const whatsappBotHandler = async ( @@ -55,12 +56,10 @@ export const whatsappBotHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -92,12 +91,10 @@ export const whatsappBotHandler = async ( retriever = vectorstore.asRetriever({}); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); if (!modelinfo) { diff --git a/server/src/queue/controllers/audio.controller.ts b/server/src/queue/controllers/audio.controller.ts index 86f6b07c..daeddd40 100644 --- a/server/src/queue/controllers/audio.controller.ts +++ b/server/src/queue/controllers/audio.controller.ts @@ -5,6 +5,7 @@ import { embeddings } from "../../utils/embeddings"; import { DialoqbaseAudioVideoLoader } from "../../loader/audio-video"; import { convertMp3ToWave } from "../../utils/ffmpeg"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const audioQueueController = async ( source: QSource, @@ -25,13 +26,11 @@ export const audioQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/csv.controller.ts b/server/src/queue/controllers/csv.controller.ts index 8e159e5b..c9e83179 100644 --- a/server/src/queue/controllers/csv.controller.ts +++ b/server/src/queue/controllers/csv.controller.ts @@ -4,6 +4,7 @@ import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const csvQueueController = async ( source: QSource, @@ -21,12 +22,10 @@ export const csvQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { diff --git a/server/src/queue/controllers/docx.controller.ts b/server/src/queue/controllers/docx.controller.ts index 185a8e26..aea1dc4e 100644 --- a/server/src/queue/controllers/docx.controller.ts +++ b/server/src/queue/controllers/docx.controller.ts @@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { DialoqbaseDocxLoader } from "../../loader/docx"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const DocxQueueController = async ( source: QSource, @@ -22,13 +23,11 @@ export const DocxQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/github.controller.ts b/server/src/queue/controllers/github.controller.ts index 7300c90f..fc95d477 100644 --- a/server/src/queue/controllers/github.controller.ts +++ b/server/src/queue/controllers/github.controller.ts @@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { DialoqbaseGithub } from "../../loader/github"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const githubQueueController = async ( source: QSource, @@ -25,13 +26,11 @@ export const githubQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/pdf.controller.ts b/server/src/queue/controllers/pdf.controller.ts index 2bfcaa65..2b211935 100644 --- a/server/src/queue/controllers/pdf.controller.ts +++ b/server/src/queue/controllers/pdf.controller.ts @@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { DialoqbasePDFLoader } from "../../loader/pdf"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const pdfQueueController = async ( source: QSource, @@ -22,13 +23,11 @@ export const pdfQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/rest.controller.ts b/server/src/queue/controllers/rest.controller.ts index 722dc08a..b6375b0f 100644 --- a/server/src/queue/controllers/rest.controller.ts +++ b/server/src/queue/controllers/rest.controller.ts @@ -3,6 +3,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { DialoqbaseRestApi } from "../../loader/rest"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const restQueueController = async ( source: QSource, @@ -18,13 +19,11 @@ export const restQueueController = async ( }); const docs = await loader.load(); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/text.controller.ts b/server/src/queue/controllers/text.controller.ts index 11e8c2a3..b4d73c65 100644 --- a/server/src/queue/controllers/text.controller.ts +++ b/server/src/queue/controllers/text.controller.ts @@ -3,6 +3,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const textQueueController = async ( source: QSource, @@ -20,15 +21,12 @@ export const textQueueController = async ( }, }, ]); - - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); - + + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); } diff --git a/server/src/queue/controllers/txt.controller.ts b/server/src/queue/controllers/txt.controller.ts index baf894fd..445fa24b 100644 --- a/server/src/queue/controllers/txt.controller.ts +++ b/server/src/queue/controllers/txt.controller.ts @@ -4,6 +4,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { TextLoader } from "langchain/document_loaders/fs/text"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const txtQueueController = async ( source: QSource, @@ -20,13 +21,11 @@ export const txtQueueController = async ( chunkOverlap: source.chunkOverlap, }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/video.controller.ts b/server/src/queue/controllers/video.controller.ts index c0646a3a..c41db9df 100644 --- a/server/src/queue/controllers/video.controller.ts +++ b/server/src/queue/controllers/video.controller.ts @@ -6,6 +6,7 @@ import { embeddings } from "../../utils/embeddings"; import { DialoqbaseAudioVideoLoader } from "../../loader/audio-video"; import { convertMp4ToWave } from "../../utils/ffmpeg"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const videoQueueController = async ( source: QSource, @@ -26,13 +27,11 @@ export const videoQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/website.controller.ts b/server/src/queue/controllers/website.controller.ts index d0843f75..2c492326 100644 --- a/server/src/queue/controllers/website.controller.ts +++ b/server/src/queue/controllers/website.controller.ts @@ -8,6 +8,7 @@ import { DialoqbasePDFLoader } from "../../loader/pdf"; import { DialoqbaseWebLoader } from "../../loader/web"; import { CheerioWebBaseLoader } from "langchain/document_loaders/web/cheerio"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const websiteQueueController = async ( source: QSource, @@ -37,13 +38,11 @@ export const websiteQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); @@ -79,13 +78,11 @@ export const websiteQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/youtube.controller.ts b/server/src/queue/controllers/youtube.controller.ts index fdb0e12a..5f1f30ef 100644 --- a/server/src/queue/controllers/youtube.controller.ts +++ b/server/src/queue/controllers/youtube.controller.ts @@ -6,6 +6,7 @@ import { embeddings } from "../../utils/embeddings"; import { DialoqbaseYoutube } from "../../loader/youtube"; import { PrismaClient } from "@prisma/client"; import { DialoqbaseYoutubeTranscript } from "../../loader/youtube-transcript"; +import { getModelInfo } from "../../utils/get-model-info"; export const youtubeQueueController = async ( source: QSource, @@ -29,12 +30,10 @@ export const youtubeQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -66,12 +65,10 @@ export const youtubeQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { diff --git a/server/src/schema/api/v1/admin/index.ts b/server/src/schema/api/v1/admin/index.ts index 6f4c9599..424b6e6d 100644 --- a/server/src/schema/api/v1/admin/index.ts +++ b/server/src/schema/api/v1/admin/index.ts @@ -20,6 +20,8 @@ export const dialoqbaseSettingsSchema: FastifySchema = { defaultChatModel: { type: "string" }, defaultEmbeddingModel: { type: "string" }, dynamicallyFetchOllamaModels: { type: "boolean" }, + hideDefaultModels: { type: "boolean" }, + ollamaURL: { type: "string" }, }, }, }; @@ -43,12 +45,9 @@ export const updateDialoqbaseSettingsSchema: FastifySchema = { dynamicallyFetchOllamaModels: { type: "boolean" }, defaultChatModel: { type: "string" }, defaultEmbeddingModel: { type: "string" }, + hideDefaultModels: { type: "boolean" }, + ollamaURL: { type: "string" }, }, - required: [ - "noOfBotsPerUser", - "allowUserToCreateBots", - "allowUserToRegister", - ], }, response: { 200: { diff --git a/server/src/utils/get-model-info.ts b/server/src/utils/get-model-info.ts new file mode 100644 index 00000000..e5aa21f6 --- /dev/null +++ b/server/src/utils/get-model-info.ts @@ -0,0 +1,92 @@ +import { PrismaClient, DialoqbaseModels } from "@prisma/client"; +import { getSettings } from "./common"; +import { cleanUrl, getAllOllamaModels } from "./ollama"; + +export const getModelInfo = async ({ + model, + prisma, + type = "all", +}: { + prisma: PrismaClient; + model: string; + type?: "all" | "chat" | "embedding"; +}): Promise => { + let modelInfo: DialoqbaseModels | null = null; + const settings = await getSettings(prisma); + const not_to_hide_providers = settings?.hideDefaultModels + ? [ "Local", "local", "ollama", "transformer", "Transformer"] + : undefined; + if (type === "all") { + modelInfo = await prisma.dialoqbaseModels.findFirst({ + where: { + model_id: model, + hide: false, + deleted: false, + model_provider: { + in: not_to_hide_providers, + }, + }, + }); + } else if (type === "chat") { + modelInfo = await prisma.dialoqbaseModels.findFirst({ + where: { + hide: false, + deleted: false, + model_provider: { + in: not_to_hide_providers, + }, + OR: [ + { + model_id: model, + }, + { + model_id: `${model}-dbase`, + }, + ], + }, + }); + } else if (type === "embedding") { + modelInfo = await prisma.dialoqbaseModels.findFirst({ + where: { + OR: [ + { + model_id: model, + }, + { + model_id: `dialoqbase_eb_${model}`, + }, + ], + hide: false, + deleted: false, + model_provider: { + in: not_to_hide_providers, + }, + }, + }); + } + if (!modelInfo) { + if (settings?.dynamicallyFetchOllamaModels) { + const ollamaModles = await getAllOllamaModels(settings.ollamaURL); + const ollamaInfo = ollamaModles.find((m) => m.value === model); + if (ollamaInfo) { + return { + name: ollamaInfo.name, + model_id: ollamaInfo.name, + stream_available: true, + local_model: true, + model_provider: "ollama", + config: { + baseURL: cleanUrl(settings.ollamaURL), + }, + createdAt: new Date(), + model_type: "chat", + deleted: false, + hide: false, + id: 1, + }; + } + } + } + + return modelInfo; +}; diff --git a/server/src/utils/ollama.ts b/server/src/utils/ollama.ts new file mode 100644 index 00000000..bdfaac77 --- /dev/null +++ b/server/src/utils/ollama.ts @@ -0,0 +1,38 @@ +import axios from "axios"; + +export const cleanUrl = (url: string) => { + if (url.endsWith("/")) { + return url.slice(0, -1); + } + return url; +}; + +export const getAllOllamaModels = async (url: string) => { + try { + const response = await axios.get(`${cleanUrl(url)}/api/tags`); + const { models } = response.data as { + models: { + name: string; + details?: { + parent_model?: string + format: string + family: string + families: string[] + parameter_size: string + quantization_level: string + } + }[]; + }; + return models.map((data) => { + return { + ...data, + label: data.name, + value: data.name, + stream: true + }; + }); + } catch (error) { + console.log(`Error fetching Ollama models`, error); + return []; + } +};