Skip to content

Commit

Permalink
feat: rendering chat_template (#1814)
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai authored Dec 23, 2024
1 parent a8b2503 commit e408f78
Show file tree
Hide file tree
Showing 17 changed files with 4,402 additions and 168 deletions.
13 changes: 6 additions & 7 deletions engine/cli/commands/chat_completion_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) {

return data_length;
}

} // namespace

void ChatCompletionCmd::Exec(const std::string& host, int port,
Expand Down Expand Up @@ -103,7 +102,7 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
return;
}

std::string url = "http://" + address + "/v1/chat/completions";
auto url = "http://" + address + "/v1/chat/completions";
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_POST, 1L);

Expand Down Expand Up @@ -151,18 +150,18 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
json_data["model"] = model_handle;
json_data["stream"] = true;

std::string json_payload = json_data.toStyledString();

curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str());
auto json_str = json_data.toStyledString();
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, json_str.length());
curl_easy_setopt(curl, CURLOPT_TCP_KEEPALIVE, 1L);

std::string ai_chat;
StreamingCallback callback;
callback.ai_chat = &ai_chat;

curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &callback);

CURLcode res = curl_easy_perform(curl);
auto res = curl_easy_perform(curl);

if (res != CURLE_OK) {
CLI_LOG("CURL request failed: " << curl_easy_strerror(res));
Expand Down
29 changes: 29 additions & 0 deletions engine/common/model_metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <sstream>
#include "common/tokenizer.h"

struct ModelMetadata {
uint32_t version;
uint64_t tensor_count;
uint64_t metadata_kv_count;
std::shared_ptr<Tokenizer> tokenizer;

std::string ToString() const {
std::ostringstream ss;
ss << "ModelMetadata {\n"
<< "version: " << version << "\n"
<< "tensor_count: " << tensor_count << "\n"
<< "metadata_kv_count: " << metadata_kv_count << "\n"
<< "tokenizer: ";

if (tokenizer) {
ss << "\n" << tokenizer->ToString();
} else {
ss << "null";
}

ss << "\n}";
return ss.str();
}
};
72 changes: 72 additions & 0 deletions engine/common/tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include <sstream>
#include <string>

struct Tokenizer {
std::string eos_token = "";
bool add_eos_token = true;

std::string bos_token = "";
bool add_bos_token = true;

std::string unknown_token = "";
std::string padding_token = "";

std::string chat_template = "";

bool add_generation_prompt = true;

// Helper function for common fields
std::string BaseToString() const {
std::ostringstream ss;
ss << "eos_token: \"" << eos_token << "\"\n"
<< "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n"
<< "bos_token: \"" << bos_token << "\"\n"
<< "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n"
<< "unknown_token: \"" << unknown_token << "\"\n"
<< "padding_token: \"" << padding_token << "\"\n"
<< "chat_template: \"" << chat_template << "\"\n"
<< "add_generation_prompt: "
<< (add_generation_prompt ? "true" : "false") << "\"";
return ss.str();
}

virtual ~Tokenizer() = default;

virtual std::string ToString() = 0;
};

struct GgufTokenizer : public Tokenizer {
std::string pre = "";

~GgufTokenizer() override = default;

std::string ToString() override {
std::ostringstream ss;
ss << "GgufTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "pre: \"" << pre << "\"\n";
ss << "}";
return ss.str();
}
};

struct SafeTensorTokenizer : public Tokenizer {
bool add_prefix_space = true;

~SafeTensorTokenizer() = default;

std::string ToString() override {
std::ostringstream ss;
ss << "SafeTensorTokenizer {\n";
// Add base class members
ss << BaseToString() << "\n";
// Add derived class members
ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n";
ss << "}";
return ss.str();
}
};
17 changes: 5 additions & 12 deletions engine/controllers/files.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,8 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp =
cortex_utils::CreateCortexContentResponse(std::move(res.value()));
callback(resp);
} else {
if (!msg_res->rel_path.has_value()) {
Expand All @@ -243,10 +241,8 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(content_res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp = cortex_utils::CreateCortexContentResponse(
std::move(content_res.value()));
callback(resp);
}
}
Expand All @@ -261,9 +257,6 @@ void Files::RetrieveFileContent(
return;
}

auto [buffer, size] = std::move(res.value());
auto resp = HttpResponse::newHttpResponse();
resp->setBody(std::string(buffer.get(), size));
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
auto resp = cortex_utils::CreateCortexContentResponse(std::move(res.value()));
callback(resp);
}
9 changes: 8 additions & 1 deletion engine/controllers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "trantor/utils/Logger.h"
#include "utils/cortex_utils.h"
#include "utils/function_calling/common.h"
#include "utils/http_util.h"

using namespace inferences;

