From 10fc488525e852463d8a6264ddec87d45178afb4 Mon Sep 17 00:00:00 2001 From: James Date: Tue, 31 Dec 2024 18:24:57 +0700 Subject: [PATCH] feat: add run --- engine/common/additional_message.h | 103 +++++ .../common/api-dto/delete_success_response.h | 2 +- engine/common/assistant.h | 4 +- .../common/assistant_code_interpreter_tool.h | 2 +- engine/common/assistant_file_search_tool.h | 6 +- engine/common/assistant_function_tool.h | 10 +- engine/common/cortex/sync_queue.h | 29 ++ engine/common/dto/run_create_dto.h | 350 ++++++++++++++ engine/common/dto/run_update_dto.h | 64 +++ engine/common/events/assistant_stream_event.h | 55 +++ engine/common/events/done.h | 19 + engine/common/events/error.h | 32 ++ engine/common/events/thread_created.h | 12 + .../common/events/thread_message_completed.h | 20 + engine/common/events/thread_message_created.h | 20 + engine/common/events/thread_message_delta.h | 51 ++ .../events/thread_message_in_progress.h | 20 + .../common/events/thread_message_incomplete.h | 12 + engine/common/events/thread_run_cancelled.h | 12 + engine/common/events/thread_run_cancelling.h | 12 + engine/common/events/thread_run_completed.h | 12 + .../common/events/thread_run_created_event.h | 12 + engine/common/events/thread_run_expired.h | 12 + engine/common/events/thread_run_failed.h | 12 + engine/common/events/thread_run_in_progress.h | 12 + engine/common/events/thread_run_incomplete.h | 12 + engine/common/events/thread_run_queued.h | 12 + .../events/thread_run_requires_action.h | 12 + .../common/events/thread_run_step_cancelled.h | 12 + .../common/events/thread_run_step_completed.h | 12 + .../common/events/thread_run_step_created.h | 12 + engine/common/events/thread_run_step_delta.h | 12 + .../common/events/thread_run_step_expired.h | 12 + engine/common/events/thread_run_step_failed.h | 12 + .../events/thread_run_step_in_progress.h | 12 + engine/common/file.h | 2 +- engine/common/json_serializable.h | 2 +- engine/common/last_error.h | 71 +++ engine/common/message.h | 42 +- engine/common/message_attachment.h | 2 +- engine/common/message_content_image_file.h | 2 +- engine/common/message_content_image_url.h | 4 +- engine/common/message_content_refusal.h | 2 +- engine/common/message_content_text.h | 22 +- engine/common/message_delta.h | 96 ++++ engine/common/message_incomplete_detail.h | 25 +- .../common/repository/assistant_repository.h | 2 +- engine/common/repository/message_repository.h | 4 +- engine/common/repository/run_repository.h | 23 + engine/common/repository/thread_repository.h | 2 +- engine/common/required_action.h | 32 ++ engine/common/run.h | 343 ++++++++++++++ engine/common/run_step.h | 76 +++ engine/common/run_step_delta.h | 30 ++ engine/common/run_step_detail.h | 26 ++ engine/common/run_usage.h | 50 ++ engine/common/thread.h | 4 +- engine/common/tool_choice.h | 51 ++ engine/common/tool_resources.h | 6 +- engine/common/truncation_strategy.h | 88 ++++ engine/common/variant_map.h | 36 ++ engine/controllers/runs.cc | 80 ++++ engine/controllers/runs.h | 63 +++ engine/controllers/server.cc | 3 - engine/controllers/server.h | 2 +- engine/controllers/swagger.cc | 2 +- engine/database/database.h | 3 +- engine/database/runs.cc | 437 ++++++++++++++++++ engine/database/runs.h | 38 ++ engine/main.cc | 19 +- engine/migrations/db_helper.h | 4 +- engine/migrations/migration_manager.cc | 9 + engine/migrations/v3/migration.h | 4 +- engine/migrations/v4/migration.h | 92 ++++ .../repositories/assistant_fs_repository.cc | 4 +- engine/repositories/assistant_fs_repository.h | 5 +- engine/repositories/file_fs_repository.cc | 1 - engine/repositories/message_fs_repository.cc | 4 +- engine/repositories/message_fs_repository.h | 4 +- engine/repositories/run_sqlite_repository.cc | 28 ++ engine/repositories/run_sqlite_repository.h | 28 ++ engine/repositories/thread_fs_repository.cc | 4 +- engine/repositories/thread_fs_repository.h | 5 +- engine/services/assistant_service.cc | 7 +- engine/services/database_service.cc | 31 +- engine/services/database_service.h | 16 +- engine/services/file_service.cc | 8 +- engine/services/inference_service.h | 30 +- engine/services/message_service.cc | 33 +- engine/services/message_service.h | 6 +- engine/services/model_service.cc | 10 +- engine/services/run_service.cc | 324 +++++++++++++ engine/services/run_service.h | 76 +++ engine/services/thread_service.cc | 9 +- .../test/components/test_function_calling.cc | 3 +- engine/test/components/test_tool_resources.cc | 2 +- engine/utils/cpuid/cpu_info.cc | 2 +- engine/utils/time_utils.h | 12 + 98 files changed, 3315 insertions(+), 151 deletions(-) create mode 100644 engine/common/additional_message.h create mode 100644 engine/common/cortex/sync_queue.h create mode 100644 engine/common/dto/run_create_dto.h create mode 100644 engine/common/dto/run_update_dto.h create mode 100644 engine/common/events/assistant_stream_event.h create mode 100644 engine/common/events/done.h create mode 100644 engine/common/events/error.h create mode 100644 engine/common/events/thread_created.h create mode 100644 engine/common/events/thread_message_completed.h create mode 100644 engine/common/events/thread_message_created.h create mode 100644 engine/common/events/thread_message_delta.h create mode 100644 engine/common/events/thread_message_in_progress.h create mode 100644 engine/common/events/thread_message_incomplete.h create mode 100644 engine/common/events/thread_run_cancelled.h create mode 100644 engine/common/events/thread_run_cancelling.h create mode 100644 engine/common/events/thread_run_completed.h create mode 100644 engine/common/events/thread_run_created_event.h create mode 100644 engine/common/events/thread_run_expired.h create mode 100644 engine/common/events/thread_run_failed.h create mode 100644 engine/common/events/thread_run_in_progress.h create mode 100644 engine/common/events/thread_run_incomplete.h create mode 100644 engine/common/events/thread_run_queued.h create mode 100644 engine/common/events/thread_run_requires_action.h create mode 100644 engine/common/events/thread_run_step_cancelled.h create mode 100644 engine/common/events/thread_run_step_completed.h create mode 100644 engine/common/events/thread_run_step_created.h create mode 100644 engine/common/events/thread_run_step_delta.h create mode 100644 engine/common/events/thread_run_step_expired.h create mode 100644 engine/common/events/thread_run_step_failed.h create mode 100644 engine/common/events/thread_run_step_in_progress.h create mode 100644 engine/common/last_error.h create mode 100644 engine/common/message_delta.h create mode 100644 engine/common/repository/run_repository.h create mode 100644 engine/common/required_action.h create mode 100644 engine/common/run.h create mode 100644 engine/common/run_step.h create mode 100644 engine/common/run_step_delta.h create mode 100644 engine/common/run_step_detail.h create mode 100644 engine/common/run_usage.h create mode 100644 engine/common/tool_choice.h create mode 100644 engine/common/truncation_strategy.h create mode 100644 engine/controllers/runs.cc create mode 100644 engine/controllers/runs.h create mode 100644 engine/database/runs.cc create mode 100644 engine/database/runs.h create mode 100644 engine/migrations/v4/migration.h create mode 100644 engine/repositories/run_sqlite_repository.cc create mode 100644 engine/repositories/run_sqlite_repository.h create mode 100644 engine/services/run_service.cc create mode 100644 engine/services/run_service.h create mode 100644 engine/utils/time_utils.h diff --git a/engine/common/additional_message.h b/engine/common/additional_message.h new file mode 100644 index 000000000..d793b6ab0 --- /dev/null +++ b/engine/common/additional_message.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include "common/message_attachment.h" +#include "common/message_attachment_factory.h" +#include "common/message_content.h" +#include "common/message_content_factory.h" +#include "common/message_role.h" +#include "common/variant_map.h" + +namespace OpenAi { +struct AdditionalMessage { + AdditionalMessage() = default; + + AdditionalMessage(const AdditionalMessage&) = delete; + + AdditionalMessage& operator=(const AdditionalMessage&) = delete; + + AdditionalMessage(AdditionalMessage&& other) noexcept + : role{std::move(other.role)}, + content{std::move(other.content)}, + attachments{std::move(other.attachments)}, + metadata{std::move(other.metadata)} {} + + AdditionalMessage& operator=(AdditionalMessage&& other) noexcept { + if (this != &other) { + role = std::move(other.role); + content = std::move(other.content); + attachments = std::move(other.attachments); + metadata = std::move(other.metadata); + } + + return *this; + } + + /** + * The role of the entity that is creating the message. + * Allowed values include: User or Assistant. + */ + Role role; + + std::variant>> + content; + + /** + * A list of files attached to the message, and the tools they were added to. + */ + std::optional> attachments; + + /** + * Set of 16 key-value pairs that can be attached to an object. This can be useful + * for storing additional information about the object in a structured format. + * Keys can be a maximum of 64 characters long and values can be a maximum of + * 512 characters long. + */ + std::optional metadata; + + static cpp::result FromJson( + Json::Value&& json) { + try { + AdditionalMessage msg; + if (json.isMember("role") && json["role"].isString()) { + msg.role = RoleFromString(json["role"].asString()); + } + if (!json.isMember("content")) { + return cpp::fail("content is mandatory"); + } + if (json["content"].isString()) { + msg.content = std::move(json["content"].asString()); + } else if (json["content"].isArray()) { + auto result = ParseContents(std::move(json["content"])); + if (result.has_error()) { + return cpp::fail("Failed to parse content array: " + result.error()); + } + if (result.value().empty()) { + return cpp::fail("Content array cannot be empty"); + } + msg.content = std::move(result.value()); + } else { + return cpp::fail("content must be either a string or an array"); + } + + if (json.isMember("attachments")) { + msg.attachments = + ParseAttachments(std::move(json["attachments"])).value(); + } + if (json.isMember("metadata") && json["metadata"].isObject() && + !json["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + msg.metadata = res.value(); + } + } + return msg; + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } +}; +} // namespace OpenAi diff --git a/engine/common/api-dto/delete_success_response.h b/engine/common/api-dto/delete_success_response.h index ebb8f36f0..c13fea1e9 100644 --- a/engine/common/api-dto/delete_success_response.h +++ b/engine/common/api-dto/delete_success_response.h @@ -8,7 +8,7 @@ struct DeleteSuccessResponse : JsonSerializable { std::string object; bool deleted; - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value json; json["id"] = id; json["object"] = object; diff --git a/engine/common/assistant.h b/engine/common/assistant.h index 6210a0c2c..cc5275aa3 100644 --- a/engine/common/assistant.h +++ b/engine/common/assistant.h @@ -29,7 +29,7 @@ struct JanAssistant : JsonSerializable { ~JanAssistant() = default; - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; @@ -201,7 +201,7 @@ struct Assistant : JsonSerializable { std::variant response_format; - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value root; diff --git a/engine/common/assistant_code_interpreter_tool.h b/engine/common/assistant_code_interpreter_tool.h index 43bfac47c..576cbb122 100644 --- a/engine/common/assistant_code_interpreter_tool.h +++ b/engine/common/assistant_code_interpreter_tool.h @@ -23,7 +23,7 @@ struct AssistantCodeInterpreterTool : public AssistantTool { return std::move(tool); } - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value json; json["type"] = type; return json; diff --git a/engine/common/assistant_file_search_tool.h b/engine/common/assistant_file_search_tool.h index 2abaa7f6e..528e74f93 100644 --- a/engine/common/assistant_file_search_tool.h +++ b/engine/common/assistant_file_search_tool.h @@ -42,7 +42,7 @@ struct FileSearchRankingOption : public JsonSerializable { return option; } - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value json; json["ranker"] = ranker; json["score_threshold"] = score_threshold; @@ -99,7 +99,7 @@ struct AssistantFileSearch : public JsonSerializable { } } - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value root; root["max_num_results"] = max_num_results; root["ranking_options"] = ranking_options.ToJson().value(); @@ -137,7 +137,7 @@ struct AssistantFileSearchTool : public AssistantTool { } } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value root; root["type"] = type; diff --git a/engine/common/assistant_function_tool.h b/engine/common/assistant_function_tool.h index 7998cb8ff..5656ffbd6 100644 --- a/engine/common/assistant_function_tool.h +++ b/engine/common/assistant_function_tool.h @@ -64,10 +64,6 @@ struct AssistantFunction : public JsonSerializable { return cpp::fail("Function name can't be empty"); } - if (!json.isMember("description")) { - return cpp::fail("Function description is mandatory"); - } - if (!json.isMember("parameters")) { return cpp::fail("Function parameters are mandatory"); } @@ -76,14 +72,14 @@ struct AssistantFunction : public JsonSerializable { if (json.isMember("strict")) { is_strict = json["strict"].asBool(); } - AssistantFunction function{json["description"].asString(), + AssistantFunction function{json.get("description", "").asString(), json["name"].asString(), json["parameters"], is_strict}; function.parameters = json["parameters"]; return function; } - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value json; json["description"] = description; json["name"] = name; @@ -120,7 +116,7 @@ struct AssistantFunctionTool : public AssistantTool { return AssistantFunctionTool{function_res.value()}; } - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value root; root["type"] = type; root["function"] = function.ToJson().value(); diff --git a/engine/common/cortex/sync_queue.h b/engine/common/cortex/sync_queue.h new file mode 100644 index 000000000..68dad122e --- /dev/null +++ b/engine/common/cortex/sync_queue.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include +#include + +// Status and result +using InferResult = std::pair; + +struct SyncQueue { + void push(InferResult&& p) { + std::unique_lock l(mtx); + q.push(p); + cond.notify_one(); + } + + InferResult wait_and_pop() { + std::unique_lock l(mtx); + cond.wait(l, [this] { return !q.empty(); }); + auto res = q.front(); + q.pop(); + return res; + } + + std::mutex mtx; + std::condition_variable cond; + std::queue q; +}; diff --git a/engine/common/dto/run_create_dto.h b/engine/common/dto/run_create_dto.h new file mode 100644 index 000000000..f4697a407 --- /dev/null +++ b/engine/common/dto/run_create_dto.h @@ -0,0 +1,350 @@ +#pragma once + +#include +#include "common/additional_message.h" +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" +#include "common/assistant_tool.h" +#include "common/dto/base_dto.h" +#include "common/tool_choice.h" +#include "common/truncation_strategy.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct RunCreateDto : public BaseDto { + RunCreateDto() = default; + + ~RunCreateDto() = default; + + RunCreateDto(const RunCreateDto&) = delete; + + RunCreateDto& operator=(const RunCreateDto&) = delete; + + RunCreateDto(RunCreateDto&& other) noexcept + : assistant_id{std::move(other.assistant_id)}, + model{std::move(other.model)}, + instructions{std::move(other.instructions)}, + additional_instructions{std::move(other.additional_instructions)}, + additional_messages{std::move(other.additional_messages)}, + tools{std::move(other.tools)}, + metadata{std::move(other.metadata)}, + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + stream{std::move(other.stream)}, + max_prompt_tokens{std::move(other.max_prompt_tokens)}, + max_completion_tokens{std::move(other.max_completion_tokens)}, + truncation_strategy{std::move(other.truncation_strategy)}, + tool_choice{std::move(other.tool_choice)}, + parallel_tool_calls{std::move(other.parallel_tool_calls)}, + response_format{std::move(other.response_format)} {} + + RunCreateDto& operator=(RunCreateDto&& other) noexcept { + if (this != &other) { + assistant_id = std::move(other.assistant_id); + model = std::move(other.model); + instructions = std::move(other.instructions); + additional_instructions = std::move(other.additional_instructions); + additional_messages = std::move(other.additional_messages); + tools = std::move(other.tools); + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + stream = std::move(other.stream); + max_prompt_tokens = std::move(other.max_prompt_tokens); + max_completion_tokens = std::move(other.max_completion_tokens); + truncation_strategy = std::move(other.truncation_strategy); + tool_choice = std::move(other.tool_choice); + parallel_tool_calls = std::move(other.parallel_tool_calls); + response_format = std::move(other.response_format); + } + return *this; + } + + /** + * The ID of the assistant to use to execute this run. + */ + std::string assistant_id; + + /** + * The ID of the Model to be used to execute this run. + * If a value is provided here, it will override the model associated with + * the assistant. If not, the model associated with the assistant will be used. + */ + std::optional model; + + /** + * Overrides the instructions of the assistant. This is useful for modifying + * the behavior on a per-run basis. + */ + std::optional instructions; + + /** + * Appends additional instructions at the end of the instructions for the run. + * This is useful for modifying the behavior on a per-run basis without overriding + * other instructions. + */ + std::optional additional_instructions; + + /** + * Adds additional messages to the thread before creating the run. + */ + std::optional> additional_messages; + + /** + * A list of tool enabled on the assistant. There can be a maximum of 128 + * tools per assistant. Tools can be of types code_interpreter, file_search, + * or function. + */ + std::optional>> tools; + + /** + * Set of 16 key-value pairs that can be attached to an object. + * This can be useful for storing additional information about the object + * in a structured format. Keys can be a maximum of 64 characters long + * and values can be a maximum of 512 characters long. + */ + std::optional metadata; + + /** + * What sampling temperature to use, between 0 and 2. Higher values like + * 0.8 will make the output more random, while lower values like 0.2 will + * make it more focused and deterministic. + */ + std::optional temperature; + + /** + * An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p + * probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. + * + * We generally recommend altering this or temperature but not both. + */ + std::optional top_p; + + /** + * If true, returns a stream of events that happen during the Run as + * server-sent events, terminating when the Run enters a terminal + * state with a data: [DONE] message. + */ + std::optional stream; + + /** + * The maximum number of prompt tokens that may be used over the + * course of the run. The run will make a best effort to use only + * the number of prompt tokens specified, across multiple turns of + * the run. If the run exceeds the number of prompt tokens specified, + * the run will end with status incomplete. + * + * See incomplete_details for more info. + */ + std::optional max_prompt_tokens; + + /** + * The maximum number of completion tokens that may be used over the + * course of the run. The run will make a best effort to use only the + * number of completion tokens specified, across multiple turns of the + * run. If the run exceeds the number of completion tokens specified, + * the run will end with status incomplete. + * + * See incomplete_details for more info. + */ + std::optional max_completion_tokens; + + /** + * Controls for how a thread will be truncated prior to the run. + * Use this to control the intial context window of the run. + */ + std::optional truncation_strategy; + + /** + * Controls which (if any) tool is called by the model. + * none means the model will not call any tools and instead generates a message. + * auto is the default value and means the model can pick between generating a message + * or calling one or more tools. + * required means the model must call one or more tools before responding to the user. + * Specifying a particular tool like {"type": "file_search"} or {"type": "function", "function": {"name": "my_function"}} forces the model to call that tool. + */ + std::optional> tool_choice; + + /** + * Whether to enable parallel function calling during tool use. + */ + std::optional parallel_tool_calls{true}; + + /** + * + */ + std::optional> response_format; + + cpp::result Validate() const { + if (assistant_id.empty()) { + return cpp::fail("assistant_id is mandatory"); + } + + return {}; + } + + static cpp::result FromJson(Json::Value&& json) { + try { + RunCreateDto dto; + + if (!json.isMember("assistant_id") || !json["assistant_id"].isString()) { + return cpp::fail("Missing or invalid 'assistant_id' field"); + } + dto.assistant_id = json["assistant_id"].asString(); + dto.model = json["model"].asString(); + dto.instructions = json["instructions"].asString(); + dto.additional_instructions = json["additional_instructions"].asString(); + + if (json.isMember("additional_messages") && + json["additional_messages"].isArray()) { + + std::vector msgs; + auto additional_messages_array = json["additional_messages"]; + for (auto& additional_message : additional_messages_array) { + auto result = OpenAi::AdditionalMessage::FromJson( + std::move(additional_message)); + if (result.has_value()) { + msgs.push_back(std::move(result.value())); + } else { + CTL_WRN("Failed to parse additional message: " + result.error()); + } + } + if (!msgs.empty()) { + dto.additional_messages = std::move(msgs); + } + } + + if (json.isMember("tools") && json["tools"].isArray()) { + auto tools_array = json["tools"]; + std::vector> parsed_tools; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + auto tool_type = tool["type"].asString(); + + if (tool_type == "file_search") { + auto result = OpenAi::AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + parsed_tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = OpenAi::AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + parsed_tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + + result.error()); + } + } else if (tool_type == "function") { + auto result = OpenAi::AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + parsed_tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + if (!parsed_tools.empty()) { + dto.tools = std::move(parsed_tools); + } + } + + // Parse metadata + if (json.isMember("metadata") && json["metadata"].isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_value()) { + dto.metadata = res.value(); + } else { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } + } + + if (json.isMember("temperature") && json["temperature"].isDouble()) { + dto.temperature = json["temperature"].asFloat(); + } + + if (json.isMember("top_p") && json["top_p"].isDouble()) { + dto.top_p = json["top_p"].asFloat(); + } + + if (json.isMember("stream") && json["stream"].isBool()) { + dto.stream = json["stream"].asBool(); + } + + if (json.isMember("max_prompt_tokens") && + json["max_prompt_tokens"].isUInt()) { + dto.max_prompt_tokens = json["max_prompt_tokens"].asUInt(); + } + + if (json.isMember("max_completion_tokens") && + json["max_completion_tokens"].isUInt()) { + dto.max_completion_tokens = json["max_completion_tokens"].asUInt(); + } + + if (json.isMember("truncation_strategy") && + json["truncation_strategy"].isObject()) { + dto.truncation_strategy = + std::move(OpenAi::TruncationStrategy::FromJson( + std::move(json["truncation_strategy"])) + .value()); + } + + if (json.isMember("tool_choice")) { + if (json["tool_choice"].isString()) { + if (json["tool_choice"].asString() != "none" && + json["tool_choice"].asString() != "auto" && + json["tool_choice"].asString() != "required") { + return cpp::fail( + "tool_choice must be either none, auto or required"); + } + + dto.tool_choice = json["tool_choice"].asString(); + } else if (json["tool_choice"].isObject()) { + dto.tool_choice = std::move( + OpenAi::ToolChoice::FromJson(std::move(json["tool_choice"])) + .value()); + } else { + return cpp::fail("tool_choice must be either a string or an object"); + } + } + + if (json.isMember("parallel_tool_calls") && + json["parallel_tool_calls"].isBool()) { + dto.parallel_tool_calls = json["parallel_tool_calls"].asBool(); + } + + if (json.isMember("response_format")) { + const auto& response_format = json["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } +}; +} // namespace dto diff --git a/engine/common/dto/run_update_dto.h b/engine/common/dto/run_update_dto.h new file mode 100644 index 000000000..31bc3bc20 --- /dev/null +++ b/engine/common/dto/run_update_dto.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include "common/dto/base_dto.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct RunUpdateDto : public BaseDto { + RunUpdateDto() = default; + + ~RunUpdateDto() = default; + + RunUpdateDto(const RunUpdateDto&) = delete; + + RunUpdateDto& operator=(const RunUpdateDto&) = delete; + + RunUpdateDto(RunUpdateDto&& other) noexcept + : metadata{std::move(other.metadata)} {} + + RunUpdateDto& operator=(RunUpdateDto&& other) noexcept { + if (this != &other) { + metadata = std::move(other.metadata); + } + return *this; + } + + /** + * Set of 16 key-value pairs that can be attached to an object. + * This can be useful for storing additional information about the object + * in a structured format. Keys can be a maximum of 64 characters long + * and values can be a maximum of 512 characters long. + */ + std::optional metadata; + + cpp::result Validate() const { + if (!metadata.has_value()) { + return cpp::fail("Nothing to update"); + } + + return {}; + } + + static cpp::result FromJson(Json::Value&& json) { + try { + RunUpdateDto dto; + + // Parse metadata + if (json.isMember("metadata") && json["metadata"].isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_value()) { + dto.metadata = res.value(); + } else { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } + } + + return dto; + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } +}; +} // namespace dto diff --git a/engine/common/events/assistant_stream_event.h b/engine/common/events/assistant_stream_event.h new file mode 100644 index 000000000..958b88fd1 --- /dev/null +++ b/engine/common/events/assistant_stream_event.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include "utils/result.hpp" + +namespace OpenAi { +/** + * Represents an event emitted when streaming a Run. + * + * Each event in a server-sent events stream has an event and data property: + * event: thread.created + * data: {"id": "thread_123", "object": "thread", ...} + * + * We emit events whenever a new object is created, transitions to a new state, + * or is being streamed in parts (deltas). + * + * For example, we emit thread.run.created when a new run is created, + * thread.run.completed when a run completes, and so on. When an Assistant chooses + * to create a message during a run, we emit a thread.message.created event, + * a thread.message.in_progress event, many thread.message.delta events, + * and finally a thread.message.completed event. + */ +struct AssistantStreamEvent { + AssistantStreamEvent(const std::string& event) : event{std::move(event)} {} + + AssistantStreamEvent(const AssistantStreamEvent&) = delete; + + AssistantStreamEvent& operator=(const AssistantStreamEvent&) = delete; + + AssistantStreamEvent(AssistantStreamEvent&& other) noexcept + : event{std::move(other.event)} {} + + AssistantStreamEvent& operator=(AssistantStreamEvent&& other) noexcept { + if (this != &other) { + event = std::move(other.event); + } + return *this; + } + + virtual ~AssistantStreamEvent() = default; + + std::string event; + + virtual auto SingleLineJsonData() const + -> cpp::result = 0; + + auto ToEvent() const -> cpp::result { + auto data = SingleLineJsonData(); + if (data.has_error()) { + return cpp::fail(data.error()); + } + return "event: " + event + "\n" + "data: " + data.value() + "\n\n"; + } +}; +} // namespace OpenAi diff --git a/engine/common/events/done.h b/engine/common/events/done.h new file mode 100644 index 000000000..b93ee005f --- /dev/null +++ b/engine/common/events/done.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include "common/events/assistant_stream_event.h" + +namespace OpenAi { +struct DoneEvent : public AssistantStreamEvent { + DoneEvent() : AssistantStreamEvent{"done"} {} + + ~DoneEvent() = default; + + std::string data{"[DONE]"}; + + auto SingleLineJsonData() const + -> cpp::result override { + return data; + } +}; +} // namespace OpenAi diff --git a/engine/common/events/error.h b/engine/common/events/error.h new file mode 100644 index 000000000..2069fac34 --- /dev/null +++ b/engine/common/events/error.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include "common/events/assistant_stream_event.h" +#include "utils/result.hpp" + +namespace OpenAi { +struct ErrorEvent : public AssistantStreamEvent { + ErrorEvent(const std::string& error) + : error{std::move(error)}, AssistantStreamEvent("error") {} + + ~ErrorEvent() = default; + + std::string error; + + auto SingleLineJsonData() const + -> cpp::result override { + Json::Value json; + json["error"] = error; + Json::FastWriter writer; + try { + std::string json_str = writer.write(json); + if (!json_str.empty() && json_str.back() == '\n') { + json_str.pop_back(); + } + return json_str; + } catch (const std::exception& e) { + return cpp::fail(std::string("Failed to write JSON: ") + e.what()); + } + } +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_created.h b/engine/common/events/thread_created.h new file mode 100644 index 000000000..438214f1c --- /dev/null +++ b/engine/common/events/thread_created.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/thread.h" + +namespace OpenAi { +struct ThreadRunCreatedEvent : public AssistantStreamEvent { + std::string event{"thread.created"}; + + Thread thread; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_message_completed.h b/engine/common/events/thread_message_completed.h new file mode 100644 index 000000000..c8e1b6c22 --- /dev/null +++ b/engine/common/events/thread_message_completed.h @@ -0,0 +1,20 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" + +namespace OpenAi { +struct ThreadMessageCompletedEvent : public AssistantStreamEvent { + explicit ThreadMessageCompletedEvent(const std::string& json_message) + : AssistantStreamEvent("thread.message.completed"), + json_message{std::move(json_message)} {} + + ~ThreadMessageCompletedEvent() = default; + + std::string json_message; + + auto SingleLineJsonData() const + -> cpp::result override { + return json_message; + } +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_message_created.h b/engine/common/events/thread_message_created.h new file mode 100644 index 000000000..b0e0c7204 --- /dev/null +++ b/engine/common/events/thread_message_created.h @@ -0,0 +1,20 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" + +namespace OpenAi { +struct ThreadMessageCreatedEvent : public AssistantStreamEvent { + ThreadMessageCreatedEvent(const std::string& json_message) + : AssistantStreamEvent("thread.message.created"), + json_message{std::move(json_message)} {} + + ~ThreadMessageCreatedEvent() = default; + + std::string json_message; + + auto SingleLineJsonData() const + -> cpp::result override { + return json_message; + } +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_message_delta.h b/engine/common/events/thread_message_delta.h new file mode 100644 index 000000000..d7ddc6c97 --- /dev/null +++ b/engine/common/events/thread_message_delta.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include "common/events/assistant_stream_event.h" +#include "common/message_delta.h" + +namespace OpenAi { +struct ThreadMessageDeltaEvent : public AssistantStreamEvent { + explicit ThreadMessageDeltaEvent(MessageDelta::Delta&& delta_msg) + : AssistantStreamEvent("thread.message.delta"), + delta{std::move(delta_msg)} {} + + ThreadMessageDeltaEvent(const ThreadMessageDeltaEvent&) = delete; + + ThreadMessageDeltaEvent& operator=(const ThreadMessageDeltaEvent&) = delete; + + ThreadMessageDeltaEvent(ThreadMessageDeltaEvent&& other) noexcept + : AssistantStreamEvent(std::move(other)), delta{std::move(other.delta)} {} + + ThreadMessageDeltaEvent& operator=(ThreadMessageDeltaEvent&& other) noexcept { + if (this != &other) { + AssistantStreamEvent::operator=(std::move(other)); + delta = std::move(other.delta); + } + return *this; + } + + ~ThreadMessageDeltaEvent() = default; + + MessageDelta delta; + + auto SingleLineJsonData() const + -> cpp::result override { + try { + auto delta_json = delta.ToJson(); + if (!delta_json.has_value()) { + return cpp::fail("Failed to convert delta to JSON"); + } + + Json::FastWriter writer; + std::string json_str = writer.write(delta_json.value()); + if (!json_str.empty() && json_str.back() == '\n') { + json_str.pop_back(); + } + return json_str; + } catch (const std::exception& e) { + return cpp::fail(std::string("Failed to write JSON: ") + e.what()); + } + } +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_message_in_progress.h b/engine/common/events/thread_message_in_progress.h new file mode 100644 index 000000000..07c7751d3 --- /dev/null +++ b/engine/common/events/thread_message_in_progress.h @@ -0,0 +1,20 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" + +namespace OpenAi { +struct ThreadMessageInProgressEvent : public AssistantStreamEvent { + ThreadMessageInProgressEvent(const std::string& json_message) + : AssistantStreamEvent("thread.message.in_progress"), + json_message{std::move(json_message)} {} + + ~ThreadMessageInProgressEvent() = default; + + std::string json_message; + + auto SingleLineJsonData() const + -> cpp::result override { + return json_message; + } +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_message_incomplete.h b/engine/common/events/thread_message_incomplete.h new file mode 100644 index 000000000..4d0eb863e --- /dev/null +++ b/engine/common/events/thread_message_incomplete.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/message.h" + +namespace OpenAi { +struct ThreadMessageIncompleteEvent : public AssistantStreamEvent { + std::string event{"thread.message.incomplete"}; + + Message message; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_cancelled.h b/engine/common/events/thread_run_cancelled.h new file mode 100644 index 000000000..f1f868bef --- /dev/null +++ b/engine/common/events/thread_run_cancelled.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunCancelledEvent : public AssistantStreamEvent { + std::string event{"thread.run.cancelled"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_cancelling.h b/engine/common/events/thread_run_cancelling.h new file mode 100644 index 000000000..19f118cfd --- /dev/null +++ b/engine/common/events/thread_run_cancelling.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunCancellingEvent : public AssistantStreamEvent { + std::string event{"thread.run.cancelling"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_completed.h b/engine/common/events/thread_run_completed.h new file mode 100644 index 000000000..3c47b3957 --- /dev/null +++ b/engine/common/events/thread_run_completed.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunCompletedEvent : public AssistantStreamEvent { + std::string event{"thread.run.completed"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_created_event.h b/engine/common/events/thread_run_created_event.h new file mode 100644 index 000000000..2d5fef120 --- /dev/null +++ b/engine/common/events/thread_run_created_event.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunCreatedEvent : public AssistantStreamEvent { + std::string event{"thread.run.created"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_expired.h b/engine/common/events/thread_run_expired.h new file mode 100644 index 000000000..69592c0dc --- /dev/null +++ b/engine/common/events/thread_run_expired.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunExpiredEvent : public AssistantStreamEvent { + std::string event{"thread.run.expired"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_failed.h b/engine/common/events/thread_run_failed.h new file mode 100644 index 000000000..aff6506cf --- /dev/null +++ b/engine/common/events/thread_run_failed.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunFailedEvent : public AssistantStreamEvent { + std::string event{"thread.run.failed"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_in_progress.h b/engine/common/events/thread_run_in_progress.h new file mode 100644 index 000000000..fe3a98da3 --- /dev/null +++ b/engine/common/events/thread_run_in_progress.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunInProgressEvent : public AssistantStreamEvent { + std::string event{"thread.run.in_progress"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_incomplete.h b/engine/common/events/thread_run_incomplete.h new file mode 100644 index 000000000..f75791007 --- /dev/null +++ b/engine/common/events/thread_run_incomplete.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunIncompleteEvent : public AssistantStreamEvent { + std::string event{"thread.run.incomplete"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_queued.h b/engine/common/events/thread_run_queued.h new file mode 100644 index 000000000..cbe834080 --- /dev/null +++ b/engine/common/events/thread_run_queued.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunQueuedEvent : public AssistantStreamEvent { + std::string event{"thread.run.queued"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_requires_action.h b/engine/common/events/thread_run_requires_action.h new file mode 100644 index 000000000..6ba073f98 --- /dev/null +++ b/engine/common/events/thread_run_requires_action.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run.h" + +namespace OpenAi { +struct ThreadRunRequiresActionEvent : public AssistantStreamEvent { + std::string event{"thread.run.requires_action"}; + + Run run; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_step_cancelled.h b/engine/common/events/thread_run_step_cancelled.h new file mode 100644 index 000000000..1eb24f15d --- /dev/null +++ b/engine/common/events/thread_run_step_cancelled.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run_step.h" + +namespace OpenAi { +struct ThreadRunStepCancelledEvent : public AssistantStreamEvent { + std::string event{"thread.run.step.cancelled"}; + + RunStep run_step; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_step_completed.h b/engine/common/events/thread_run_step_completed.h new file mode 100644 index 000000000..2df6b9aac --- /dev/null +++ b/engine/common/events/thread_run_step_completed.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run_step.h" + +namespace OpenAi { +struct ThreadRunStepCompletedEvent : public AssistantStreamEvent { + std::string event{"thread.run.step.completed"}; + + RunStep run_step; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_step_created.h b/engine/common/events/thread_run_step_created.h new file mode 100644 index 000000000..3402bb022 --- /dev/null +++ b/engine/common/events/thread_run_step_created.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run_step.h" + +namespace OpenAi { +struct ThreadRunStepCreatedEvent : public AssistantStreamEvent { + std::string event{"thread.run.step.created"}; + + RunStep run_step; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_step_delta.h b/engine/common/events/thread_run_step_delta.h new file mode 100644 index 000000000..ae6673ec7 --- /dev/null +++ b/engine/common/events/thread_run_step_delta.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run_step_delta.h" + +namespace OpenAi { +struct ThreadRunStepDeltaEvent : public AssistantStreamEvent { + std::string event{"thread.run.step.delta"}; + + RunStepDelta delta; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_step_expired.h b/engine/common/events/thread_run_step_expired.h new file mode 100644 index 000000000..f24b2f195 --- /dev/null +++ b/engine/common/events/thread_run_step_expired.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run_step.h" + +namespace OpenAi { +struct ThreadRunStepExpiredEvent : public AssistantStreamEvent { + std::string event{"thread.run.step.expired"}; + + RunStep run_step; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_step_failed.h b/engine/common/events/thread_run_step_failed.h new file mode 100644 index 000000000..c14db3060 --- /dev/null +++ b/engine/common/events/thread_run_step_failed.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run_step.h" + +namespace OpenAi { +struct ThreadRunStepFailedEvent : public AssistantStreamEvent { + std::string event{"thread.run.step.failed"}; + + RunStep run_step; +}; +} // namespace OpenAi diff --git a/engine/common/events/thread_run_step_in_progress.h b/engine/common/events/thread_run_step_in_progress.h new file mode 100644 index 000000000..bfd01965b --- /dev/null +++ b/engine/common/events/thread_run_step_in_progress.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common/events/assistant_stream_event.h" +#include "common/run_step.h" + +namespace OpenAi { +struct ThreadRunStepInProgressEvent : public AssistantStreamEvent { + std::string event{"thread.run.step.in_progress"}; + + RunStep run_step; +}; +} // namespace OpenAi diff --git a/engine/common/file.h b/engine/common/file.h index 3096023c5..379a9f74b 100644 --- a/engine/common/file.h +++ b/engine/common/file.h @@ -55,7 +55,7 @@ struct File : public JsonSerializable { return file; } - cpp::result ToJson() { + cpp::result ToJson() const override { Json::Value root; root["id"] = id; diff --git a/engine/common/json_serializable.h b/engine/common/json_serializable.h index 4afec92c5..9d8597562 100644 --- a/engine/common/json_serializable.h +++ b/engine/common/json_serializable.h @@ -5,7 +5,7 @@ struct JsonSerializable { - virtual cpp::result ToJson() = 0; + virtual cpp::result ToJson() const = 0; virtual ~JsonSerializable() = default; }; diff --git a/engine/common/last_error.h b/engine/common/last_error.h new file mode 100644 index 000000000..b28c0480c --- /dev/null +++ b/engine/common/last_error.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { + +enum class LastErrorType { SERVER_ERROR, RATE_LIMIT_EXCEEDED, INVALID_PROMPT }; + +inline std::string LastErrorTypeToString(LastErrorType error_type) { + switch (error_type) { + case LastErrorType::SERVER_ERROR: + return "server_error"; + case LastErrorType::RATE_LIMIT_EXCEEDED: + return "rate_limit_exceeded"; + case LastErrorType::INVALID_PROMPT: + return "invalid_prompt"; + default: + return "unknown error type: #" + + std::to_string(static_cast(error_type)); + } +} + +inline LastErrorType LastErrorTypeFromString(const std::string& input) { + if (input == "server_error") { + return LastErrorType::SERVER_ERROR; + } else if (input == "rate_limit_exceeded") { + return LastErrorType::RATE_LIMIT_EXCEEDED; + } else if (input == "invalid_prompt") { + return LastErrorType::INVALID_PROMPT; + } else { + return LastErrorType::SERVER_ERROR; + } +} + +struct LastError : public JsonSerializable { + LastErrorType code; + + std::string message; + + static cpp::result FromJsonString( + std::string&& json_str) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(json_str, root)) { + return cpp::fail("Failed to parse JSON: " + + reader.getFormattedErrorMessages()); + } + + LastError last_error; + + try { + last_error.message = std::move(root.get("message", "").asString()); + last_error.code = + LastErrorTypeFromString(std::move(root.get("code", "").asString())); + + return last_error; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJsonString failed: ") + e.what()); + } + } + + cpp::result ToJson() const override { + Json::Value root; + root["code"] = static_cast(code); + root["message"] = message; + return root; + } +}; +} // namespace OpenAi diff --git a/engine/common/message.h b/engine/common/message.h index d31c4f0d3..dd161d771 100644 --- a/engine/common/message.h +++ b/engine/common/message.h @@ -36,11 +36,13 @@ inline std::string ExtractFileId(const std::string& path) { // Represents a message within a thread. struct Message : JsonSerializable { Message() = default; - // Delete copy operations + Message(const Message&) = delete; + Message& operator=(const Message&) = delete; - // Allow move operations + Message(Message&&) = default; + Message& operator=(Message&&) = default; // The identifier, which can be referenced in API endpoints. @@ -208,21 +210,30 @@ struct Message : JsonSerializable { } } - cpp::result ToSingleLineJsonString() { - auto json_result = ToJson(); - if (json_result.has_error()) { - return cpp::fail(json_result.error()); + cpp::result ToSingleLineJsonString( + bool add_new_line = true) const { + auto json = ToJson(); + if (json.has_error()) { + return cpp::fail(json.error()); } Json::FastWriter writer; try { - return writer.write(json_result.value()); + if (add_new_line) { + return writer.write(json.value()); + } else { + auto json_str = writer.write(json.value()); + if (!json_str.empty() && json_str.back() == '\n') { + json_str.pop_back(); + } + return json_str; + } } catch (const std::exception& e) { return cpp::fail(std::string("Failed to write JSON: ") + e.what()); } } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; @@ -295,5 +306,20 @@ struct Message : JsonSerializable { return cpp::fail(std::string("ToJson failed: ") + e.what()); } } + + static cpp::result Clone(const Message& original) { + // First convert to JSON + auto json_result = original.ToJson(); + if (json_result.has_error()) { + return cpp::fail("Failed to convert to JSON: " + json_result.error()); + } + + // Convert JSON back to string + Json::FastWriter writer; + std::string json_str = writer.write(json_result.value()); + + // Create new Message from JSON string + return Message::FromJsonString(std::move(json_str)); + } }; }; // namespace OpenAi diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h index 6a0fb02e9..06ce4c9bc 100644 --- a/engine/common/message_attachment.h +++ b/engine/common/message_attachment.h @@ -35,7 +35,7 @@ struct Attachment : JsonSerializable { std::vector tools; - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["file_id"] = file_id; diff --git a/engine/common/message_content_image_file.h b/engine/common/message_content_image_file.h index c3ec57853..f4cec69bc 100644 --- a/engine/common/message_content_image_file.h +++ b/engine/common/message_content_image_file.h @@ -54,7 +54,7 @@ struct ImageFileContent : Content { } } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["type"] = type; diff --git a/engine/common/message_content_image_url.h b/engine/common/message_content_image_url.h index 336cf01d3..0670068e3 100644 --- a/engine/common/message_content_image_url.h +++ b/engine/common/message_content_image_url.h @@ -28,7 +28,7 @@ struct ImageUrl : public JsonSerializable { ImageUrl& operator=(const ImageUrl&) = delete; - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value root; root["url"] = url; @@ -75,7 +75,7 @@ struct ImageUrlContent : Content { } } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["type"] = type; diff --git a/engine/common/message_content_refusal.h b/engine/common/message_content_refusal.h index c2537ccbf..dc2efe4ca 100644 --- a/engine/common/message_content_refusal.h +++ b/engine/common/message_content_refusal.h @@ -32,7 +32,7 @@ struct Refusal : Content { } } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["type"] = type; diff --git a/engine/common/message_content_text.h b/engine/common/message_content_text.h index 5ede2582d..50ae6265a 100644 --- a/engine/common/message_content_text.h +++ b/engine/common/message_content_text.h @@ -58,7 +58,7 @@ struct FileCitationWrapper : Annotation { FileCitation file_citation; - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["text"] = text; @@ -105,7 +105,7 @@ struct FilePathWrapper : Annotation { FilePath file_path; - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["text"] = text; @@ -124,14 +124,18 @@ struct Text : JsonSerializable { // The data that makes up the text. Text() = default; - Text(Text&&) noexcept = default; + Text(const std::string& text_content) : value{text_content} {} - Text& operator=(Text&&) noexcept = default; + ~Text() = default; Text(const Text&) = delete; Text& operator=(const Text&) = delete; + Text(Text&&) noexcept = default; + + Text& operator=(Text&&) noexcept = default; + std::string value; std::vector> annotations; @@ -178,7 +182,7 @@ struct Text : JsonSerializable { } } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["value"] = value; @@ -192,7 +196,7 @@ struct Text : JsonSerializable { } json["annotations"] = annotations_json_arr; return json; - } catch (const std::exception e) { + } catch (const std::exception& e) { return cpp::fail(std::string("ToJson failed: ") + e.what()); } }; @@ -203,6 +207,10 @@ struct TextContent : Content { // Always text. TextContent() : Content("text") {} + TextContent(const std::string& text_value) : Content("text") { + text = Text(text_value); + } + TextContent(TextContent&&) noexcept = default; TextContent& operator=(TextContent&&) noexcept = default; @@ -229,7 +237,7 @@ struct TextContent : Content { } } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["type"] = type; diff --git a/engine/common/message_delta.h b/engine/common/message_delta.h new file mode 100644 index 000000000..612adeea6 --- /dev/null +++ b/engine/common/message_delta.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include "common/message_content.h" +#include "common/message_role.h" +#include "utils/logging_utils.h" + +namespace OpenAi { +struct MessageDelta : public JsonSerializable { + struct Delta : public JsonSerializable { + Delta(Role role, std::vector> content) + : role{role}, content{std::move(content)} {}; + + Delta(const Delta&) = delete; + + Delta& operator=(const Delta&) = delete; + + Delta(Delta&& other) noexcept + : role{other.role}, content{std::move(other.content)} {} + + Delta& operator=(Delta&& other) noexcept { + if (this != &other) { + role = other.role; + content = std::move(other.content); + } + return *this; + } + + ~Delta() = default; + + Role role; + + std::vector> content; + + cpp::result ToJson() const override { + Json::Value json; + json["role"] = RoleToString(role); + Json::Value content_json_arr{Json::arrayValue}; + for (auto& child_content : content) { + if (auto it = child_content->ToJson(); it.has_value()) { + content_json_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert content to json: " + it.error()); + } + } + json["content"] = content_json_arr; + return json; + } + }; + + explicit MessageDelta(Delta&& delta) : delta{std::move(delta)} {} + + MessageDelta(const MessageDelta&) = delete; + + MessageDelta& operator=(const MessageDelta&) = delete; + + MessageDelta(MessageDelta&& other) noexcept + : id{std::move(other.id)}, + object{std::move(other.object)}, + delta{std::move(other.delta)} {} + + MessageDelta& operator=(MessageDelta&& other) noexcept { + if (this != &other) { + id = std::move(other.id); + object = std::move(other.object); + delta = std::move(other.delta); + } + return *this; + } + + ~MessageDelta() = default; + + /** + * The identifier of the message, which can be referenced in API endpoints. + */ + std::string id; + + /** + * The object type, which is always thread.message.delta. + */ + std::string object{"thread.message.delta"}; + + /** + * The delta containing the fields that have changed on the Message. + */ + Delta delta; + + auto ToJson() const -> cpp::result override { + Json::Value json; + json["id"] = id; + json["object"] = object; + json["delta"] = delta.ToJson().value(); + return json; + } +}; +} // namespace OpenAi diff --git a/engine/common/message_incomplete_detail.h b/engine/common/message_incomplete_detail.h index 98e6ff56b..bb84bd596 100644 --- a/engine/common/message_incomplete_detail.h +++ b/engine/common/message_incomplete_detail.h @@ -1,25 +1,40 @@ #pragma once +#include #include "common/json_serializable.h" namespace OpenAi { // On an incomplete message, details about why the message is incomplete. struct IncompleteDetail : JsonSerializable { - // The reason the message is incomplete. + IncompleteDetail(const std::string& reason) : reason{reason} {} + + /** + * The reason the message is incomplete. + */ std::string reason; + static cpp::result FromJsonString( + std::string&& json_str) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(json_str, root)) { + return cpp::fail("Failed to parse JSON: " + + reader.getFormattedErrorMessages()); + } + + return IncompleteDetail(root.get("reason", "").asString()); + } + static cpp::result, std::string> FromJson( Json::Value&& json) { if (json.empty()) { return std::nullopt; } - IncompleteDetail incomplete_detail; - incomplete_detail.reason = json["reason"].asString(); - return incomplete_detail; + return IncompleteDetail(json["reason"].asString()); } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; json["reason"] = reason; diff --git a/engine/common/repository/assistant_repository.h b/engine/common/repository/assistant_repository.h index d0ff1908d..c3bc351f9 100644 --- a/engine/common/repository/assistant_repository.h +++ b/engine/common/repository/assistant_repository.h @@ -10,7 +10,7 @@ class AssistantRepository { const std::string& after, const std::string& before) const = 0; virtual cpp::result CreateAssistant( - OpenAi::Assistant& assistant) = 0; + const OpenAi::Assistant& assistant) = 0; virtual cpp::result RetrieveAssistant( const std::string assistant_id) const = 0; diff --git a/engine/common/repository/message_repository.h b/engine/common/repository/message_repository.h index a8a971fd8..e0285fb49 100644 --- a/engine/common/repository/message_repository.h +++ b/engine/common/repository/message_repository.h @@ -6,7 +6,7 @@ class MessageRepository { public: virtual cpp::result CreateMessage( - OpenAi::Message& message) = 0; + const OpenAi::Message& message) = 0; virtual cpp::result, std::string> ListMessages( const std::string& thread_id, uint8_t limit, const std::string& order, @@ -17,7 +17,7 @@ class MessageRepository { const std::string& thread_id, const std::string& message_id) const = 0; virtual cpp::result ModifyMessage( - OpenAi::Message& message) = 0; + const OpenAi::Message& message) = 0; virtual cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id) = 0; diff --git a/engine/common/repository/run_repository.h b/engine/common/repository/run_repository.h new file mode 100644 index 000000000..fed1f26d5 --- /dev/null +++ b/engine/common/repository/run_repository.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common/run.h" +#include "utils/result.hpp" + +class RunRepository { + public: + virtual cpp::result CreateRun(const OpenAi::Run& run) = 0; + + virtual cpp::result, std::string> ListRuns( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const = 0; + + virtual cpp::result RetrieveRun( + const std::string& run_id) const = 0; + + virtual cpp::result ModifyRun(OpenAi::Run& run) = 0; + + virtual cpp::result DeleteRun( + const std::string& run_id) = 0; + + virtual ~RunRepository() = default; +}; diff --git a/engine/common/repository/thread_repository.h b/engine/common/repository/thread_repository.h index c7bb9e7cf..0cc867a67 100644 --- a/engine/common/repository/thread_repository.h +++ b/engine/common/repository/thread_repository.h @@ -6,7 +6,7 @@ class ThreadRepository { public: virtual cpp::result CreateThread( - OpenAi::Thread& thread) = 0; + const OpenAi::Thread& thread) = 0; virtual cpp::result, std::string> ListThreads( uint8_t limit, const std::string& order, const std::string&, diff --git a/engine/common/required_action.h b/engine/common/required_action.h new file mode 100644 index 000000000..de7be3970 --- /dev/null +++ b/engine/common/required_action.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace OpenAi { +struct ToolCall { + /** + * The ID of the tool call. This ID must be referenced when you submit + * the tool outputs in using the Submit tool outputs to run endpoint. + */ + std::string id; + + /** + * The type of tool call the output is required for. For now, this is + * always function. + */ + std::string type{"function"}; + + // function TODO: NamH implement this +}; + +struct SubmitToolOutputs { + std::vector tool_calls; +}; + +struct RequiredAction { + std::string type{"submit_tool_outputs"}; + + SubmitToolOutputs submit_tool_outputs; +}; +} // namespace OpenAi diff --git a/engine/common/run.h b/engine/common/run.h new file mode 100644 index 000000000..dc28dfd6d --- /dev/null +++ b/engine/common/run.h @@ -0,0 +1,343 @@ +#pragma once + +#include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" +#include "common/assistant_tool.h" +#include "common/json_serializable.h" +#include "common/last_error.h" +#include "common/message_incomplete_detail.h" +#include "common/required_action.h" +#include "common/run_usage.h" +#include "common/truncation_strategy.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" + +namespace OpenAi { + +enum class RunStatus { + QUEUED, + IN_PROGRESS, + REQUIRES_ACTION, + CANCELLING, + CANCELLED, + FAILED, + COMPLETED, + INCOMPLETE, + EXPIRED +}; + +inline std::string RunStatusToString(RunStatus status) { + switch (status) { + case RunStatus::QUEUED: + return "queued"; + case RunStatus::IN_PROGRESS: + return "in_progress"; + case RunStatus::REQUIRES_ACTION: + return "requires_action"; + case RunStatus::CANCELLING: + return "cancelling"; + case RunStatus::CANCELLED: + return "cancelled"; + case RunStatus::FAILED: + return "failed"; + case RunStatus::COMPLETED: + return "completed"; + case RunStatus::INCOMPLETE: + return "incomplete"; + case RunStatus::EXPIRED: + return "expired"; + default: + return "completed"; + } +} + +inline RunStatus RunStatusFromString(const std::string& input) { + if (string_utils::EqualsIgnoreCase(input, "queued")) { + return RunStatus::QUEUED; + } else if (string_utils::EqualsIgnoreCase(input, "in_progress")) { + return RunStatus::IN_PROGRESS; + } else if (string_utils::EqualsIgnoreCase(input, "requires_action")) { + return RunStatus::REQUIRES_ACTION; + } else if (string_utils::EqualsIgnoreCase(input, "cancelling")) { + return RunStatus::CANCELLING; + } else if (string_utils::EqualsIgnoreCase(input, "cancelled")) { + return RunStatus::CANCELLED; + } else if (string_utils::EqualsIgnoreCase(input, "failed")) { + return RunStatus::FAILED; + } else if (string_utils::EqualsIgnoreCase(input, "incomplete")) { + return RunStatus::INCOMPLETE; + } else if (string_utils::EqualsIgnoreCase(input, "expired")) { + return RunStatus::EXPIRED; + } else { + return RunStatus::COMPLETED; + } +} + +struct Run : public JsonSerializable { + Run() = default; + + ~Run() = default; + + Run(const Run&) = delete; + + Run& operator=(const Run&) = delete; + + Run(Run&& other) noexcept + : id{std::move(other.id)}, + object{std::move(other.object)}, + created_at{other.created_at}, + thread_id{std::move(other.thread_id)}, + assistant_id{std::move(other.assistant_id)}, + status{std::move(other.status)}, + required_action{std::move(other.required_action)}, + last_error{std::move(other.last_error)}, + expired_at{other.expired_at}, + started_at{other.started_at}, + cancelled_at{other.cancelled_at}, + failed_at{other.failed_at}, + completed_at{other.completed_at}, + incomplete_detail{std::move(other.incomplete_detail)}, + model{std::move(other.model)}, + instructions{std::move(other.instructions)}, + tools{std::move(other.tools)}, + metadata{std::move(other.metadata)}, + usage{std::move(other.usage)}, + temperature{other.temperature}, + top_p{other.top_p}, + max_prompt_tokens{other.max_prompt_tokens}, + max_completion_tokens{other.max_completion_tokens}, + truncation_strategy{std::move(other.truncation_strategy)}, + tool_choice{std::move(other.tool_choice)}, + parallel_tool_calls{other.parallel_tool_calls}, + response_format{std::move(other.response_format)} {} + + Run& operator=(Run&& other) noexcept { + if (this != &other) { + id = std::move(other.id); + object = std::move(other.object); + created_at = other.created_at; + thread_id = std::move(other.thread_id); + assistant_id = std::move(other.assistant_id); + status = std::move(other.status); + required_action = std::move(other.required_action); + last_error = std::move(other.last_error); + expired_at = other.expired_at; + started_at = other.started_at; + cancelled_at = other.cancelled_at; + failed_at = other.failed_at; + completed_at = other.completed_at; + incomplete_detail = std::move(other.incomplete_detail); + model = std::move(other.model); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + metadata = std::move(other.metadata); + usage = std::move(other.usage); + temperature = other.temperature; + top_p = other.top_p; + max_prompt_tokens = other.max_prompt_tokens; + max_completion_tokens = other.max_completion_tokens; + truncation_strategy = std::move(other.truncation_strategy); + tool_choice = std::move(other.tool_choice); + parallel_tool_calls = other.parallel_tool_calls; + response_format = std::move(other.response_format); + } + return *this; + } + + /** + * The identifier, which can be referenced in API endpoints. + */ + std::string id; + + /** + * The object type, which is always thread.run. + */ + std::string object{"thread.run"}; + + uint32_t created_at; + + std::string thread_id; + + std::string assistant_id; + + RunStatus status; + + /** + * Details on the action required to continue the run. Will be null if no + * action is required. + */ + std::optional required_action; + + /** + * The last error associated with this run. Will be null if there are no errors. + */ + std::optional last_error{std::nullopt}; + + /** + * The Unix timestamp (in seconds) for when the run will expire. + */ + uint32_t expired_at; + + /** + * The Unix timestamp (in seconds) for when the run was started. + */ + uint32_t started_at; + + /** + * The Unix timestamp (in seconds) for when the run was cancelled. + */ + uint32_t cancelled_at; + + /** + * The Unix timestamp (in seconds) for when the run failed. + */ + uint32_t failed_at; + + /** + * The Unix timestamp (in seconds) for when the run was completed. + */ + uint32_t completed_at; + + /** + * Details on why the run is incomplete. Will be null if the run is + * not incomplete. + */ + std::optional incomplete_detail; + + /** + * The model that the assistant used for this run. + */ + std::string model; + + std::string instructions; + + std::vector> tools; + + /** + * Set of 16 key-value pairs that can be attached to an object. This can + * be useful for storing additional information about the object in a + * structured format. Keys can be a maximum of 64 characters long and + * values can be a maximum of 512 characters long. + */ + Cortex::VariantMap metadata; + + /** + * Usage statistics related to the run. This value will be null if the run + * is not in a terminal state (i.e. in_progress, queued, etc.). + */ + std::optional usage; + + /** + * What sampling temperature to use, between 0 and 2. Higher values like + * 0.8 will make the output more random, while lower values like 0.2 will + * make it more focused and deterministic. + */ + std::optional temperature; + + /** + * An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p + * probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. + * + * We generally recommend altering this or temperature but not both. + */ + std::optional top_p; + + /** + * The maximum number of prompt tokens specified to have been used over + * the course of the run. + */ + std::optional max_prompt_tokens; + + /** + * The maximum number of completion tokens specified to have been used over + * the course of the run. + */ + std::optional max_completion_tokens; + + TruncationStrategy truncation_strategy; + + std::variant tool_choice; + + bool parallel_tool_calls; + + std::variant response_format; + + cpp::result ToJson() const override { + Json::Value root; + // TODO: NamH implement this + return root; + } + + static std::vector> ToolsFromJsonString( + std::string&& json_str) { + Json::Value json; + Json::Reader reader; + std::vector> tools; + if (!reader.parse(json_str, json)) { + return tools; + } + + if (json.isMember("tools") && json["tools"].isArray()) { + auto tools_array = json["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + result.error()); + } + } else if (tool_type == "function") { + auto result = AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + + return tools; + } + + static cpp::result ToolsToJsonString( + const std::vector>& tools) { + Json::Value array(Json::arrayValue); + + for (const auto& tool : tools) { + auto json_result = tool->ToJson(); + if (json_result.has_error()) { + return cpp::fail("Failed to convert tool to JSON: " + + json_result.error()); + } + array.append(json_result.value()); + } + + return array.toStyledString(); + } +}; +} // namespace OpenAi diff --git a/engine/common/run_step.h b/engine/common/run_step.h new file mode 100644 index 000000000..c3c852a9a --- /dev/null +++ b/engine/common/run_step.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include "common/last_error.h" +#include "common/run_step_detail.h" +#include "common/run_usage.h" +#include "common/variant_map.h" + +namespace OpenAi { + +enum class RunStepStatus { IN_PROGRESS, CANCELLED, FAILED, COMPLETED, EXPIRED }; + +enum class RunStepType { MESSAGE_CREATION, TOOL_CALLS }; + +struct RunStep { + std::string id; + + std::string object{"thread.run.step"}; + + uint32_t created_at; + + std::string assistant_id; + + std::string thread_id; + + std::string run_id; + + RunStepType type; + + RunStepStatus status; + + /** + * The details of the run step. + */ + std::unique_ptr step_details; + + /** + * The last error associated with this run. Will be null if there are no errors. + */ + std::optional last_error{std::nullopt}; + + /** + * The Unix timestamp (in seconds) for when the run will expire. + */ + uint32_t expired_at; + + /** + * The Unix timestamp (in seconds) for when the run was cancelled. + */ + uint32_t cancelled_at; + + /** + * The Unix timestamp (in seconds) for when the run failed. + */ + uint32_t failed_at; + + /** + * The Unix timestamp (in seconds) for when the run was completed. + */ + uint32_t completed_at; + + /** + * Set of 16 key-value pairs that can be attached to an object. This can + * be useful for storing additional information about the object in a + * structured format. Keys can be a maximum of 64 characters long and + * values can be a maximum of 512 characters long. + */ + Cortex::VariantMap metadata; + + /** + * Usage statistics related to the run. This value will be null if the run + * is not in a terminal state (i.e. in_progress, queued, etc.). + */ + std::optional usage; +}; +} // namespace OpenAi diff --git a/engine/common/run_step_delta.h b/engine/common/run_step_delta.h new file mode 100644 index 000000000..4303f3fb0 --- /dev/null +++ b/engine/common/run_step_delta.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include "common/run_step_detail.h" + +namespace OpenAi { +struct RunStepDelta { + struct Delta { + /** + * The details of the run step. + */ + std::unique_ptr step_details; + }; + + /** + * The identifier of the run step, which can be referenced in API endpoints. + */ + std::string id; + + /** + * The object type, which is always thread.run.step.delta. + */ + std::string object{"thread.run.step.delta"}; + + /** + * The delta containing the fields that have changed on the run step. + */ + Delta delta; +}; +} // namespace OpenAi diff --git a/engine/common/run_step_detail.h b/engine/common/run_step_detail.h new file mode 100644 index 000000000..0d4eab4eb --- /dev/null +++ b/engine/common/run_step_detail.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace OpenAi { +struct RunStepDetails { + virtual ~RunStepDetails() = default; + + std::string type; +}; + +struct MessageCreationDetail : public RunStepDetails { + struct MessageCreation { + std::string message_id; + }; + + std::string type{"message_creation"}; + + MessageCreation message_creation; +}; + +struct ToolCalls : public RunStepDetails { + std::string type{"tool_calls"}; + // TODO: namh implement toolcalls later +}; +} // namespace OpenAi diff --git a/engine/common/run_usage.h b/engine/common/run_usage.h new file mode 100644 index 000000000..dac85eaad --- /dev/null +++ b/engine/common/run_usage.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { +struct RunUsage : public JsonSerializable { + + RunUsage() = default; + + ~RunUsage() = default; + + uint64_t completion_tokens; + + uint64_t prompt_tokens; + + uint64_t total_tokens; + + static cpp::result FromJsonString( + std::string&& json_str) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(json_str, root)) { + return cpp::fail("Failed to parse JSON: " + + reader.getFormattedErrorMessages()); + } + + RunUsage run_usage; + + try { + run_usage.completion_tokens = root["completion_tokens"].asUInt64(); + run_usage.prompt_tokens = root["prompt_tokens"].asUInt64(); + run_usage.total_tokens = root["total_tokens"].asUInt64(); + + return run_usage; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJsonString failed: ") + e.what()); + } + } + + cpp::result ToJson() const { + Json::Value json; + json["completion_tokens"] = completion_tokens; + json["prompt_tokens"] = prompt_tokens; + json["total_tokens"] = total_tokens; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/common/thread.h b/engine/common/thread.h index dc57ba32d..8026b83af 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -116,7 +116,7 @@ struct Thread : JsonSerializable { return thread; } - cpp::result ToJson() override { + cpp::result ToJson() const override { try { Json::Value json; @@ -130,7 +130,7 @@ struct Thread : JsonSerializable { if (it == metadata.end()) { json["title"] = ""; } else { - json["title"] = std::get(metadata["title"]); + json["title"] = std::get(it->second); } } catch (const std::bad_variant_access& ex) { diff --git a/engine/common/tool_choice.h b/engine/common/tool_choice.h new file mode 100644 index 000000000..72913accd --- /dev/null +++ b/engine/common/tool_choice.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include "utils/result.hpp" + +namespace OpenAi { +struct ToolChoice { + struct Function { + /** + * The name of the function to call. + */ + std::string name; + }; + + /** + * The type of the tool. Currently, only function is supported. + */ + std::string type; + + /** + * Specifies a tool the model should use. Use to force the model + * to call a specific function. + */ + Function function; + + static cpp::result FromJson(Json::Value&& json) { + try { + ToolChoice tool_choice; + if (json.isMember("type") && json["type"].isString()) { + tool_choice.type = json["type"].asString(); + } else { + return cpp::fail("Missing or invalid 'type' field"); + } + if (json.isMember("function") && json["function"].isObject()) { + auto function = json["function"]; + if (function.isMember("name") && function["name"].isString()) { + tool_choice.function.name = function["name"].asString(); + } else { + return cpp::fail("Missing or invalid 'name' field in 'function'"); + } + } else { + return cpp::fail("Missing or invalid 'function' field"); + } + return tool_choice; + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } +}; +} // namespace OpenAi diff --git a/engine/common/tool_resources.h b/engine/common/tool_resources.h index 5aadb3f8b..1fb2fa976 100644 --- a/engine/common/tool_resources.h +++ b/engine/common/tool_resources.h @@ -19,7 +19,7 @@ struct ToolResources : JsonSerializable { virtual ~ToolResources() = default; - virtual cpp::result ToJson() override = 0; + virtual cpp::result ToJson() const override = 0; }; struct CodeInterpreter : ToolResources { @@ -55,7 +55,7 @@ struct CodeInterpreter : ToolResources { return code_interpreter; } - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value json; Json::Value file_ids_json{Json::arrayValue}; for (auto& file_id : file_ids) { @@ -101,7 +101,7 @@ struct FileSearch : ToolResources { return file_search; } - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value json; Json::Value vector_store_ids_json{Json::arrayValue}; for (auto& vector_store_id : vector_store_ids) { diff --git a/engine/common/truncation_strategy.h b/engine/common/truncation_strategy.h new file mode 100644 index 000000000..5bd66eb5c --- /dev/null +++ b/engine/common/truncation_strategy.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include +#include "common/json_serializable.h" +#include "utils/result.hpp" + +namespace OpenAi { +struct TruncationStrategy : public JsonSerializable { + TruncationStrategy() = default; + + TruncationStrategy(const TruncationStrategy&) = delete; + + TruncationStrategy& operator=(const TruncationStrategy&) = delete; + + TruncationStrategy(TruncationStrategy&& other) noexcept + : type{std::move(other.type)}, + last_messages{std::move(other.last_messages)} {} + + TruncationStrategy& operator=(TruncationStrategy&& other) noexcept { + if (this != &other) { + type = std::move(other.type); + last_messages = std::move(other.last_messages); + } + return *this; + } + + /** + * The truncation strategy to use for the thread. The default is auto. + * If set to last_messages, the thread will be truncated to the n most + * recent messages in the thread. + * + * When set to auto, messages in the middle of the thread will be dropped + * to fit the context length of the model, max_prompt_tokens. + */ + std::string type{"auto"}; + + /** + * The number of most recent messages from the thread when constructing + * the context for the run. + */ + std::optional last_messages; + + static cpp::result FromJsonString( + std::string&& json_str) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(json_str, root)) { + return cpp::fail("Failed to parse JSON: " + + reader.getFormattedErrorMessages()); + } + + try { + return FromJson(std::move(root)); + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJsonString failed: ") + e.what()); + } + } + + static cpp::result FromJson( + Json::Value&& json) { + try { + + TruncationStrategy truncation_strategy; + if (json.isMember("type") && json["type"].isString()) { + truncation_strategy.type = json["type"].asString(); + } + if (json.isMember("last_messages") && json["last_messages"].isInt()) { + truncation_strategy.last_messages = json["last_messages"].asInt(); + } + return truncation_strategy; + + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } + + cpp::result ToJson() const { + Json::Value json; + json["type"] = type; + if (last_messages.has_value()) { + json["last_messages"] = last_messages.value(); + } + return json; + } +}; +} // namespace OpenAi diff --git a/engine/common/variant_map.h b/engine/common/variant_map.h index c8da77317..19fa23282 100644 --- a/engine/common/variant_map.h +++ b/engine/common/variant_map.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -59,4 +60,39 @@ inline cpp::result ConvertJsonValueToMap( return result; } + +inline cpp::result VariantMapFromJsonString( + std::string&& json_str) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(json_str, root)) { + return cpp::fail("Failed to parse JSON: " + + reader.getFormattedErrorMessages()); + } + return ConvertJsonValueToMap(root); +} + +inline cpp::result VariantMapToString( + const VariantMap& map) { + Json::Value root(Json::objectValue); + + for (const auto& [key, value] : map) { + std::visit( + [&root, &key](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + root[key] = arg; + } else if constexpr (std::is_same_v) { + root[key] = arg; + } else if constexpr (std::is_same_v) { + root[key] = arg; + } else if constexpr (std::is_same_v) { + root[key] = arg; + } + }, + value); + } + + return root.toStyledString(); +} }; // namespace Cortex diff --git a/engine/controllers/runs.cc b/engine/controllers/runs.cc new file mode 100644 index 000000000..6614596a2 --- /dev/null +++ b/engine/controllers/runs.cc @@ -0,0 +1,80 @@ +#include "runs.h" +#include +#include "common/events/assistant_stream_event.h" +#include "utils/cortex_utils.h" + +void Runs::CreateRun(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto run_create_dto = dto::RunCreateDto::FromJson(std::move(*json_body)); + if (run_create_dto.has_error()) { + Json::Value ret; + ret["message"] = run_create_dto.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + struct SharedState { + dto::RunCreateDto dto; + std::string thread_id; + std::shared_ptr stream; + + explicit SharedState(dto::RunCreateDto&& d) : dto(std::move(d)) {} + }; + + auto state = std::make_shared(std::move(run_create_dto.value())); + state->thread_id = thread_id; + + auto sendEvents = [this, state](ResponseStreamPtr res) { + state->stream = std::shared_ptr(std::move(res)); + + run_srv_->CreateRunStream( + std::move(state->dto), state->thread_id, + [state](const OpenAi::AssistantStreamEvent& event, bool disconnect) { + auto parse_event_res = event.ToEvent(); + if (parse_event_res.has_value()) { + state->stream->send(parse_event_res.value()); + if (disconnect) { + state->stream->close(); + } + } + }); + }; + + // TODO: namh add the date time + auto resp = HttpResponse::newAsyncStreamResponse(std::move(sendEvents)); + resp->setContentTypeString("text/event-stream"); + resp->addHeader("Cache-Control", "no-cache"); + resp->addHeader("Connection", "keep-alive"); + callback(resp); +} + +void Runs::RetrieveRun(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, + const std::string& run_id) {} + +void Runs::CancelRun(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& run_id) {} + +void Runs::ModifyRun(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& run_id) {} + +void Runs::SubmitToolOutput( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& run_id) {} diff --git a/engine/controllers/runs.h b/engine/controllers/runs.h new file mode 100644 index 000000000..cb2d7a88d --- /dev/null +++ b/engine/controllers/runs.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include "services/inference_service.h" +#include "services/run_service.h" + +using namespace drogon; + +class Runs : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + + ADD_METHOD_TO(Runs::CreateRun, "/v1/threads/{thread_id}/runs", Options, Post); + + ADD_METHOD_TO(Runs::ModifyRun, "/v1/threads/{thread_id}/runs/{run_id}", + Options, Post); + + ADD_METHOD_TO(Runs::CancelRun, "/v1/threads/{thread_id}/runs/{run_id}/cancel", + Options, Post); + + ADD_METHOD_TO(Runs::RetrieveRun, "/v1/threads/{thread_id}/runs/{run_id}", + Get); + + ADD_METHOD_TO(Runs::SubmitToolOutput, + "/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", + Options, Post); + METHOD_LIST_END + + explicit Runs(std::shared_ptr run_service, + std::shared_ptr inference_service) + : run_srv_{run_service}, inference_srv_{inference_service} {} + + void CreateRun(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id); + + void RetrieveRun(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& run_id); + + void CancelRun(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& run_id); + + void ModifyRun(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& run_id); + + void SubmitToolOutput(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, + const std::string& run_id); + + private: + std::shared_ptr run_srv_; + std::shared_ptr inference_srv_; + + void ProcessStreamRes(std::function cb, + std::shared_ptr q, + const std::string& engine_type, + const std::string& model_id); +}; diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 83eaddb4e..f3b4d2fdb 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -19,8 +19,6 @@ server::server(std::shared_ptr inference_service, #endif }; -server::~server() {} - void server::ChatCompletion( const HttpRequestPtr& req, std::function&& callback) { @@ -203,7 +201,6 @@ void server::RouteRequest( ProcessNonStreamRes(callback, *q); LOG_TRACE << "Done route request"; } - } void server::LoadModel(const HttpRequestPtr& req, diff --git a/engine/controllers/server.h b/engine/controllers/server.h index 42214a641..e183ea597 100644 --- a/engine/controllers/server.h +++ b/engine/controllers/server.h @@ -29,7 +29,7 @@ class server : public drogon::HttpController, public: server(std::shared_ptr inference_service, std::shared_ptr engine_service); - ~server(); + METHOD_LIST_BEGIN // list path definitions here; METHOD_ADD(server::ChatCompletion, "chat_completion", Options, Post); diff --git a/engine/controllers/swagger.cc b/engine/controllers/swagger.cc index abb80b94e..13a547454 100644 --- a/engine/controllers/swagger.cc +++ b/engine/controllers/swagger.cc @@ -2,7 +2,7 @@ #include "cortex_openapi.h" #include "utils/cortex_utils.h" -Json::Value SwaggerController::GenerateOpenApiSpec() const { +auto SwaggerController::GenerateOpenApiSpec() const -> Json::Value { Json::Value root; Json::Reader reader; reader.parse(CortexOpenApi::GetOpenApiJson(), root); diff --git a/engine/database/database.h b/engine/database/database.h index dbe58cc4b..a5c981cfa 100644 --- a/engine/database/database.h +++ b/engine/database/database.h @@ -1,7 +1,6 @@ #pragma once -#include -#include "SQLiteCpp/SQLiteCpp.h" +#include #include "utils/file_manager_utils.h" namespace cortex::db { diff --git a/engine/database/runs.cc b/engine/database/runs.cc new file mode 100644 index 000000000..af059cf89 --- /dev/null +++ b/engine/database/runs.cc @@ -0,0 +1,437 @@ +#include "runs.h" +#include "common/run.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" +#include "utils/scope_exit.h" + +namespace cortex::db { + +OpenAi::Run Runs::ParseRunFromQuery(SQLite::Statement& query) const { + OpenAi::Run entry; + int colIdx = 0; + + entry.id = query.getColumn(colIdx++).getString(); + entry.object = query.getColumn(colIdx++).getString(); + entry.created_at = query.getColumn(colIdx++).getInt64(); + entry.assistant_id = query.getColumn(colIdx++).getString(); + entry.thread_id = query.getColumn(colIdx++).getString(); + entry.status = + OpenAi::RunStatusFromString(query.getColumn(colIdx++).getString()); + + if (!query.getColumn(colIdx).isNull()) { + entry.started_at = query.getColumn(colIdx).getInt64(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.expired_at = query.getColumn(colIdx).getInt64(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.cancelled_at = query.getColumn(colIdx).getInt64(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.failed_at = query.getColumn(colIdx).getInt64(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.completed_at = query.getColumn(colIdx).getInt64(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + auto last_error = + OpenAi::LastError::FromJsonString(query.getColumn(colIdx).getString()); + if (last_error.has_value()) { + entry.last_error = last_error.value(); + } + } + colIdx++; + + entry.model = query.getColumn(colIdx++).getString(); + + if (!query.getColumn(colIdx).isNull()) { + entry.instructions = query.getColumn(colIdx).getString(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.tools = + OpenAi::Run::ToolsFromJsonString(query.getColumn(colIdx).getString()); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.metadata = + Cortex::VariantMapFromJsonString(query.getColumn(colIdx).getString()) + .value(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.incomplete_detail = OpenAi::IncompleteDetail::FromJsonString( + query.getColumn(colIdx).getString()) + .value(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.usage = + OpenAi::RunUsage::FromJsonString(query.getColumn(colIdx).getString()) + .value(); + } + colIdx++; + + entry.temperature = query.getColumn(colIdx++).getDouble(); + entry.top_p = query.getColumn(colIdx++).getDouble(); + entry.max_prompt_tokens = query.getColumn(colIdx++).getInt(); + entry.max_completion_tokens = query.getColumn(colIdx++).getInt(); + + if (!query.getColumn(colIdx).isNull()) { + entry.truncation_strategy = OpenAi::TruncationStrategy::FromJsonString( + query.getColumn(colIdx).getString()) + .value(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.response_format = query.getColumn(colIdx).getString(); + } + colIdx++; + + if (!query.getColumn(colIdx).isNull()) { + entry.tool_choice = query.getColumn(colIdx).getString(); + } + colIdx++; + + entry.parallel_tool_calls = query.getColumn(colIdx++).getInt() != 0; + + return entry; +} + +cpp::result, std::string> Runs::ListRuns( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const { + try { + db_.exec("BEGIN TRANSACTION;"); + cortex::utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); + std::vector runs; + + std::string sql = + "SELECT id, object, created_at, assistant_id, thread_id, status, " + "started_at, expired_at, cancelled_at, failed_at, completed_at, " + "last_error, model, instructions, tools, metadata, incomplete_details, " + "usage, temperature, top_p, max_prompt_tokens, max_completion_tokens, " + "truncation_strategy, response_format, tool_choice, " + "parallel_tool_calls " + "FROM runs"; + + std::vector where; + + if (!after.empty()) { + where.push_back( + "created_at < (SELECT created_at FROM runs WHERE id = ?)"); + } + if (!before.empty()) { + where.push_back( + "created_at > (SELECT created_at FROM runs WHERE id = ?)"); + } + + if (!where.empty()) { + sql += " WHERE "; + for (size_t i = 0; i < where.size(); ++i) { + if (i > 0) + sql += " AND "; + sql += where[i]; + } + } + + sql += " ORDER BY created_at "; + sql += (order == "asc" || order == "ASC") ? "ASC" : "DESC"; + sql += " LIMIT ?"; + + SQLite::Statement query(db_, sql); + + int bindIndex = 1; + if (!after.empty()) { + query.bind(bindIndex++, after); + } + if (!before.empty()) { + query.bind(bindIndex++, before); + } + query.bind(bindIndex, static_cast(limit)); + + while (query.executeStep()) { + runs.push_back(ParseRunFromQuery(query)); + } + return runs; + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +cpp::result Runs::UpdateRun(const OpenAi::Run& run) { + try { + SQLite::Statement update( + db_, + "UPDATE runs SET object = ?, created_at = ?, assistant_id = ?, " + "thread_id = ?, status = ?, started_at = ?, expired_at = ?, " + "cancelled_at = ?, failed_at = ?, completed_at = ?, last_error = ?, " + "model = ?, instructions = ?, tools = ?, metadata = ?, " + "incomplete_details = ?, usage = ?, temperature = ?, top_p = ?, " + "max_prompt_tokens = ?, max_completion_tokens = ?, truncation_strategy " + "= ?, " + "response_format = ?, tool_choice = ?, parallel_tool_calls = ? " + "WHERE id = ?"); + + int idx = 1; + update.bind(idx++, run.object); + update.bind(idx++, static_cast(run.created_at)); + update.bind(idx++, run.assistant_id); + update.bind(idx++, run.thread_id); + update.bind(idx++, OpenAi::RunStatusToString(run.status)); + + if (run.started_at) { + update.bind(idx++, static_cast(run.started_at)); + } else { + update.bind(idx++); + } + + if (run.expired_at) { + update.bind(idx++, static_cast(run.expired_at)); + } else { + update.bind(idx++); + } + + if (run.cancelled_at) { + update.bind(idx++, static_cast(run.cancelled_at)); + } else { + update.bind(idx++); + } + + if (run.failed_at) { + update.bind(idx++, static_cast(run.failed_at)); + } else { + update.bind(idx++); + } + + if (run.completed_at) { + update.bind(idx++, static_cast(run.completed_at)); + } else { + update.bind(idx++); + } + + if (run.last_error) { + update.bind(idx++, run.last_error->ToJson()->toStyledString()); + } else { + update.bind(idx++); + } + + update.bind(idx++, run.model); + update.bind(idx++, run.instructions); + update.bind(idx++, OpenAi::Run::ToolsToJsonString(run.tools).value()); + update.bind(idx++, Cortex::VariantMapToString(run.metadata).value()); + + if (run.incomplete_detail) { + update.bind(idx++, run.incomplete_detail->ToJson()->toStyledString()); + } else { + update.bind(idx++); + } + + if (run.usage) { + update.bind(idx++, run.usage->ToJson()->toStyledString()); + } else { + update.bind(idx++); + } + + update.bind(idx++, run.temperature.value()); + update.bind(idx++, run.top_p.value()); + update.bind(idx++, run.max_prompt_tokens.value()); + update.bind(idx++, run.max_completion_tokens.value()); + update.bind(idx++, run.truncation_strategy.ToJson()->toStyledString()); + + if (std::holds_alternative(run.response_format)) { + update.bind(idx++, std::get(run.response_format)); + } else { + update.bind(idx++, + std::get(run.response_format).toStyledString()); + } + + if (std::holds_alternative(run.tool_choice)) { + update.bind(idx++, std::get(run.tool_choice)); + } else { + update.bind(idx++, + std::get(run.tool_choice).toStyledString()); + } + + update.bind(idx++, run.parallel_tool_calls ? 1 : 0); + update.bind(idx++, run.id); + + if (update.exec() == 0) { + return cpp::fail("Run not found: " + run.id); + } + + CTL_INF("Updated: " << run.ToJson()->toStyledString()); + return {}; + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +cpp::result Runs::RetrieveRun( + const std::string& run_id) const { + try { + SQLite::Statement query( + db_, + "SELECT id, object, created_at, assistant_id, thread_id, status, " + "started_at, expired_at, cancelled_at, failed_at, completed_at, " + "last_error, model, instructions, tools, metadata, incomplete_details, " + "usage, temperature, top_p, max_prompt_tokens, max_completion_tokens, " + "truncation_strategy, response_format, tool_choice, " + "parallel_tool_calls " + "FROM runs WHERE id = ?"); + + query.bind(1, run_id); + + if (query.executeStep()) { + return ParseRunFromQuery(query); + } + return cpp::fail("Run not found: " + run_id); + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +cpp::result Runs::AddRunEntry(const OpenAi::Run& run) { + try { + SQLite::Statement insert( + db_, + "INSERT INTO runs (id, object, created_at, assistant_id, thread_id, " + "status, started_at, expired_at, cancelled_at, failed_at, " + "completed_at, " + "last_error, model, instructions, tools, metadata, incomplete_details, " + "usage, temperature, top_p, max_prompt_tokens, max_completion_tokens, " + "truncation_strategy, response_format, tool_choice, " + "parallel_tool_calls) " + "VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)"); + + int idx = 1; + insert.bind(idx++, run.id); + insert.bind(idx++, run.object); + insert.bind(idx++, static_cast(run.created_at)); + insert.bind(idx++, run.assistant_id); + insert.bind(idx++, run.thread_id); + insert.bind(idx++, OpenAi::RunStatusToString(run.status)); + + if (run.started_at) + insert.bind(idx++, static_cast(run.started_at)); + else + insert.bind(idx++); + + if (run.expired_at) + insert.bind(idx++, static_cast(run.expired_at)); + else + insert.bind(idx++); + + if (run.cancelled_at) + insert.bind(idx++, static_cast(run.cancelled_at)); + else + insert.bind(idx++); + + if (run.failed_at) + insert.bind(idx++, static_cast(run.failed_at)); + else + insert.bind(idx++); + + if (run.completed_at) + insert.bind(idx++, static_cast(run.completed_at)); + else + insert.bind(idx++); + + if (run.last_error.has_value()) { + insert.bind(idx++, run.last_error.value().ToJson()->toStyledString()); + } else { + insert.bind(idx++); + } + insert.bind(idx++, run.model); + insert.bind(idx++, run.instructions); + insert.bind(idx++, OpenAi::Run::ToolsToJsonString(run.tools).value()); + insert.bind(idx++, Cortex::VariantMapToString(run.metadata).value()); + if (run.incomplete_detail) { + insert.bind(idx++, run.incomplete_detail->ToJson()->toStyledString()); + } + if (run.last_error.has_value()) { + insert.bind(idx++, run.usage->ToJson()->toStyledString()); + } else { + insert.bind(idx++); + } + + if (run.temperature) + insert.bind(idx++, *run.temperature); + else + insert.bind(idx++); + + if (run.top_p) + insert.bind(idx++, *run.top_p); + else + insert.bind(idx++); + + if (run.max_prompt_tokens) + insert.bind(idx++, *run.max_prompt_tokens); + else + insert.bind(idx++); + + if (run.max_completion_tokens) + insert.bind(idx++, *run.max_completion_tokens); + else + insert.bind(idx++); + + insert.bind(idx++, run.truncation_strategy.ToJson()->toStyledString()); + if (std::holds_alternative(run.response_format)) { + insert.bind(idx++, std::get(run.response_format)); + } else { + insert.bind(idx++, + std::get(run.response_format).toStyledString()); + } + + if (std::holds_alternative(run.tool_choice)) { + insert.bind(idx++, std::get(run.tool_choice)); + } else { + insert.bind(idx++, + std::get(run.tool_choice).toStyledString()); + } + insert.bind(idx++, run.parallel_tool_calls ? 1 : 0); + + insert.exec(); + CTL_INF("Inserted: " << run.ToJson()->toStyledString()); + return {}; + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +cpp::result Runs::DeleteRun(const std::string& run_id) { + try { + SQLite::Statement del(db_, "DELETE FROM runs WHERE id = ?"); + del.bind(1, run_id); + + if (del.exec() == 0) { + return cpp::fail("Run not found: " + run_id); + } + + CTL_INF("Deleted run: " << run_id); + return {}; + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} +} // namespace cortex::db diff --git a/engine/database/runs.h b/engine/database/runs.h new file mode 100644 index 000000000..b9894cb85 --- /dev/null +++ b/engine/database/runs.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include +#include "common/run.h" +#include "database.h" +#include "utils/result.hpp" + +namespace cortex::db { +class Runs { + SQLite::Database& db_; + + public: + Runs(SQLite::Database& db) : db_{db} {}; + + Runs() : db_(cortex::db::Database::GetInstance().db()) {} + + ~Runs() {} + + cpp::result, std::string> ListRuns( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const; + + cpp::result RetrieveRun( + const std::string& run_id) const; + + cpp::result UpdateRun(const OpenAi::Run& run); + + cpp::result AddRunEntry(const OpenAi::Run& run); + + cpp::result DeleteRun(const std::string& run_id); + + private: + OpenAi::Run ParseRunFromQuery(SQLite::Statement& query) const; +}; +} // namespace cortex::db diff --git a/engine/main.cc b/engine/main.cc index 59ec49873..2e79b7234 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -10,6 +10,7 @@ #include "controllers/messages.h" #include "controllers/models.h" #include "controllers/process_manager.h" +#include "controllers/runs.h" #include "controllers/server.h" #include "controllers/swagger.h" #include "controllers/threads.h" @@ -18,6 +19,7 @@ #include "repositories/assistant_fs_repository.h" #include "repositories/file_fs_repository.h" #include "repositories/message_fs_repository.h" +#include "repositories/run_sqlite_repository.h" #include "repositories/thread_fs_repository.h" #include "services/assistant_service.h" #include "services/config_service.h" @@ -26,6 +28,7 @@ #include "services/message_service.h" #include "services/model_service.h" #include "services/model_source_service.h" +#include "services/run_service.h" #include "services/thread_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" @@ -148,6 +151,7 @@ void RunServer(std::optional host, std::optional port, auto thread_repo = std::make_shared(data_folder_path); auto assistant_repo = std::make_shared(data_folder_path); + auto run_repo = std::make_shared(db_service); auto file_srv = std::make_shared(file_repo); auto assistant_srv = @@ -166,6 +170,9 @@ void RunServer(std::optional host, std::optional port, auto model_service = std::make_shared( db_service, hw_service, download_service, inference_svc, engine_service); inference_svc->SetModelService(model_service); + auto run_srv = + std::make_shared(run_repo, assistant_srv, model_service, + message_srv, inference_svc, thread_srv); auto file_watcher_srv = std::make_shared( model_dir_path.string(), model_service); @@ -174,6 +181,7 @@ void RunServer(std::optional host, std::optional port, // initialize custom controllers auto swagger_ctl = std::make_shared(config.apiServerHost, config.apiServerPort); + auto run_ctl = std::make_shared(run_srv, inference_svc); auto file_ctl = std::make_shared(file_srv, message_srv); auto assistant_ctl = std::make_shared(assistant_srv); auto thread_ctl = std::make_shared(thread_srv, message_srv); @@ -189,6 +197,7 @@ void RunServer(std::optional host, std::optional port, auto config_ctl = std::make_shared(config_service); drogon::app().registerController(swagger_ctl); + drogon::app().registerController(run_ctl); drogon::app().registerController(file_ctl); drogon::app().registerController(assistant_ctl); drogon::app().registerController(thread_ctl); @@ -236,18 +245,16 @@ void RunServer(std::optional host, std::optional port, auto allowed_origins = config_service->GetApiServerConfiguration()->allowed_origins; - auto is_contains_asterisk = - std::find(allowed_origins.begin(), allowed_origins.end(), "*"); - if (is_contains_asterisk != allowed_origins.end()) { + if (auto it = std::ranges::find(allowed_origins, "*"); + it != allowed_origins.end()) { resp->addHeader("Access-Control-Allow-Origin", "*"); resp->addHeader("Access-Control-Allow-Methods", "*"); return; } // Check if the origin is in our allowed list - auto it = - std::find(allowed_origins.begin(), allowed_origins.end(), origin); - if (it != allowed_origins.end()) { + if (auto it = std::ranges::find(allowed_origins, origin); + it != allowed_origins.end()) { resp->addHeader("Access-Control-Allow-Origin", origin); } else if (allowed_origins.empty()) { resp->addHeader("Access-Control-Allow-Origin", "*"); diff --git a/engine/migrations/db_helper.h b/engine/migrations/db_helper.h index 867e871ff..5b7768d61 100644 --- a/engine/migrations/db_helper.h +++ b/engine/migrations/db_helper.h @@ -1,11 +1,9 @@ #pragma once + #include namespace cortex::mgr { -#include -#include #include -#include inline bool ColumnExists(SQLite::Database& db, const std::string& table_name, const std::string& column_name) { diff --git a/engine/migrations/migration_manager.cc b/engine/migrations/migration_manager.cc index 26197115d..935b3399d 100644 --- a/engine/migrations/migration_manager.cc +++ b/engine/migrations/migration_manager.cc @@ -1,6 +1,7 @@ #include "migration_manager.h" #include #include "assert.h" +#include "migrations/v4/migration.h" #include "schema_version.h" #include "utils/file_manager_utils.h" #include "utils/scope_exit.h" @@ -149,6 +150,8 @@ cpp::result MigrationManager::DoUpFolderStructure( return v2::MigrateFolderStructureUp(); case 3: return v3::MigrateFolderStructureUp(); + case 4: + return v4::MigrateFolderStructureUp(); default: return true; @@ -165,6 +168,8 @@ cpp::result MigrationManager::DoDownFolderStructure( return v2::MigrateFolderStructureDown(); case 3: return v3::MigrateFolderStructureDown(); + case 4: + return v4::MigrateFolderStructureDown(); default: return true; @@ -203,6 +208,8 @@ cpp::result MigrationManager::DoUpDB(int version) { return v2::MigrateDBUp(db_); case 3: return v3::MigrateDBUp(db_); + case 4: + return v4::MigrateDBUp(db_); default: return true; @@ -219,6 +226,8 @@ cpp::result MigrationManager::DoDownDB(int version) { return v2::MigrateDBDown(db_); case 3: return v3::MigrateDBDown(db_); + case 4: + return v3::MigrateDBDown(db_); default: return true; diff --git a/engine/migrations/v3/migration.h b/engine/migrations/v3/migration.h index 3bed802fb..395a8f0ef 100644 --- a/engine/migrations/v3/migration.h +++ b/engine/migrations/v3/migration.h @@ -53,11 +53,11 @@ inline cpp::result MigrateDBUp(SQLite::Database& db) { inline cpp::result MigrateDBDown(SQLite::Database& db) { try { - // hardware + // files { SQLite::Statement query(db, "SELECT name FROM sqlite_master WHERE " - "type='table' AND name='hardware'"); + "type='table' AND name='files'"); auto table_exists = query.executeStep(); if (table_exists) { db.exec("DROP TABLE files"); diff --git a/engine/migrations/v4/migration.h b/engine/migrations/v4/migration.h new file mode 100644 index 000000000..11bcd6708 --- /dev/null +++ b/engine/migrations/v4/migration.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace cortex::migr::v4 { +inline cpp::result MigrateFolderStructureUp() { + return true; +} + +inline cpp::result MigrateFolderStructureDown() { + return true; +} + +// Database +inline cpp::result MigrateDBUp(SQLite::Database& db) { + try { + db.exec( + "CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY " + "KEY);"); + + // runs + { + // Check if the table exists + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='runs'"); + auto table_exists = query.executeStep(); + + if (!table_exists) { + // Create new table + db.exec( + "CREATE TABLE runs (" + "id TEXT PRIMARY KEY," + "object TEXT," + "created_at INTEGER," + "assistant_id TEXT," + "thread_id TEXT," + "status TEXT," + "started_at INTEGER," + "expired_at INTEGER," + "cancelled_at INTEGER," + "failed_at INTEGER," + "completed_at INTEGER," + "last_error TEXT," + "model TEXT," + "instructions TEXT," + "tools TEXT," + "metadata TEXT," + "incomplete_details TEXT," + "usage TEXT," + "temperature REAL," + "top_p REAL," + "max_prompt_tokens INTEGER," + "max_completion_tokens INTEGER," + "truncation_strategy TEXT," + "response_format TEXT," + "tool_choice TEXT," + "parallel_tool_calls BOOL" + ")"); + } + } + + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration up failed: " << e.what()); + return cpp::fail(e.what()); + } +}; + +inline cpp::result MigrateDBDown(SQLite::Database& db) { + try { + // runs + { + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='runs'"); + auto table_exists = query.executeStep(); + if (table_exists) { + db.exec("DROP TABLE runs"); + } + } + + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration down failed: " << e.what()); + return cpp::fail(e.what()); + } +} +}; // namespace cortex::migr::v4 diff --git a/engine/repositories/assistant_fs_repository.cc b/engine/repositories/assistant_fs_repository.cc index 87b4174fd..82f0e3434 100644 --- a/engine/repositories/assistant_fs_repository.cc +++ b/engine/repositories/assistant_fs_repository.cc @@ -113,7 +113,7 @@ cpp::result AssistantFsRepository::DeleteAssistant( } cpp::result -AssistantFsRepository::CreateAssistant(OpenAi::Assistant& assistant) { +AssistantFsRepository::CreateAssistant(const OpenAi::Assistant& assistant) { CTL_INF("CreateAssistant: " + assistant.id); { std::unique_lock lock(GrabAssistantMutex(assistant.id)); @@ -139,7 +139,7 @@ AssistantFsRepository::CreateAssistant(OpenAi::Assistant& assistant) { } cpp::result AssistantFsRepository::SaveAssistant( - OpenAi::Assistant& assistant) { + const OpenAi::Assistant& assistant) { auto path = GetAssistantPath(assistant.id) / kAssistantFileName; if (!std::filesystem::exists(path)) { std::filesystem::create_directories(path); diff --git a/engine/repositories/assistant_fs_repository.h b/engine/repositories/assistant_fs_repository.h index f310bd54e..9bead6260 100644 --- a/engine/repositories/assistant_fs_repository.h +++ b/engine/repositories/assistant_fs_repository.h @@ -15,7 +15,7 @@ class AssistantFsRepository : public AssistantRepository { const std::string& before) const override; cpp::result CreateAssistant( - OpenAi::Assistant& assistant) override; + const OpenAi::Assistant& assistant) override; cpp::result RetrieveAssistant( const std::string assistant_id) const override; @@ -43,7 +43,8 @@ class AssistantFsRepository : public AssistantRepository { std::shared_mutex& GrabAssistantMutex(const std::string& assistant_id) const; - cpp::result SaveAssistant(OpenAi::Assistant& assistant); + cpp::result SaveAssistant( + const OpenAi::Assistant& assistant); cpp::result LoadAssistant( const std::string& assistant_id) const; diff --git a/engine/repositories/file_fs_repository.cc b/engine/repositories/file_fs_repository.cc index e6c28b38e..71b9d0f88 100644 --- a/engine/repositories/file_fs_repository.cc +++ b/engine/repositories/file_fs_repository.cc @@ -2,7 +2,6 @@ #include #include #include -#include "database/file.h" #include "utils/logging_utils.h" #include "utils/result.hpp" diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index db6f5dd6e..6e49670f6 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -11,7 +11,7 @@ std::filesystem::path MessageFsRepository::GetMessagePath( } cpp::result MessageFsRepository::CreateMessage( - OpenAi::Message& message) { + const OpenAi::Message& message) { CTL_INF("CreateMessage for thread " + message.thread_id); auto path = GetMessagePath(message.thread_id); @@ -133,7 +133,7 @@ cpp::result MessageFsRepository::RetrieveMessage( } cpp::result MessageFsRepository::ModifyMessage( - OpenAi::Message& message) { + const OpenAi::Message& message) { auto mutex = GrabMutex(message.thread_id); std::unique_lock lock(*mutex); diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h index 0ca6e89b3..690c3960b 100644 --- a/engine/repositories/message_fs_repository.h +++ b/engine/repositories/message_fs_repository.h @@ -11,7 +11,7 @@ class MessageFsRepository : public MessageRepository { public: cpp::result CreateMessage( - OpenAi::Message& message) override; + const OpenAi::Message& message) override; cpp::result, std::string> ListMessages( const std::string& thread_id, uint8_t limit, const std::string& order, @@ -23,7 +23,7 @@ class MessageFsRepository : public MessageRepository { const std::string& message_id) const override; cpp::result ModifyMessage( - OpenAi::Message& message) override; + const OpenAi::Message& message) override; cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id) override; diff --git a/engine/repositories/run_sqlite_repository.cc b/engine/repositories/run_sqlite_repository.cc new file mode 100644 index 000000000..00c2946d3 --- /dev/null +++ b/engine/repositories/run_sqlite_repository.cc @@ -0,0 +1,28 @@ +#include "run_sqlite_repository.h" + +cpp::result RunSqliteRepository::CreateRun( + const OpenAi::Run& run) { + return db_service_->AddRunEntry(run); +} + +cpp::result, std::string> +RunSqliteRepository::ListRuns(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + return db_service_->ListRuns(limit, order, after, before); +} + +cpp::result RunSqliteRepository::RetrieveRun( + const std::string& run_id) const { + return db_service_->RetrieveRun(run_id); +} + +cpp::result RunSqliteRepository::ModifyRun( + OpenAi::Run& run) { + return db_service_->ModifyRun(run); +} + +cpp::result RunSqliteRepository::DeleteRun( + const std::string& run_id) { + return db_service_->DeleteRun(run_id); +} diff --git a/engine/repositories/run_sqlite_repository.h b/engine/repositories/run_sqlite_repository.h new file mode 100644 index 000000000..87eb53800 --- /dev/null +++ b/engine/repositories/run_sqlite_repository.h @@ -0,0 +1,28 @@ +#pragma once + +#include "common/repository/run_repository.h" +#include "services/database_service.h" + +class RunSqliteRepository : public RunRepository { + public: + cpp::result CreateRun(const OpenAi::Run& run) override; + + cpp::result, std::string> ListRuns( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const override; + + cpp::result RetrieveRun( + const std::string& run_id) const override; + + cpp::result ModifyRun(OpenAi::Run& run) override; + + cpp::result DeleteRun(const std::string& run_id) override; + + ~RunSqliteRepository() = default; + + explicit RunSqliteRepository(std::shared_ptr db_service) + : db_service_{db_service} {} + + private: + std::shared_ptr db_service_; +}; diff --git a/engine/repositories/thread_fs_repository.cc b/engine/repositories/thread_fs_repository.cc index 6b75db8e4..007ef4bce 100644 --- a/engine/repositories/thread_fs_repository.cc +++ b/engine/repositories/thread_fs_repository.cc @@ -116,7 +116,7 @@ cpp::result ThreadFsRepository::LoadThread( } cpp::result ThreadFsRepository::CreateThread( - OpenAi::Thread& thread) { + const OpenAi::Thread& thread) { CTL_INF("CreateThread: " + thread.id); std::unique_lock lock(GrabThreadMutex(thread.id)); auto thread_path = GetThreadPath(thread.id); @@ -134,7 +134,7 @@ cpp::result ThreadFsRepository::CreateThread( } cpp::result ThreadFsRepository::SaveThread( - OpenAi::Thread& thread) { + const OpenAi::Thread& thread) { auto path = GetThreadPath(thread.id) / kThreadFileName; if (!std::filesystem::exists(path)) { return cpp::fail("Path does not exist: " + path.string()); diff --git a/engine/repositories/thread_fs_repository.h b/engine/repositories/thread_fs_repository.h index b6f6032fa..1502e7065 100644 --- a/engine/repositories/thread_fs_repository.h +++ b/engine/repositories/thread_fs_repository.h @@ -46,7 +46,7 @@ class ThreadFsRepository : public ThreadRepository, cpp::result LoadThread( const std::string& thread_id) const; - cpp::result SaveThread(OpenAi::Thread& thread); + cpp::result SaveThread(const OpenAi::Thread& thread); public: explicit ThreadFsRepository(const std::filesystem::path& data_folder_path) @@ -59,7 +59,8 @@ class ThreadFsRepository : public ThreadRepository, } } - cpp::result CreateThread(OpenAi::Thread& thread) override; + cpp::result CreateThread( + const OpenAi::Thread& thread) override; cpp::result, std::string> ListThreads( uint8_t limit, const std::string& order, const std::string& after, diff --git a/engine/services/assistant_service.cc b/engine/services/assistant_service.cc index 08a5a743f..64c7d49fb 100644 --- a/engine/services/assistant_service.cc +++ b/engine/services/assistant_service.cc @@ -1,6 +1,7 @@ #include "assistant_service.h" #include #include "utils/logging_utils.h" +#include "utils/time_utils.h" #include "utils/ulid_generator.h" cpp::result @@ -95,11 +96,7 @@ cpp::result AssistantService::CreateAssistantV2( if (create_dto.response_format) { assistant.response_format = *create_dto.response_format; } - auto seconds_since_epoch = - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(); - assistant.created_at = seconds_since_epoch; + assistant.created_at = cortex_utils::SecondsSinceEpoch(); return assistant_repository_->CreateAssistant(assistant); } cpp::result diff --git a/engine/services/database_service.cc b/engine/services/database_service.cc index d4cd977a9..e41c30c53 100644 --- a/engine/services/database_service.cc +++ b/engine/services/database_service.cc @@ -1,4 +1,6 @@ #include "database_service.h" +#include "database/file.h" +#include "database/runs.h" // begin engines std::optional DatabaseService::UpsertEngine( @@ -127,4 +129,31 @@ cpp::result, std::string> DatabaseService::GetModels( const std::string& model_src) const { return cortex::db::Models().GetModels(model_src); } -// end models \ No newline at end of file +// end models + +// runs +cpp::result DatabaseService::AddRunEntry( + const OpenAi::Run& run) { + return cortex::db::Runs().AddRunEntry(run); +} + +cpp::result DatabaseService::RetrieveRun( + const std::string& run_id) const { + return cortex::db::Runs().RetrieveRun(run_id); +} + +cpp::result DatabaseService::ModifyRun( + const OpenAi::Run& run) { + return cortex::db::Runs().UpdateRun(run); +} + +cpp::result DatabaseService::DeleteRun( + const std::string& run_id) { + return cortex::db::Runs().DeleteRun(run_id); +} + +cpp::result, std::string> DatabaseService::ListRuns( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) { + return cortex::db::Runs().ListRuns(limit, order, after, before); +} diff --git a/engine/services/database_service.h b/engine/services/database_service.h index 4fb4f7be0..41ceb4485 100644 --- a/engine/services/database_service.h +++ b/engine/services/database_service.h @@ -1,6 +1,8 @@ #pragma once + +#include "common/file.h" +#include "common/run.h" #include "database/engines.h" -#include "database/file.h" #include "database/hardware.h" #include "database/models.h" @@ -64,5 +66,15 @@ class DatabaseService { cpp::result, std::string> GetModels( const std::string& model_src) const; + // runs + cpp::result AddRunEntry(const OpenAi::Run& run); + cpp::result RetrieveRun( + const std::string& run_id) const; + cpp::result ModifyRun(const OpenAi::Run& run); + cpp::result DeleteRun(const std::string& run_id); + cpp::result, std::string> ListRuns( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before); + private: -}; \ No newline at end of file +}; diff --git a/engine/services/file_service.cc b/engine/services/file_service.cc index 3341227e0..b29bb15c6 100644 --- a/engine/services/file_service.cc +++ b/engine/services/file_service.cc @@ -1,22 +1,18 @@ #include "file_service.h" #include +#include "utils/time_utils.h" #include "utils/ulid_generator.h" cpp::result FileService::UploadFile( const std::string& filename, const std::string& purpose, const char* content, uint64_t content_length) { - auto seconds_since_epoch = - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(); - auto file_id{"file-" + ulid::GenerateUlid()}; OpenAi::File file; file.id = file_id; file.object = "file"; file.bytes = content_length; - file.created_at = seconds_since_epoch; + file.created_at = cortex_utils::SecondsSinceEpoch(); file.filename = filename; file.purpose = purpose; diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index f23be3f23..0767a3e69 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -1,36 +1,10 @@ #pragma once -#include -#include -#include -#include "extensions/remote-engine/remote_engine.h" +#include "common/cortex/sync_queue.h" #include "services/engine_service.h" #include "services/model_service.h" #include "utils/result.hpp" -// Status and result -using InferResult = std::pair; - -struct SyncQueue { - void push(InferResult&& p) { - std::unique_lock l(mtx); - q.push(p); - cond.notify_one(); - } - - InferResult wait_and_pop() { - std::unique_lock l(mtx); - cond.wait(l, [this] { return !q.empty(); }); - auto res = q.front(); - q.pop(); - return res; - } - - std::mutex mtx; - std::condition_variable cond; - std::queue q; -}; - class InferenceService { public: explicit InferenceService(std::shared_ptr engine_service) @@ -47,7 +21,7 @@ class InferenceService { cpp::result HandleRouteRequest( std::shared_ptr q, std::shared_ptr json_body); - + InferResult LoadModel(std::shared_ptr json_body); InferResult UnloadModel(const std::string& engine, diff --git a/engine/services/message_service.cc b/engine/services/message_service.cc index 9b57e0215..024148bcc 100644 --- a/engine/services/message_service.cc +++ b/engine/services/message_service.cc @@ -1,6 +1,7 @@ #include "services/message_service.h" #include "utils/logging_utils.h" #include "utils/result.hpp" +#include "utils/time_utils.h" #include "utils/ulid_generator.h" cpp::result MessageService::CreateMessage( @@ -8,32 +9,33 @@ cpp::result MessageService::CreateMessage( std::variant>>&& content, std::optional> attachments, - std::optional metadata) { + std::optional metadata, OpenAi::Status status) { LOG_TRACE << "CreateMessage for thread " << thread_id; - uint32_t seconds_since_epoch = - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(); + auto seconds_since_epoch = cortex_utils::SecondsSinceEpoch(); std::vector> content_list{}; - // if content is string if (std::holds_alternative(content)) { - auto text_content = std::make_unique(); - text_content->text.value = std::get(content); + auto text_content = + std::make_unique(std::get(content)); content_list.push_back(std::move(text_content)); } else { content_list = std::move( std::get>>(content)); } + std::optional completed_at = + (status == OpenAi::Status::COMPLETED) + ? std::optional(seconds_since_epoch) + : std::nullopt; + OpenAi::Message msg; msg.id = ulid::GenerateUlid(); msg.object = "thread.message"; msg.created_at = seconds_since_epoch; msg.thread_id = thread_id; - msg.status = OpenAi::Status::COMPLETED; - msg.completed_at = seconds_since_epoch; + msg.status = status; + msg.completed_at = completed_at; msg.incomplete_at = std::nullopt; msg.incomplete_details = std::nullopt; msg.role = role; @@ -136,3 +138,14 @@ cpp::result MessageService::InitializeMessages( return message_repository_->InitializeMessages(thread_id, std::move(messages)); } + +cpp::result MessageService::ModifyMessage( + const OpenAi::Message& message) { + auto res = message_repository_->ModifyMessage(message); + if (res.has_error()) { + CTL_ERR("Failed to modify message: " + res.error()); + return cpp::fail("Failed to modify message: " + res.error()); + } + + return RetrieveMessage(message.thread_id, message.id); +} diff --git a/engine/services/message_service.h b/engine/services/message_service.h index 456cdb3a3..4870b6697 100644 --- a/engine/services/message_service.h +++ b/engine/services/message_service.h @@ -14,7 +14,8 @@ class MessageService { std::variant>>&& content, std::optional> attachments, - std::optional metadata); + std::optional metadata, + OpenAi::Status status = OpenAi::Status::COMPLETED); cpp::result InitializeMessages( const std::string& thread_id, @@ -35,6 +36,9 @@ class MessageService { std::vector>>> content); + cpp::result ModifyMessage( + const OpenAi::Message& message); + cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 74767a9b2..057fb52f0 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -10,10 +10,8 @@ #include "config/yaml_config.h" #include "database/models.h" #include "hardware_service.h" -#include "utils/archive_utils.h" - #include "services/inference_service.h" - +#include "utils/archive_utils.h" #include "utils/cli_selection_utils.h" #include "utils/engine_constants.h" #include "utils/file_manager_utils.h" @@ -232,11 +230,11 @@ cpp::result ModelService::HandleCortexsoModel( continue; } auto model_id = modelName + ":" + branch.second.name; - if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(), - model_id) != - downloaded_model_ids.end()) { // if downloaded, we skip it + if (auto it = std::ranges::find(downloaded_model_ids, model_id); + it != downloaded_model_ids.end()) { continue; } + avai_download_opts.emplace_back(model_id); } diff --git a/engine/services/run_service.cc b/engine/services/run_service.cc new file mode 100644 index 000000000..418ff4d03 --- /dev/null +++ b/engine/services/run_service.cc @@ -0,0 +1,324 @@ +#include "run_service.h" +#include +#include "common/cortex/sync_queue.h" +#include "common/events/done.h" +#include "common/events/error.h" +#include "common/events/thread_message_completed.h" +#include "common/events/thread_message_created.h" +#include "common/events/thread_message_delta.h" +#include "common/events/thread_message_in_progress.h" +#include "common/message_delta.h" +#include "utils/logging_utils.h" +#include "utils/time_utils.h" +#include "utils/ulid_generator.h" + +auto RunService::CreateRun(const dto::RunCreateDto& create_dto) + -> cpp::result { + + auto assistant = assistant_srv_->RetrieveAssistantV2(create_dto.assistant_id); + if (assistant.has_error()) { + return cpp::fail(assistant.error()); + } + + auto model = create_dto.model.has_value() ? create_dto.model.value() + : assistant->model; + // todo: check if model exists + + auto additional_inst = create_dto.additional_instructions.value_or(""); + auto instructions = create_dto.instructions.has_value() + ? create_dto.instructions.value() + : assistant->instructions.value_or(""); + instructions += ". " + additional_inst; + + // parsing messages and store it + // tools + // metadata + + auto temperature = create_dto.temperature.has_value() + ? create_dto.temperature.value() + : assistant->temperature.value_or(1.0f); + + auto top_p = create_dto.top_p.has_value() ? create_dto.top_p.value() + : assistant->top_p.value_or(1.0f); + + auto stream = + create_dto.stream.has_value() ? create_dto.stream.value() : true; + + // max_prompt_tokens + // max_completion_tokens + // truncation_strategy + // tool_choice + // parallel_tool_calls + // response_format + + auto id{"run_" + ulid::GenerateUlid()}; + + OpenAi::Run run; + run.id = id; + run.assistant_id = create_dto.assistant_id; + run.model = model; + run.instructions = instructions; + run.temperature = temperature; + run.top_p = top_p; + + // TODO: submit run + // if submit success then we return run + + return run; +} + +auto RunService::CreateRunStream( + dto::RunCreateDto&& create_dto, const std::string& thread_id, + std::function callback) + -> void { + std::thread([this, &create_dto, thread_id, callback = std::move(callback)]() { + auto assistant = + assistant_srv_->RetrieveAssistantV2(create_dto.assistant_id); + if (assistant.has_error()) { + callback(OpenAi::ErrorEvent(assistant.error()), true); + return; + } + + auto thread = thread_srv_->RetrieveThread(thread_id); + if (thread.has_error()) { + callback(OpenAi::ErrorEvent(thread.error()), true); + return; + } + + auto model = create_dto.model.has_value() && !create_dto.model->empty() + ? create_dto.model.value() + : assistant->model; + if (model.empty()) { + callback(OpenAi::ErrorEvent( + "Model to use is empty. Please recheck the assistant"), + true); + return; + } + + // TODO: check if engine and model is loaded yet + + auto additional_messages = std::move(*create_dto.additional_messages); + for (auto& msg : additional_messages) { + auto create_msg_res = message_srv_->CreateMessage( + thread_id, msg.role, std::move(msg.content), + std::move(msg.attachments), std::move(msg.metadata)); + + if (create_msg_res.has_error()) { + auto err_msg = + "Create additional message error: " + create_msg_res.error(); + CTL_WRN(err_msg); + callback(OpenAi::ErrorEvent(err_msg), true); + return; + } + } + + auto additional_inst = create_dto.additional_instructions.value_or(""); + auto instructions = create_dto.instructions.has_value() + ? create_dto.instructions.value() + : assistant->instructions.value_or(""); + instructions += ". " + additional_inst; + + auto temperature = create_dto.temperature.has_value() + ? create_dto.temperature.value() + : assistant->temperature.value_or(1.0f); + + auto top_p = create_dto.top_p.has_value() ? create_dto.top_p.value() + : assistant->top_p.value_or(1.0f); + + auto q = std::make_shared(); + { + auto request_json = GetMessageListAsJson(thread_id); + if (request_json.has_error()) { + callback(OpenAi::ErrorEvent(request_json.error()), true); + return; + } + (*request_json.value())["model"] = model; + (*request_json.value())["stream"] = true; + (*request_json.value())["top_p"] = top_p; + (*request_json.value())["temperature"] = temperature; + + auto ir = inference_srv_->HandleChatCompletion( + q, std::shared_ptr(std::move(request_json.value()))); + if (ir.has_error()) { + auto error = + OpenAi::ErrorEvent(std::get<1>(ir.error()).toStyledString()); + callback(error, true); + return; + } + } + + auto assistant_msg = message_srv_->CreateMessage( + thread_id, OpenAi::Role::ASSISTANT, "", std::nullopt, std::nullopt, + OpenAi::Status::IN_PROGRESS); + if (assistant_msg.has_error()) { + callback(OpenAi::ErrorEvent(assistant_msg.error()), true); + return; + } + + { + assistant_msg->assistant_id = assistant->id; + assistant_msg->run_id = + "run_test123"; // TODO: create and store the run id + auto res = message_srv_->ModifyMessage(assistant_msg.value()); + if (res.has_error()) { + CTL_ERR("Failed to modify message: " + res.error()); + } + callback(OpenAi::ThreadMessageCreatedEvent( + assistant_msg->ToSingleLineJsonString(false).value()), + false); + } + + callback(OpenAi::ThreadMessageInProgressEvent( + assistant_msg->ToSingleLineJsonString(false).value()), + false); + + while (true) { + auto [status, res] = q->wait_and_pop(); + + if (status["has_error"].asBool()) { + auto err_msg{res["message"].asString()}; + { + assistant_msg->incomplete_details = OpenAi::IncompleteDetail(err_msg); + assistant_msg->incomplete_at = cortex_utils::SecondsSinceEpoch(); + assistant_msg->status = OpenAi::Status::INCOMPLETE; + auto res = message_srv_->ModifyMessage(assistant_msg.value()); + if (res.has_error()) { + CTL_ERR("Failed to modify message: " + res.error()); + } + } + callback(OpenAi::ErrorEvent(err_msg), true); + return; + } + + if (status["is_done"].asBool()) { + { + assistant_msg->completed_at = cortex_utils::SecondsSinceEpoch(); + assistant_msg->status = OpenAi::Status::COMPLETED; + auto res = message_srv_->ModifyMessage(assistant_msg.value()); + if (res.has_error()) { + CTL_ERR("Failed to modify message: " + res.error()); + } + callback(OpenAi::ThreadMessageCompletedEvent( + assistant_msg->ToSingleLineJsonString(false).value()), + false); + } + break; + } + + auto str = res["data"].asString(); + if (str.substr(0, 6) == "data: ") { + str = str.substr(6); + } + Json::Value json; + Json::Reader reader; + bool parse_success = reader.parse(str, json); + if (!parse_success) { + CTL_ERR("Failed to parse JSON: " + reader.getFormattedErrorMessages()); + continue; + } + + if (!json.isMember("choices") || json["choices"].empty() || + !json["choices"][0].isMember("delta") || + !json["choices"][0]["delta"].isMember("content")) { + CTL_WRN("Missing required fields in JSON"); + CTL_WRN("Has choices: " + std::to_string(json.isMember("choices"))); + if (json.isMember("choices")) { + CTL_WRN("Choices size: " + std::to_string(json["choices"].size())); + CTL_WRN("First choice has delta: " + + std::to_string(json["choices"][0].isMember("delta"))); + } + continue; + } + + { + auto text_content = std::make_unique( + json["choices"][0]["delta"]["content"].asString()); + + auto content = std::vector>(); + content.push_back(std::move(text_content)); + // Get existing content or create new if empty + if (!assistant_msg->content.empty()) { + // Find the last text content to append to + for (auto it = assistant_msg->content.rbegin(); + it != assistant_msg->content.rend(); ++it) { + if (auto* text_content = + dynamic_cast(it->get())) { + // Append the new delta text to existing text content + text_content->text.value += + json["choices"][0]["delta"]["content"].asString(); + break; + } + } + } else { + // Create new text content if message is empty + auto text_content = std::make_unique( + json["choices"][0]["delta"]["content"].asString()); + assistant_msg->content.push_back(std::move(text_content)); + } + + // Update message + assistant_msg->status = OpenAi::Status::IN_PROGRESS; + auto res = message_srv_->ModifyMessage(assistant_msg.value()); + if (res.has_error()) { + CTL_ERR("Failed to modify message: " + res.error()); + } + } + + auto text_content = std::make_unique( + json["choices"][0]["delta"]["content"].asString()); + + auto content = std::vector>(); + content.push_back(std::move(text_content)); + + auto delta = OpenAi::MessageDelta::Delta(OpenAi::Role::ASSISTANT, + std::move(content)); + auto msg_delta_evt = OpenAi::ThreadMessageDeltaEvent(std::move(delta)); + callback(msg_delta_evt, false); + } + + // TODO: emit run completed + + callback(OpenAi::DoneEvent(), true); + }).detach(); +} + +auto RunService::GetMessageListAsJson(const std::string& thread_id) + -> cpp::result, std::string> { + auto messages_res = + message_srv_->ListMessages(thread_id, -1, "asc", "", "", ""); + if (messages_res.has_error()) { + return cpp::fail(messages_res.error()); + } + + Json::Value messages(Json::arrayValue); + // TODO: namh check what if the message is not text based? + // TODO: namh what if message have multiple text array item? + for (const auto& msg : messages_res.value()) { + Json::Value message; + for (const auto& content : msg.content) { + if (content->type == "text") { + if (auto* text_content = + dynamic_cast(content.get())) { + auto text = text_content->text.value; + message["content"] = std::move(text); + break; + } + } + } + + message["role"] = RoleToString(msg.role); + messages.append(message); + } + + auto json = std::make_unique(); + (*json)["messages"] = std::move(messages); + CTL_INF("GetMessageListAsJson: " + json->toStyledString()); + return json; +} + +auto RunService::ListRuns(const std::string& thread_id, uint8_t limit, + const std::string& order, const std::string& after, + const std::string& before) const + -> cpp::result, std::string> { + return run_repository_->ListRuns(limit, order, after, before); +} diff --git a/engine/services/run_service.h b/engine/services/run_service.h new file mode 100644 index 000000000..58efc1459 --- /dev/null +++ b/engine/services/run_service.h @@ -0,0 +1,76 @@ +#pragma once + +#include "common/dto/run_create_dto.h" +#include "common/dto/run_update_dto.h" +#include "common/events/assistant_stream_event.h" +#include "common/repository/run_repository.h" +#include "common/run.h" +#include "services/assistant_service.h" +#include "services/inference_service.h" +#include "services/message_service.h" +#include "services/model_service.h" +#include "services/thread_service.h" +#include "utils/result.hpp" + +class RunService { + public: + auto CreateRun(const dto::RunCreateDto& create_dto) + -> cpp::result; + + auto CreateRunStream( + dto::RunCreateDto&& create_dto, const std::string& thread_id, + std::function + callback) -> void; + + /** + * Retrieves a run. + */ + auto RetrieveRun(const std::string& thread_id, + const std::string& run_id) const + -> cpp::result; + + /** + * Cancels a run that is in_progress. + */ + auto CancelRun(const std::string& thread_id, const std::string& run_id) + -> cpp::result; + + /** + * Modifies a run. + */ + auto ModifyRun(const std::string& thread_id, const std::string& run_id, + const dto::RunUpdateDto& update_dto) -> void; + + /** + * Returns a list of runs belonging to a thread. + */ + auto ListRuns(const std::string& thread_id, uint8_t limit, + const std::string& order, const std::string& after, + const std::string& before) const + -> cpp::result, std::string>; + + explicit RunService(std::shared_ptr run_repo, + std::shared_ptr assistant_srv, + std::shared_ptr model_srv, + std::shared_ptr message_srv, + std::shared_ptr inference_srv, + std::shared_ptr thread_srv) + : run_repository_{run_repo}, + assistant_srv_{assistant_srv}, + model_srv_{model_srv}, + message_srv_{message_srv}, + inference_srv_{inference_srv}, + thread_srv_{thread_srv} {} + + private: + std::shared_ptr run_repository_; + + std::shared_ptr assistant_srv_; + std::shared_ptr model_srv_; + std::shared_ptr message_srv_; + std::shared_ptr inference_srv_; + std::shared_ptr thread_srv_; + + auto GetMessageListAsJson(const std::string& thread_id) + -> cpp::result, std::string>; +}; diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc index 9c5e7e857..a03670988 100644 --- a/engine/services/thread_service.cc +++ b/engine/services/thread_service.cc @@ -1,6 +1,6 @@ #include "thread_service.h" -#include #include "utils/logging_utils.h" +#include "utils/time_utils.h" #include "utils/ulid_generator.h" cpp::result ThreadService::CreateThread( @@ -8,15 +8,10 @@ cpp::result ThreadService::CreateThread( std::optional metadata) { LOG_TRACE << "CreateThread"; - auto seconds_since_epoch = - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(); - OpenAi::Thread thread; thread.id = ulid::GenerateUlid(); thread.object = "thread"; - thread.created_at = seconds_since_epoch; + thread.created_at = cortex_utils::SecondsSinceEpoch(); if (tool_resources) { thread.tool_resources = std::move(tool_resources); diff --git a/engine/test/components/test_function_calling.cc b/engine/test/components/test_function_calling.cc index 7a4810b29..18a7e8f57 100644 --- a/engine/test/components/test_function_calling.cc +++ b/engine/test/components/test_function_calling.cc @@ -1,6 +1,5 @@ #include #include "gtest/gtest.h" -#include "json/json.h" #include "utils/function_calling/common.h" class FunctionCallingUtilsTest : public ::testing::Test { @@ -154,4 +153,4 @@ TEST_F(FunctionCallingUtilsTest, PostProcessResponse) { ["arguments"] .asString(), "{\"arg\":\"value\"}"); -} \ No newline at end of file +} diff --git a/engine/test/components/test_tool_resources.cc b/engine/test/components/test_tool_resources.cc index 2b78e6494..0aed37c00 100644 --- a/engine/test/components/test_tool_resources.cc +++ b/engine/test/components/test_tool_resources.cc @@ -8,7 +8,7 @@ namespace { // Mock class for testing abstract ToolResources class MockToolResources : public ToolResources { public: - cpp::result ToJson() override { + cpp::result ToJson() const override { Json::Value json; json["mock"] = "value"; return json; diff --git a/engine/utils/cpuid/cpu_info.cc b/engine/utils/cpuid/cpu_info.cc index 3d4a56ffc..e2c0c2a2f 100644 --- a/engine/utils/cpuid/cpu_info.cc +++ b/engine/utils/cpuid/cpu_info.cc @@ -26,7 +26,7 @@ CpuInfo::CpuInfo() : impl(new Impl()) { init_cpuinfo(*impl); } -CpuInfo::~CpuInfo() {} +CpuInfo::~CpuInfo() = default; // x86 member functions bool CpuInfo::has_fpu() const { diff --git a/engine/utils/time_utils.h b/engine/utils/time_utils.h new file mode 100644 index 000000000..d4e69ae15 --- /dev/null +++ b/engine/utils/time_utils.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace cortex_utils { +inline auto SecondsSinceEpoch() -> uint32_t { + + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} +} // namespace cortex_utils