diff --git a/app/ui/src/routes/settings/application.tsx b/app/ui/src/routes/settings/application.tsx
index ff70e79f..ec1579f1 100644
--- a/app/ui/src/routes/settings/application.tsx
+++ b/app/ui/src/routes/settings/application.tsx
@@ -1,4 +1,4 @@
-import { Form, InputNumber, Switch, notification, Select } from "antd";
+import { Form, InputNumber, Switch, notification, Select, Input } from "antd";
import React from "react";
import api from "../../services/api";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
@@ -25,6 +25,9 @@ export default function SettingsApplicationRoot() {
defaultChunkOverlap: number;
defaultChatModel: string;
defaultEmbeddingModel: string;
+ hideDefaultModels: boolean;
+ dynamicallyFetchOllamaModels: boolean;
+ ollamaURL: string;
};
});
@@ -63,6 +66,22 @@ export default function SettingsApplicationRoot() {
}
);
+ const { mutateAsync: updateModelSettings, isLoading: isModelLoading } =
+ useMutation(onUpdateApplicatoon, {
+ onSuccess: (data) => {
+ queryClient.invalidateQueries(["fetchBotCreateConfig"]);
+ notification.success({
+ message: "Success",
+ description: data.message,
+ });
+ },
+ onError: (error: any) => {
+ notification.error({
+ message: "Error",
+ description: error?.response?.data?.message || "Something went wrong",
+ });
+ },
+ });
const { mutateAsync: updateRagSettings, isLoading: isRagLoading } =
useMutation(onRagApplicationUpdate, {
onSuccess: (data) => {
@@ -135,6 +154,33 @@ export default function SettingsApplicationRoot() {
>
+
+
+
+
+
+
+
+
+
+
+
{
try {
@@ -130,21 +131,29 @@ export const getAllModelsHandler = async (
request: FastifyRequest,
reply: FastifyReply
) => {
- try {
- const prisma = request.server.prisma;
- const user = request.user;
+ const prisma = request.server.prisma;
+ const user = request.user;
- if (!user.is_admin) {
- return reply.status(403).send({
- message: "Forbidden",
- });
- }
- const allModels = await prisma.dialoqbaseModels.findMany({
- where: {
- deleted: false,
- },
+ if (!user.is_admin) {
+ return reply.status(403).send({
+ message: "Forbidden",
});
+ }
+ const settings = await getSettings(prisma);
+
+ const not_to_hide_providers = settings?.hideDefaultModels
+ ? [ "Local", "local", "ollama", "transformer", "Transformer"]
+ : undefined;
+ const allModels = await prisma.dialoqbaseModels.findMany({
+ where: {
+ deleted: false,
+ model_provider: {
+ in: not_to_hide_providers,
+ },
+ },
+ });
+ try {
return {
data: allModels.filter((model) => model.model_type !== "embedding"),
embedding: allModels.filter((model) => model.model_type === "embedding"),
@@ -245,7 +254,7 @@ export const saveModelFromInputedUrlHandler = async (
});
}
- let newModelId = model_id.trim() + `_custom_${new Date().getTime()}`;
+ let newModelId = model_id.trim() + `_dialoqbase_${new Date().getTime()}`;
await prisma.dialoqbaseModels.create({
data: {
name: isModelExist.name,
diff --git a/server/src/handlers/api/v1/bot/bot/api.handler.ts b/server/src/handlers/api/v1/bot/bot/api.handler.ts
index eb9dbb5a..bda42af7 100644
--- a/server/src/handlers/api/v1/bot/bot/api.handler.ts
+++ b/server/src/handlers/api/v1/bot/bot/api.handler.ts
@@ -16,6 +16,7 @@ import {
uniqueNamesGenerator,
} from "unique-names-generator";
import { validateDataSource } from "../../../../../utils/datasource-validation";
+import { getModelInfo } from "../../../../../utils/get-model-info";
export const createBotAPIHandler = async (
request: FastifyRequest,
@@ -55,19 +56,11 @@ export const createBotAPIHandler = async (
message: `Reach maximum limit of ${maxBotsAllowed} bots per user`,
});
}
- const modelInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- hide: false,
- deleted: false,
- OR: [
- {
- model_id: model,
- },
- {
- model_id: `${model}-dbase`,
- },
- ],
- },
+
+ const modelInfo = await getModelInfo({
+ model,
+ prisma,
+ type: "chat",
});
if (!modelInfo) {
@@ -76,19 +69,10 @@ export const createBotAPIHandler = async (
});
}
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- OR: [
- {
- model_id: embedding,
- },
- {
- model_id: `dialoqbase_eb_${embedding}`,
- },
- ],
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
diff --git a/server/src/handlers/api/v1/bot/bot/chat.handler.ts b/server/src/handlers/api/v1/bot/bot/chat.handler.ts
index 6a371bf6..18e5b10f 100644
--- a/server/src/handlers/api/v1/bot/bot/chat.handler.ts
+++ b/server/src/handlers/api/v1/bot/bot/chat.handler.ts
@@ -7,377 +7,368 @@ import { DialoqbaseHybridRetrival } from "../../../../../utils/hybrid";
import { DialoqbaseVectorStore } from "../../../../../utils/store";
import { chatModelProvider } from "../../../../../utils/models";
import { createChain, groupMessagesByConversation } from "../../../../../chain";
+import { getModelInfo } from "../../../../../utils/get-model-info";
function nextTick() {
- return new Promise((resolve) => setTimeout(resolve, 0));
+ return new Promise((resolve) => setTimeout(resolve, 0));
}
export const chatRequestAPIHandler = async (
- request: FastifyRequest,
- reply: FastifyReply
+ request: FastifyRequest,
+ reply: FastifyReply
) => {
- const { message, history, stream } = request.body;
- if (stream) {
- try {
- const bot_id = request.params.id;
- const prisma = request.server.prisma;
- const user_id = request.user.user_id;
-
- const bot = await prisma.bot.findFirst({
- where: {
- id: bot_id,
- user_id
- },
- });
-
- if (!bot) {
- return reply.status(404).send({
- message: "Bot not found",
- });
- }
-
-
- const temperature = bot.temperature;
-
- const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
- });
-
- if (!embeddingInfo) {
- return reply.status(404).send({
- message: "Embedding not found",
- });
- }
-
- const embeddingModel = embeddings(
- embeddingInfo.model_provider!.toLowerCase(),
- embeddingInfo.model_id,
- embeddingInfo?.config
- );
-
- reply.raw.on("close", () => {
- console.log("closed");
- });
-
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
- });
-
- if (!modelinfo) {
- return reply.status(404).send({
- message: "Model not found",
- });
- }
-
- const botConfig = (modelinfo.config as {}) || {};
- let retriever: BaseRetriever;
- let resolveWithDocuments: (value: Document[]) => void;
- const documentPromise = new Promise((resolve) => {
- resolveWithDocuments = resolve;
- });
- if (bot.use_hybrid_search) {
- retriever = new DialoqbaseHybridRetrival(embeddingModel, {
- botId: bot.id,
- sourceId: null,
- callbacks: [
- {
- handleRetrieverEnd(documents) {
- resolveWithDocuments(documents);
- },
- },
- ],
- });
- } else {
- const vectorstore = await DialoqbaseVectorStore.fromExistingIndex(
- embeddingModel,
- {
- botId: bot.id,
- sourceId: null,
- }
- );
-
- retriever = vectorstore.asRetriever({
- callbacks: [
- {
- handleRetrieverEnd(documents) {
- resolveWithDocuments(documents);
- },
- },
- ],
- });
- }
-
- let response: string = "";
- const streamedModel = chatModelProvider(
- bot.provider,
- bot.model,
- temperature,
- {
- streaming: true,
- ...botConfig,
- }
- );
-
- const nonStreamingModel = chatModelProvider(
- bot.provider,
- bot.model,
- temperature,
- {
- ...botConfig,
- }
- );
-
- reply.raw.on("close", () => {
- // close the model
- });
-
- const chain = createChain({
- llm: streamedModel,
- question_llm: nonStreamingModel,
- question_template: bot.questionGeneratorPrompt,
- response_template: bot.qaPrompt,
- retriever,
- });
-
- let stream = await chain.stream({
- question: sanitizedQuestion,
- chat_history: groupMessagesByConversation(
- history.map((message) => ({
- type: message.role,
- content: message.text,
- }))
- ),
- });
-
- for await (const token of stream) {
- reply.sse({
- id: "",
- event: "chunk",
- data: JSON.stringify({
- bot: {
- text: token || "",
- sourceDocuments: [],
- },
- history: [
- ...history,
- {
- type: "human",
- text: message,
- },
- {
- type: "ai",
- text: token || "",
- },
- ],
- }),
- });
- response += token;
- }
-
- const documents = await documentPromise;
-
-
- await prisma.botApiHistory.create({
- data: {
- api_key: request.headers.authorization || "",
- bot_id: bot.id,
- human: message,
- bot: response,
- }
- })
-
- reply.sse({
- event: "result",
- id: "",
- data: JSON.stringify({
- bot: {
- text: response,
- sourceDocuments: documents,
- },
- history: [
- ...history,
- {
- type: "human",
- text: message,
- },
- {
- type: "ai",
- text: response,
- },
- ],
- }),
- });
- await nextTick();
- return reply.raw.end();
- } catch (e) {
- return reply.status(500).send({
- message: "Internal Server Error",
- });
+ const { message, history, stream } = request.body;
+ if (stream) {
+ try {
+ const bot_id = request.params.id;
+ const prisma = request.server.prisma;
+ const user_id = request.user.user_id;
+
+ const bot = await prisma.bot.findFirst({
+ where: {
+ id: bot_id,
+ user_id,
+ },
+ });
+
+ if (!bot) {
+ return reply.status(404).send({
+ message: "Bot not found",
+ });
+ }
+
+ const temperature = bot.temperature;
+
+ const sanitizedQuestion = message.trim().replaceAll("\n", " ");
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "embedding",
+ });
+
+ if (!embeddingInfo) {
+ return reply.status(404).send({
+ message: "Embedding not found",
+ });
+ }
+
+ const embeddingModel = embeddings(
+ embeddingInfo.model_provider!.toLowerCase(),
+ embeddingInfo.model_id,
+ embeddingInfo?.config
+ );
+
+ reply.raw.on("close", () => {
+ console.log("closed");
+ });
+
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
+ });
+
+ if (!modelinfo) {
+ return reply.status(404).send({
+ message: "Model not found",
+ });
+ }
+
+ const botConfig = (modelinfo.config as {}) || {};
+ let retriever: BaseRetriever;
+ let resolveWithDocuments: (value: Document[]) => void;
+ const documentPromise = new Promise((resolve) => {
+ resolveWithDocuments = resolve;
+ });
+ if (bot.use_hybrid_search) {
+ retriever = new DialoqbaseHybridRetrival(embeddingModel, {
+ botId: bot.id,
+ sourceId: null,
+ callbacks: [
+ {
+ handleRetrieverEnd(documents) {
+ resolveWithDocuments(documents);
+ },
+ },
+ ],
+ });
+ } else {
+ const vectorstore = await DialoqbaseVectorStore.fromExistingIndex(
+ embeddingModel,
+ {
+ botId: bot.id,
+ sourceId: null,
+ }
+ );
+
+ retriever = vectorstore.asRetriever({
+ callbacks: [
+ {
+ handleRetrieverEnd(documents) {
+ resolveWithDocuments(documents);
+ },
+ },
+ ],
+ });
+ }
+
+ let response: string = "";
+ const streamedModel = chatModelProvider(
+ bot.provider,
+ bot.model,
+ temperature,
+ {
+ streaming: true,
+ ...botConfig,
}
- } else {
- try {
- const bot_id = request.params.id;
- const user_id = request.user.user_id;
-
- const prisma = request.server.prisma;
-
- const bot = await prisma.bot.findFirst({
- where: {
- id: bot_id,
- user_id
- },
- });
-
- if (!bot) {
- return reply.status(404).send({
- message: "Bot not found",
- });
- }
-
- const temperature = bot.temperature;
-
- const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
- });
-
- if (!embeddingInfo) {
- return reply.status(404).send({
- message: "Embedding not found",
- });
- }
-
- const embeddingModel = embeddings(
- embeddingInfo.model_provider!.toLowerCase(),
- embeddingInfo.model_id,
- embeddingInfo?.config
- );
-
- let retriever: BaseRetriever;
- let resolveWithDocuments: (value: Document[]) => void;
- const documentPromise = new Promise((resolve) => {
- resolveWithDocuments = resolve;
- });
- if (bot.use_hybrid_search) {
- retriever = new DialoqbaseHybridRetrival(embeddingModel, {
- botId: bot.id,
- sourceId: null,
- callbacks: [
- {
- handleRetrieverEnd(documents) {
- resolveWithDocuments(documents);
- },
- },
- ],
- });
- } else {
- const vectorstore = await DialoqbaseVectorStore.fromExistingIndex(
- embeddingModel,
- {
- botId: bot.id,
- sourceId: null,
- }
- );
-
- retriever = vectorstore.asRetriever({
- callbacks: [
- {
- handleRetrieverEnd(documents) {
- resolveWithDocuments(documents);
- },
- },
- ],
- });
- }
-
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
- });
-
- if (!modelinfo) {
- return reply.status(404).send({
- message: "Model not found",
- });
- }
-
- const botConfig: any = (modelinfo.config as {}) || {};
- if (bot.provider.toLowerCase() === "openai") {
- if (bot.bot_model_api_key && bot.bot_model_api_key.trim() !== "") {
- botConfig.configuration = {
- apiKey: bot.bot_model_api_key,
- };
- }
- }
-
- const model = chatModelProvider(bot.provider, bot.model, temperature, {
- ...botConfig,
- });
-
- const chain = createChain({
- llm: model,
- question_llm: model,
- question_template: bot.questionGeneratorPrompt,
- response_template: bot.qaPrompt,
- retriever,
- });
-
- const botResponse = await chain.invoke({
- question: sanitizedQuestion,
- chat_history: groupMessagesByConversation(
- history.map((message) => ({
- type: message.role,
- content: message.text,
- }))
- ),
- });
-
- const documents = await documentPromise;
-
- await prisma.botApiHistory.create({
- data: {
- api_key: request.headers.authorization || "",
- bot_id: bot.id,
- human: message,
- bot: botResponse,
- }
- })
-
- return {
- bot: {
- text: botResponse,
- sourceDocuments: documents,
- },
- history: [
- ...history,
- {
- type: "human",
- text: message,
- },
- {
- type: "ai",
- text: botResponse,
- },
- ],
- };
- } catch (e) {
- return reply.status(500).send({
- message: "Internal Server Error",
- });
+ );
+
+ const nonStreamingModel = chatModelProvider(
+ bot.provider,
+ bot.model,
+ temperature,
+ {
+ ...botConfig,
}
+ );
+
+ reply.raw.on("close", () => {
+ // close the model
+ });
+
+ const chain = createChain({
+ llm: streamedModel,
+ question_llm: nonStreamingModel,
+ question_template: bot.questionGeneratorPrompt,
+ response_template: bot.qaPrompt,
+ retriever,
+ });
+
+ let stream = await chain.stream({
+ question: sanitizedQuestion,
+ chat_history: groupMessagesByConversation(
+ history.map((message) => ({
+ type: message.role,
+ content: message.text,
+ }))
+ ),
+ });
+
+ for await (const token of stream) {
+ reply.sse({
+ id: "",
+ event: "chunk",
+ data: JSON.stringify({
+ bot: {
+ text: token || "",
+ sourceDocuments: [],
+ },
+ history: [
+ ...history,
+ {
+ type: "human",
+ text: message,
+ },
+ {
+ type: "ai",
+ text: token || "",
+ },
+ ],
+ }),
+ });
+ response += token;
+ }
+
+ const documents = await documentPromise;
+
+ await prisma.botApiHistory.create({
+ data: {
+ api_key: request.headers.authorization || "",
+ bot_id: bot.id,
+ human: message,
+ bot: response,
+ },
+ });
+
+ reply.sse({
+ event: "result",
+ id: "",
+ data: JSON.stringify({
+ bot: {
+ text: response,
+ sourceDocuments: documents,
+ },
+ history: [
+ ...history,
+ {
+ type: "human",
+ text: message,
+ },
+ {
+ type: "ai",
+ text: response,
+ },
+ ],
+ }),
+ });
+ await nextTick();
+ return reply.raw.end();
+ } catch (e) {
+ return reply.status(500).send({
+ message: "Internal Server Error",
+ });
}
+ } else {
+ try {
+ const bot_id = request.params.id;
+ const user_id = request.user.user_id;
+
+ const prisma = request.server.prisma;
+
+ const bot = await prisma.bot.findFirst({
+ where: {
+ id: bot_id,
+ user_id,
+ },
+ });
+
+ if (!bot) {
+ return reply.status(404).send({
+ message: "Bot not found",
+ });
+ }
+
+ const temperature = bot.temperature;
+
+ const sanitizedQuestion = message.trim().replaceAll("\n", " ");
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "embedding",
+ });
+
+ if (!embeddingInfo) {
+ return reply.status(404).send({
+ message: "Embedding not found",
+ });
+ }
+
+ const embeddingModel = embeddings(
+ embeddingInfo.model_provider!.toLowerCase(),
+ embeddingInfo.model_id,
+ embeddingInfo?.config
+ );
+
+ let retriever: BaseRetriever;
+ let resolveWithDocuments: (value: Document[]) => void;
+ const documentPromise = new Promise((resolve) => {
+ resolveWithDocuments = resolve;
+ });
+ if (bot.use_hybrid_search) {
+ retriever = new DialoqbaseHybridRetrival(embeddingModel, {
+ botId: bot.id,
+ sourceId: null,
+ callbacks: [
+ {
+ handleRetrieverEnd(documents) {
+ resolveWithDocuments(documents);
+ },
+ },
+ ],
+ });
+ } else {
+ const vectorstore = await DialoqbaseVectorStore.fromExistingIndex(
+ embeddingModel,
+ {
+ botId: bot.id,
+ sourceId: null,
+ }
+ );
+
+ retriever = vectorstore.asRetriever({
+ callbacks: [
+ {
+ handleRetrieverEnd(documents) {
+ resolveWithDocuments(documents);
+ },
+ },
+ ],
+ });
+ }
+
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
+ });
+
+ if (!modelinfo) {
+ return reply.status(404).send({
+ message: "Model not found",
+ });
+ }
+
+ const botConfig: any = (modelinfo.config as {}) || {};
+ if (bot.provider.toLowerCase() === "openai") {
+ if (bot.bot_model_api_key && bot.bot_model_api_key.trim() !== "") {
+ botConfig.configuration = {
+ apiKey: bot.bot_model_api_key,
+ };
+ }
+ }
+
+ const model = chatModelProvider(bot.provider, bot.model, temperature, {
+ ...botConfig,
+ });
+
+ const chain = createChain({
+ llm: model,
+ question_llm: model,
+ question_template: bot.questionGeneratorPrompt,
+ response_template: bot.qaPrompt,
+ retriever,
+ });
+
+ const botResponse = await chain.invoke({
+ question: sanitizedQuestion,
+ chat_history: groupMessagesByConversation(
+ history.map((message) => ({
+ type: message.role,
+ content: message.text,
+ }))
+ ),
+ });
+
+ const documents = await documentPromise;
+
+ await prisma.botApiHistory.create({
+ data: {
+ api_key: request.headers.authorization || "",
+ bot_id: bot.id,
+ human: message,
+ bot: botResponse,
+ },
+ });
+
+ return {
+ bot: {
+ text: botResponse,
+ sourceDocuments: documents,
+ },
+ history: [
+ ...history,
+ {
+ type: "human",
+ text: message,
+ },
+ {
+ type: "ai",
+ text: botResponse,
+ },
+ ],
+ };
+ } catch (e) {
+ return reply.status(500).send({
+ message: "Internal Server Error",
+ });
+ }
+ }
};
diff --git a/server/src/handlers/api/v1/bot/bot/get.handler.ts b/server/src/handlers/api/v1/bot/bot/get.handler.ts
index c6f2cc9d..d02cfa9c 100644
--- a/server/src/handlers/api/v1/bot/bot/get.handler.ts
+++ b/server/src/handlers/api/v1/bot/bot/get.handler.ts
@@ -2,6 +2,7 @@ import { FastifyReply, FastifyRequest } from "fastify";
import { GetBotRequestById } from "./types";
import { getSettings } from "../../../../../utils/common";
+import { getAllOllamaModels } from "../../../../../utils/ollama";
export const getBotByIdEmbeddingsHandler = async (
request: FastifyRequest,
@@ -125,10 +126,18 @@ export const getCreateBotConfigHandler = async (
reply: FastifyReply
) => {
const prisma = request.server.prisma;
+ const settings = await getSettings(prisma);
+
+ const not_to_hide_providers = settings?.hideDefaultModels
+ ? ["Local", "local", "ollama", "transformer", "Transformer"]
+ : undefined;
const models = await prisma.dialoqbaseModels.findMany({
where: {
hide: false,
deleted: false,
+ model_provider: {
+ in: not_to_hide_providers,
+ },
},
});
@@ -146,16 +155,30 @@ export const getCreateBotConfigHandler = async (
.filter((model) => model.model_type === "embedding")
.map((model) => {
return {
- label: `${model.name || model.model_id} ${model.model_id === "dialoqbase_eb_dialoqbase-ollama"
+ label: `${model.name || model.model_id} ${
+ model.model_id === "dialoqbase_eb_dialoqbase-ollama"
? "(Deprecated)"
: ""
- }`,
+ }`,
value: model.model_id,
disabled: model.model_id === "dialoqbase_eb_dialoqbase-ollama",
};
});
-
- const settings = await getSettings(prisma);
+
+ if (settings?.dynamicallyFetchOllamaModels) {
+ const ollamaModels = await getAllOllamaModels(settings.ollamaURL);
+ chatModel.push(
+ ...ollamaModels?.filter((model) => {
+ return (
+ !model?.details?.families?.includes("bert") &&
+ !model?.details?.families?.includes("nomic-bert")
+ );
+ })
+ );
+ embeddingModel.push(
+ ...ollamaModels.map((model) => ({ ...model, disabled: false }))
+ );
+ }
return {
chatModel,
@@ -200,10 +223,11 @@ export const getBotByIdSettingsHandler = async (
.filter((model) => model.model_type === "embedding")
.map((model) => {
return {
- label: `${model.name || model.model_id} ${model.model_id === "dialoqbase_eb_dialoqbase-ollama"
+ label: `${model.name || model.model_id} ${
+ model.model_id === "dialoqbase_eb_dialoqbase-ollama"
? "(Deprecated)"
: ""
- }`,
+ }`,
value: model.model_id,
disabled: model.model_id === "dialoqbase_eb_dialoqbase-ollama",
};
@@ -221,7 +245,6 @@ export const getBotByIdSettingsHandler = async (
};
};
-
export const isBotReadyHandler = async (
request: FastifyRequest,
reply: FastifyReply
@@ -252,4 +275,4 @@ export const isBotReadyHandler = async (
return {
is_ready: source === 0,
};
-};
\ No newline at end of file
+};
diff --git a/server/src/handlers/api/v1/bot/bot/post.handler.ts b/server/src/handlers/api/v1/bot/bot/post.handler.ts
index 238b7133..fd557a7b 100644
--- a/server/src/handlers/api/v1/bot/bot/post.handler.ts
+++ b/server/src/handlers/api/v1/bot/bot/post.handler.ts
@@ -16,6 +16,7 @@ import {
HELPFUL_ASSISTANT_WITH_CONTEXT_PROMPT,
HELPFUL_ASSISTANT_WITHOUT_CONTEXT_PROMPT,
} from "../../../../../utils/prompts";
+import { getModelInfo } from "../../../../../utils/get-model-info";
export const createBotHandler = async (
request: FastifyRequest,
@@ -55,12 +56,10 @@ export const createBotHandler = async (
});
}
// const providerName = modelProviderName(model);
- const modelInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: model,
- hide: false,
- deleted: false,
- },
+ const modelInfo = await getModelInfo({
+ model,
+ prisma,
+ type: "chat",
});
if (!modelInfo) {
@@ -69,12 +68,10 @@ export const createBotHandler = async (
});
}
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
diff --git a/server/src/handlers/api/v1/bot/bot/put.handler.ts b/server/src/handlers/api/v1/bot/bot/put.handler.ts
index 380d6fe7..7dc0fc97 100644
--- a/server/src/handlers/api/v1/bot/bot/put.handler.ts
+++ b/server/src/handlers/api/v1/bot/bot/put.handler.ts
@@ -4,6 +4,7 @@ import {
apiKeyValidaton,
apiKeyValidatonMessage,
} from "../../../../../utils/validate";
+import { getModelInfo } from "../../../../../utils/get-model-info";
export const updateBotByIdHandler = async (
request: FastifyRequest,
@@ -25,12 +26,9 @@ export const updateBotByIdHandler = async (
});
}
- const modelInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: request.body.model,
- hide: false,
- deleted: false,
- },
+ const modelInfo = await getModelInfo({
+ model: request.body.model,
+ prisma,
});
if (!modelInfo) {
@@ -96,26 +94,15 @@ export const updateBotAPIByIdHandler = async (
questionGeneratorPrompt: request.body?.question_generator_prompt,
system_prompt: undefined,
question_generator_prompt: undefined,
- }
-
+ };
if (updateBody.model) {
- const modelInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- hide: false,
- deleted: false,
- OR: [
- {
- model_id: updateBody.model
- },
- {
- model_id: `${updateBody.model}-dbase`
- }
- ],
- },
+ const modelInfo = await getModelInfo({
+ model: updateBody.model,
+ prisma,
+ type: "chat",
});
-
if (!modelInfo) {
return reply.status(400).send({
message: "Model not found",
@@ -140,8 +127,7 @@ export const updateBotAPIByIdHandler = async (
updateBody = {
...updateBody,
provider: modelInfo.model_provider || "",
- }
-
+ };
}
await prisma.bot.update({
where: {
diff --git a/server/src/handlers/api/v1/bot/bot/upload.handler.ts b/server/src/handlers/api/v1/bot/bot/upload.handler.ts
index cf0f1b9a..6264c04c 100644
--- a/server/src/handlers/api/v1/bot/bot/upload.handler.ts
+++ b/server/src/handlers/api/v1/bot/bot/upload.handler.ts
@@ -19,6 +19,7 @@ const pump = util.promisify(pipeline);
import { fileTypeFinder } from "../../../../../utils/fileType";
import { getSettings } from "../../../../../utils/common";
import { HELPFUL_ASSISTANT_WITH_CONTEXT_PROMPT } from "../../../../../utils/prompts";
+import { getModelInfo } from "../../../../../utils/get-model-info";
export const createBotFileHandler = async (
request: FastifyRequest,
@@ -52,12 +53,10 @@ export const createBotFileHandler = async (
});
}
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
@@ -76,12 +75,10 @@ export const createBotFileHandler = async (
});
}
- const modelInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: model,
- hide: false,
- deleted: false,
- },
+ const modelInfo = await getModelInfo({
+ model,
+ prisma,
+ type: "chat",
});
if (!modelInfo) {
diff --git a/server/src/handlers/api/v1/bot/playground/chat.handler.ts b/server/src/handlers/api/v1/bot/playground/chat.handler.ts
index ab950f97..2e7661ec 100644
--- a/server/src/handlers/api/v1/bot/playground/chat.handler.ts
+++ b/server/src/handlers/api/v1/bot/playground/chat.handler.ts
@@ -7,6 +7,7 @@ import { DialoqbaseHybridRetrival } from "../../../../../utils/hybrid";
import { BaseRetriever } from "@langchain/core/retrievers";
import { Document } from "langchain/document";
import { createChain, groupMessagesByConversation } from "../../../../../chain";
+import { getModelInfo } from "../../../../../utils/get-model-info";
export const chatRequestHandler = async (
request: FastifyRequest,
@@ -48,14 +49,11 @@ export const chatRequestHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "all",
});
-
if (!embeddingInfo) {
return {
bot: {
@@ -118,12 +116,10 @@ export const chatRequestHandler = async (
});
}
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
});
if (!modelinfo) {
@@ -295,12 +291,10 @@ export const chatRequestStreamHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
@@ -375,12 +369,10 @@ export const chatRequestStreamHandler = async (
});
}
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
});
if (!modelinfo) {
reply.raw.setHeader("Content-Type", "text/event-stream");
diff --git a/server/src/handlers/bot/api.handler.ts b/server/src/handlers/bot/api.handler.ts
index aef50c1b..d7e995aa 100644
--- a/server/src/handlers/bot/api.handler.ts
+++ b/server/src/handlers/bot/api.handler.ts
@@ -8,6 +8,7 @@ import { Document } from "langchain/document";
import { BaseRetriever } from "@langchain/core/retrievers";
import { DialoqbaseHybridRetrival } from "../../utils/hybrid";
import { createChain, groupMessagesByConversation } from "../../chain";
+import { getModelInfo } from "../../utils/get-model-info";
export const chatRequestAPIHandler = async (
request: FastifyRequest,
@@ -40,12 +41,10 @@ export const chatRequestAPIHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ prisma,
+ model: bot.embedding,
+ type: "embedding",
});
if (!embeddingInfo) {
@@ -64,12 +63,10 @@ export const chatRequestAPIHandler = async (
console.log("closed");
});
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
+ const modelinfo = await getModelInfo({
+ prisma,
+ model: bot.model,
+ type: "chat",
});
if (!modelinfo) {
@@ -172,15 +169,14 @@ export const chatRequestAPIHandler = async (
});
const documents = await documentPromise;
-
await prisma.botApiHistory.create({
data: {
api_key: request.headers["x-api-key"],
bot_id: bot.id,
human: message,
bot: response,
- }
- })
+ },
+ });
reply.sse({
event: "result",
@@ -237,12 +233,10 @@ export const chatRequestAPIHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ prisma,
+ model: bot.embedding,
+ type: "embedding",
});
if (!embeddingInfo) {
@@ -308,12 +302,10 @@ export const chatRequestAPIHandler = async (
});
}
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
+ const modelinfo = await getModelInfo({
+ prisma,
+ model: bot.model,
+ type: "chat",
});
if (!modelinfo) {
@@ -375,8 +367,8 @@ export const chatRequestAPIHandler = async (
bot_id: bot.id,
human: message,
bot: botResponse,
- }
- })
+ },
+ });
return {
bot: {
diff --git a/server/src/handlers/bot/post.handler.ts b/server/src/handlers/bot/post.handler.ts
index 9bbc17c9..04e2a2c9 100644
--- a/server/src/handlers/bot/post.handler.ts
+++ b/server/src/handlers/bot/post.handler.ts
@@ -7,6 +7,7 @@ import { BaseRetriever } from "@langchain/core/retrievers";
import { DialoqbaseHybridRetrival } from "../../utils/hybrid";
import { Document } from "langchain/document";
import { createChain, groupMessagesByConversation } from "../../chain";
+import { getModelInfo } from "../../utils/get-model-info";
export const chatRequestHandler = async (
request: FastifyRequest,
@@ -69,12 +70,10 @@ export const chatRequestHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
@@ -139,13 +138,11 @@ export const chatRequestHandler = async (
});
}
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
- });
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
+ })
if (!modelinfo) {
return {
@@ -341,13 +338,11 @@ export const chatRequestStreamHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
return {
@@ -416,13 +411,11 @@ export const chatRequestStreamHandler = async (
console.log("closed");
});
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
- });
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
+ })
if (!modelinfo) {
reply.raw.setHeader("Content-Type", "text/event-stream");
diff --git a/server/src/integration/handlers/discord.handler.ts b/server/src/integration/handlers/discord.handler.ts
index 3eb9c4da..4549b37e 100644
--- a/server/src/integration/handlers/discord.handler.ts
+++ b/server/src/integration/handlers/discord.handler.ts
@@ -6,6 +6,7 @@ import { DialoqbaseHybridRetrival } from "../../utils/hybrid";
import { Document } from "langchain/document";
import { BaseRetriever } from "@langchain/core/retrievers";
import { createChain } from "../../chain";
+import { getModelInfo } from "../../utils/get-model-info";
const prisma = new PrismaClient();
export const discordBotHandler = async (
@@ -51,14 +52,11 @@ export const discordBotHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
- });
-
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
return {
text: "Opps! Model not found",
@@ -108,12 +106,10 @@ export const discordBotHandler = async (
});
}
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
});
if (!modelinfo) {
diff --git a/server/src/integration/handlers/telegram.handler.ts b/server/src/integration/handlers/telegram.handler.ts
index 9f675ff2..0be218d1 100644
--- a/server/src/integration/handlers/telegram.handler.ts
+++ b/server/src/integration/handlers/telegram.handler.ts
@@ -5,6 +5,7 @@ import { chatModelProvider } from "../../utils/models";
import { DialoqbaseHybridRetrival } from "../../utils/hybrid";
import { BaseRetriever } from "@langchain/core/retrievers";
import { createChain } from "../../chain";
+import { getModelInfo } from "../../utils/get-model-info";
const prisma = new PrismaClient();
export const telegramBotHandler = async (
@@ -46,12 +47,10 @@ export const telegramBotHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
@@ -83,14 +82,12 @@ export const telegramBotHandler = async (
retriever = vectorstore.asRetriever({});
}
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
});
-
+
if (!modelinfo) {
return "Unable to find model";
}
diff --git a/server/src/integration/handlers/whatsapp.handler.ts b/server/src/integration/handlers/whatsapp.handler.ts
index 7945312f..bfed6e21 100644
--- a/server/src/integration/handlers/whatsapp.handler.ts
+++ b/server/src/integration/handlers/whatsapp.handler.ts
@@ -5,6 +5,7 @@ import { chatModelProvider } from "../../utils/models";
import { BaseRetriever } from "@langchain/core/retrievers";
import { DialoqbaseHybridRetrival } from "../../utils/hybrid";
import { createChain } from "../../chain";
+import { getModelInfo } from "../../utils/get-model-info";
const prisma = new PrismaClient();
export const whatsappBotHandler = async (
@@ -55,12 +56,10 @@ export const whatsappBotHandler = async (
const temperature = bot.temperature;
const sanitizedQuestion = message.trim().replaceAll("\n", " ");
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: bot.embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
@@ -92,12 +91,10 @@ export const whatsappBotHandler = async (
retriever = vectorstore.asRetriever({});
}
- const modelinfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: bot.model,
- hide: false,
- deleted: false,
- },
+ const modelinfo = await getModelInfo({
+ model: bot.model,
+ prisma,
+ type: "chat",
});
if (!modelinfo) {
diff --git a/server/src/queue/controllers/audio.controller.ts b/server/src/queue/controllers/audio.controller.ts
index 86f6b07c..daeddd40 100644
--- a/server/src/queue/controllers/audio.controller.ts
+++ b/server/src/queue/controllers/audio.controller.ts
@@ -5,6 +5,7 @@ import { embeddings } from "../../utils/embeddings";
import { DialoqbaseAudioVideoLoader } from "../../loader/audio-video";
import { convertMp3ToWave } from "../../utils/ffmpeg";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const audioQueueController = async (
source: QSource,
@@ -25,13 +26,11 @@ export const audioQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
diff --git a/server/src/queue/controllers/csv.controller.ts b/server/src/queue/controllers/csv.controller.ts
index 8e159e5b..c9e83179 100644
--- a/server/src/queue/controllers/csv.controller.ts
+++ b/server/src/queue/controllers/csv.controller.ts
@@ -4,6 +4,7 @@ import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { DialoqbaseVectorStore } from "../../utils/store";
import { embeddings } from "../../utils/embeddings";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const csvQueueController = async (
source: QSource,
@@ -21,12 +22,10 @@ export const csvQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
diff --git a/server/src/queue/controllers/docx.controller.ts b/server/src/queue/controllers/docx.controller.ts
index 185a8e26..aea1dc4e 100644
--- a/server/src/queue/controllers/docx.controller.ts
+++ b/server/src/queue/controllers/docx.controller.ts
@@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store";
import { embeddings } from "../../utils/embeddings";
import { DialoqbaseDocxLoader } from "../../loader/docx";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const DocxQueueController = async (
source: QSource,
@@ -22,13 +23,11 @@ export const DocxQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
diff --git a/server/src/queue/controllers/github.controller.ts b/server/src/queue/controllers/github.controller.ts
index 7300c90f..fc95d477 100644
--- a/server/src/queue/controllers/github.controller.ts
+++ b/server/src/queue/controllers/github.controller.ts
@@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store";
import { embeddings } from "../../utils/embeddings";
import { DialoqbaseGithub } from "../../loader/github";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const githubQueueController = async (
source: QSource,
@@ -25,13 +26,11 @@ export const githubQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
diff --git a/server/src/queue/controllers/pdf.controller.ts b/server/src/queue/controllers/pdf.controller.ts
index 2bfcaa65..2b211935 100644
--- a/server/src/queue/controllers/pdf.controller.ts
+++ b/server/src/queue/controllers/pdf.controller.ts
@@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store";
import { embeddings } from "../../utils/embeddings";
import { DialoqbasePDFLoader } from "../../loader/pdf";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const pdfQueueController = async (
source: QSource,
@@ -22,13 +23,11 @@ export const pdfQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
diff --git a/server/src/queue/controllers/rest.controller.ts b/server/src/queue/controllers/rest.controller.ts
index 722dc08a..b6375b0f 100644
--- a/server/src/queue/controllers/rest.controller.ts
+++ b/server/src/queue/controllers/rest.controller.ts
@@ -3,6 +3,7 @@ import { DialoqbaseVectorStore } from "../../utils/store";
import { embeddings } from "../../utils/embeddings";
import { DialoqbaseRestApi } from "../../loader/rest";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const restQueueController = async (
source: QSource,
@@ -18,13 +19,11 @@ export const restQueueController = async (
});
const docs = await loader.load();
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
diff --git a/server/src/queue/controllers/text.controller.ts b/server/src/queue/controllers/text.controller.ts
index 11e8c2a3..b4d73c65 100644
--- a/server/src/queue/controllers/text.controller.ts
+++ b/server/src/queue/controllers/text.controller.ts
@@ -3,6 +3,7 @@ import { DialoqbaseVectorStore } from "../../utils/store";
import { embeddings } from "../../utils/embeddings";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const textQueueController = async (
source: QSource,
@@ -20,15 +21,12 @@ export const textQueueController = async (
},
},
]);
-
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
-
+
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
}
diff --git a/server/src/queue/controllers/txt.controller.ts b/server/src/queue/controllers/txt.controller.ts
index baf894fd..445fa24b 100644
--- a/server/src/queue/controllers/txt.controller.ts
+++ b/server/src/queue/controllers/txt.controller.ts
@@ -4,6 +4,7 @@ import { DialoqbaseVectorStore } from "../../utils/store";
import { embeddings } from "../../utils/embeddings";
import { TextLoader } from "langchain/document_loaders/fs/text";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const txtQueueController = async (
source: QSource,
@@ -20,13 +21,11 @@ export const txtQueueController = async (
chunkOverlap: source.chunkOverlap,
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
diff --git a/server/src/queue/controllers/video.controller.ts b/server/src/queue/controllers/video.controller.ts
index c0646a3a..c41db9df 100644
--- a/server/src/queue/controllers/video.controller.ts
+++ b/server/src/queue/controllers/video.controller.ts
@@ -6,6 +6,7 @@ import { embeddings } from "../../utils/embeddings";
import { DialoqbaseAudioVideoLoader } from "../../loader/audio-video";
import { convertMp4ToWave } from "../../utils/ffmpeg";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const videoQueueController = async (
source: QSource,
@@ -26,13 +27,11 @@ export const videoQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
diff --git a/server/src/queue/controllers/website.controller.ts b/server/src/queue/controllers/website.controller.ts
index d0843f75..2c492326 100644
--- a/server/src/queue/controllers/website.controller.ts
+++ b/server/src/queue/controllers/website.controller.ts
@@ -8,6 +8,7 @@ import { DialoqbasePDFLoader } from "../../loader/pdf";
import { DialoqbaseWebLoader } from "../../loader/web";
import { CheerioWebBaseLoader } from "langchain/document_loaders/web/cheerio";
import { PrismaClient } from "@prisma/client";
+import { getModelInfo } from "../../utils/get-model-info";
export const websiteQueueController = async (
source: QSource,
@@ -37,13 +38,11 @@ export const websiteQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
@@ -79,13 +78,11 @@ export const websiteQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
- });
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
+ })
if (!embeddingInfo) {
throw new Error("Embedding not found. Please verify the embedding id");
diff --git a/server/src/queue/controllers/youtube.controller.ts b/server/src/queue/controllers/youtube.controller.ts
index fdb0e12a..5f1f30ef 100644
--- a/server/src/queue/controllers/youtube.controller.ts
+++ b/server/src/queue/controllers/youtube.controller.ts
@@ -6,6 +6,7 @@ import { embeddings } from "../../utils/embeddings";
import { DialoqbaseYoutube } from "../../loader/youtube";
import { PrismaClient } from "@prisma/client";
import { DialoqbaseYoutubeTranscript } from "../../loader/youtube-transcript";
+import { getModelInfo } from "../../utils/get-model-info";
export const youtubeQueueController = async (
source: QSource,
@@ -29,12 +30,10 @@ export const youtubeQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
@@ -66,12 +65,10 @@ export const youtubeQueueController = async (
});
const chunks = await textSplitter.splitDocuments(docs);
- const embeddingInfo = await prisma.dialoqbaseModels.findFirst({
- where: {
- model_id: source.embedding,
- hide: false,
- deleted: false,
- },
+ const embeddingInfo = await getModelInfo({
+ model: source.embedding,
+ prisma,
+ type: "embedding",
});
if (!embeddingInfo) {
diff --git a/server/src/schema/api/v1/admin/index.ts b/server/src/schema/api/v1/admin/index.ts
index 6f4c9599..424b6e6d 100644
--- a/server/src/schema/api/v1/admin/index.ts
+++ b/server/src/schema/api/v1/admin/index.ts
@@ -20,6 +20,8 @@ export const dialoqbaseSettingsSchema: FastifySchema = {
defaultChatModel: { type: "string" },
defaultEmbeddingModel: { type: "string" },
dynamicallyFetchOllamaModels: { type: "boolean" },
+ hideDefaultModels: { type: "boolean" },
+ ollamaURL: { type: "string" },
},
},
};
@@ -43,12 +45,9 @@ export const updateDialoqbaseSettingsSchema: FastifySchema = {
dynamicallyFetchOllamaModels: { type: "boolean" },
defaultChatModel: { type: "string" },
defaultEmbeddingModel: { type: "string" },
+ hideDefaultModels: { type: "boolean" },
+ ollamaURL: { type: "string" },
},
- required: [
- "noOfBotsPerUser",
- "allowUserToCreateBots",
- "allowUserToRegister",
- ],
},
response: {
200: {
diff --git a/server/src/utils/get-model-info.ts b/server/src/utils/get-model-info.ts
new file mode 100644
index 00000000..e5aa21f6
--- /dev/null
+++ b/server/src/utils/get-model-info.ts
@@ -0,0 +1,92 @@
+import { PrismaClient, DialoqbaseModels } from "@prisma/client";
+import { getSettings } from "./common";
+import { cleanUrl, getAllOllamaModels } from "./ollama";
+
+export const getModelInfo = async ({
+ model,
+ prisma,
+ type = "all",
+}: {
+ prisma: PrismaClient;
+ model: string;
+ type?: "all" | "chat" | "embedding";
+}): Promise => {
+ let modelInfo: DialoqbaseModels | null = null;
+ const settings = await getSettings(prisma);
+ const not_to_hide_providers = settings?.hideDefaultModels
+ ? [ "Local", "local", "ollama", "transformer", "Transformer"]
+ : undefined;
+ if (type === "all") {
+ modelInfo = await prisma.dialoqbaseModels.findFirst({
+ where: {
+ model_id: model,
+ hide: false,
+ deleted: false,
+ model_provider: {
+ in: not_to_hide_providers,
+ },
+ },
+ });
+ } else if (type === "chat") {
+ modelInfo = await prisma.dialoqbaseModels.findFirst({
+ where: {
+ hide: false,
+ deleted: false,
+ model_provider: {
+ in: not_to_hide_providers,
+ },
+ OR: [
+ {
+ model_id: model,
+ },
+ {
+ model_id: `${model}-dbase`,
+ },
+ ],
+ },
+ });
+ } else if (type === "embedding") {
+ modelInfo = await prisma.dialoqbaseModels.findFirst({
+ where: {
+ OR: [
+ {
+ model_id: model,
+ },
+ {
+ model_id: `dialoqbase_eb_${model}`,
+ },
+ ],
+ hide: false,
+ deleted: false,
+ model_provider: {
+ in: not_to_hide_providers,
+ },
+ },
+ });
+ }
+ if (!modelInfo) {
+ if (settings?.dynamicallyFetchOllamaModels) {
+ const ollamaModles = await getAllOllamaModels(settings.ollamaURL);
+ const ollamaInfo = ollamaModles.find((m) => m.value === model);
+ if (ollamaInfo) {
+ return {
+ name: ollamaInfo.name,
+ model_id: ollamaInfo.name,
+ stream_available: true,
+ local_model: true,
+ model_provider: "ollama",
+ config: {
+ baseURL: cleanUrl(settings.ollamaURL),
+ },
+ createdAt: new Date(),
+ model_type: "chat",
+ deleted: false,
+ hide: false,
+ id: 1,
+ };
+ }
+ }
+ }
+
+ return modelInfo;
+};
diff --git a/server/src/utils/ollama.ts b/server/src/utils/ollama.ts
new file mode 100644
index 00000000..bdfaac77
--- /dev/null
+++ b/server/src/utils/ollama.ts
@@ -0,0 +1,38 @@
+import axios from "axios";
+
+export const cleanUrl = (url: string) => {
+ if (url.endsWith("/")) {
+ return url.slice(0, -1);
+ }
+ return url;
+};
+
+export const getAllOllamaModels = async (url: string) => {
+ try {
+ const response = await axios.get(`${cleanUrl(url)}/api/tags`);
+ const { models } = response.data as {
+ models: {
+ name: string;
+ details?: {
+ parent_model?: string
+ format: string
+ family: string
+ families: string[]
+ parameter_size: string
+ quantization_level: string
+ }
+ }[];
+ };
+ return models.map((data) => {
+ return {
+ ...data,
+ label: data.name,
+ value: data.name,
+ stream: true
+ };
+ });
+ } catch (error) {
+ console.log(`Error fetching Ollama models`, error);
+ return [];
+ }
+};