Expand All @@ -27,6 +26,14 @@ void server::ChatCompletion(
std::function<void(const HttpResponsePtr&)>&& callback) {
LOG_DEBUG << "Start chat completion";
auto json_body = req->getJsonObject();
if (json_body == nullptr) {
Json::Value ret;
ret["message"] = "Body can't be empty";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
return;
}
bool is_stream = (*json_body).get("stream", false).asBool();
auto model_id = (*json_body).get("model", "invalid_model").asString();
auto engine_type = [this, &json_body]() -> std::string {
Expand Down
1 change: 1 addition & 0 deletions engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
auto model_src_svc = std::make_shared<services::ModelSourceService>();
auto model_service = std::make_shared<ModelService>(
download_service, inference_svc, engine_service);
inference_svc->SetModelService(model_service);

auto file_watcher_srv = std::make_shared<FileWatcherService>(
model_dir_path.string(), model_service);
Expand Down
20 changes: 8 additions & 12 deletions engine/services/engine_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <mutex>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

Expand All @@ -17,7 +16,6 @@
#include "utils/cpuid/cpu_info.h"
#include "utils/dylib.h"
#include "utils/dylib_path_manager.h"
#include "utils/engine_constants.h"
#include "utils/github_release_utils.h"
#include "utils/result.hpp"
#include "utils/system_info_utils.h"
Expand Down Expand Up @@ -48,10 +46,6 @@ class EngineService : public EngineServiceI {
struct EngineInfo {
std::unique_ptr<cortex_cpp::dylib> dl;
EngineV engine;
#if defined(_WIN32)
DLL_DIRECTORY_COOKIE cookie;
DLL_DIRECTORY_COOKIE cuda_cookie;
#endif
};

std::mutex engines_mutex_;
Expand Down Expand Up @@ -106,21 +100,23 @@ class EngineService : public EngineServiceI {

cpp::result<DefaultEngineVariant, std::string> SetDefaultEngineVariant(
const std::string& engine, const std::string& version,
const std::string& variant);
const std::string& variant) override;

cpp::result<DefaultEngineVariant, std::string> GetDefaultEngineVariant(
const std::string& engine);
const std::string& engine) override;

cpp::result<std::vector<EngineVariantResponse>, std::string>
GetInstalledEngineVariants(const std::string& engine) const;
GetInstalledEngineVariants(const std::string& engine) const override;

cpp::result<EngineV, std::string> GetLoadedEngine(
const std::string& engine_name);

std::vector<EngineV> GetLoadedEngines();

cpp::result<void, std::string> LoadEngine(const std::string& engine_name);
cpp::result<void, std::string> UnloadEngine(const std::string& engine_name);
cpp::result<void, std::string> LoadEngine(
const std::string& engine_name) override;
cpp::result<void, std::string> UnloadEngine(
const std::string& engine_name) override;

cpp::result<github_release_utils::GitHubRelease, std::string>
GetLatestEngineVersion(const std::string& engine) const;
Expand All @@ -138,7 +134,7 @@ class EngineService : public EngineServiceI {

cpp::result<cortex::db::EngineEntry, std::string> GetEngineByNameAndVariant(
const std::string& engine_name,
const std::optional<std::string> variant = std::nullopt);
const std::optional<std::string> variant = std::nullopt) override;

cpp::result<cortex::db::EngineEntry, std::string> UpsertEngine(
const std::string& engine_name, const std::string& type,
Expand Down
42 changes: 41 additions & 1 deletion engine/services/inference_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <drogon/HttpTypes.h>
#include "utils/engine_constants.h"
#include "utils/function_calling/common.h"
#include "utils/jinja_utils.h"

namespace services {
cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
Expand All @@ -24,6 +25,45 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
return cpp::fail(std::make_pair(stt, res));
}

{
auto model_id = json_body->get("model", "").asString();
if (!model_id.empty()) {
if (auto model_service = model_service_.lock()) {
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
if (metadata_ptr != nullptr &&
!metadata_ptr->tokenizer->chat_template.empty()) {
auto tokenizer = metadata_ptr->tokenizer;
auto messages = (*json_body)["messages"];
Json::Value messages_jsoncpp(Json::arrayValue);
for (auto message : messages) {
messages_jsoncpp.append(message);
}

Json::Value tools(Json::arrayValue);
Json::Value template_data_json;
template_data_json["messages"] = messages_jsoncpp;
// template_data_json["tools"] = tools;

auto prompt_result = jinja::RenderTemplate(
tokenizer->chat_template, template_data_json,
tokenizer->bos_token, tokenizer->eos_token,
tokenizer->add_bos_token, tokenizer->add_eos_token,
tokenizer->add_generation_prompt);
if (prompt_result.has_value()) {
(*json_body)["prompt"] = prompt_result.value();
Json::Value stops(Json::arrayValue);
stops.append(tokenizer->eos_token);
(*json_body)["stop"] = stops;
} else {
CTL_ERR("Failed to render prompt: " + prompt_result.error());
}
}
}
}
}

CTL_INF("Json body inference: " + json_body->toStyledString());

auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
res["tool_choice"] = tool_choice;
Expand Down Expand Up @@ -297,4 +337,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
}
return true;
}
} // namespace services
} // namespace services
9 changes: 8 additions & 1 deletion engine/services/inference_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#include <mutex>
#include <queue>
#include "services/engine_service.h"
#include "services/model_service.h"
#include "utils/result.hpp"
#include "extensions/remote-engine/remote_engine.h"

namespace services {

// Status and result
using InferResult = std::pair<Json::Value, Json::Value>;

Expand Down Expand Up @@ -58,7 +60,12 @@ class InferenceService {
bool HasFieldInReq(std::shared_ptr<Json::Value> json_body,
const std::string& field);

void SetModelService(std::shared_ptr<ModelService> model_service) {
model_service_ = model_service;
}

private:
std::shared_ptr<EngineService> engine_service_;
std::weak_ptr<ModelService> model_service_;
};
} // namespace services
Loading

0 comments on commit e408f78

Please sign in to comment.