Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sangjanai committed Jan 9, 2025
1 parent da7576d commit 4df0704
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 124 deletions.
203 changes: 126 additions & 77 deletions engine/extensions/python-engine/python_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,82 @@
#include <sstream>
#include <string>
namespace python_engine {
namespace {
constexpr const int k200OK = 200;
constexpr const int k400BadRequest = 400;
constexpr const int k409Conflict = 409;
constexpr const int k500InternalServerError = 500;
constexpr const int kFileLoggerOption = 0;

size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
void* userdata) {
auto* context = static_cast<StreamContext*>(userdata);
std::string chunk(ptr, size * nmemb);

context->buffer += chunk;

// Process complete lines
size_t pos;
while ((pos = context->buffer.find('\n')) != std::string::npos) {
std::string line = context->buffer.substr(0, pos);
context->buffer = context->buffer.substr(pos + 1);
LOG_DEBUG << "line: " << line;

// Skip empty lines
if (line.empty() || line == "\r")
continue;

if (line == "data: [DONE]") {
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = 200;
(*context->callback)(std::move(status), Json::Value());
break;
}

// Parse the JSON
Json::Value chunk_json;
chunk_json["data"] = line + "\n\n";
Json::Reader reader;

Json::Value status;
status["is_done"] = false;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = 200;
(*context->callback)(std::move(status), std::move(chunk_json));
}

return size * nmemb;
}

static size_t WriteCallback(char* ptr, size_t size, size_t nmemb,
std::string* data) {
data->append(ptr, size * nmemb);
return size * nmemb;
}

PythonEngine::PythonEngine() : q_(4 /*n_parallel*/, "python_engine") {}

PythonEngine::~PythonEngine() {
curl_global_cleanup();
}

config::PythonModelConfig* PythonEngine::GetModelConfig(
const std::string& model) {
std::shared_lock lock(models_mutex_);
auto it = models_.find(model);
if (it != models_.end()) {
return &it->second;
}
return nullptr;
}
std::string constructWindowsCommandLine(const std::vector<std::string>& args) {
std::string cmdLine;
std::string ConstructWindowsCommandLine(const std::vector<std::string>& args) {
std::string cmd_line;
for (const auto& arg : args) {
// Simple escaping for Windows command line
std::string escapedArg = arg;
if (escapedArg.find(' ') != std::string::npos) {
std::string escaped_arg = arg;
if (escaped_arg.find(' ') != std::string::npos) {
// Wrap in quotes and escape existing quotes
for (char& c : escapedArg) {
for (char& c : escaped_arg) {
if (c == '"')
c = '\\';
}
escapedArg = "\"" + escapedArg + "\"";
escaped_arg = "\"" + escaped_arg + "\"";
}
cmdLine += escapedArg + " ";
cmd_line += escaped_arg + " ";
}
return cmdLine;
return cmd_line;
}

std::vector<char*> convertToArgv(const std::vector<std::string>& args) {
std::vector<char*> ConvertToArgv(const std::vector<std::string>& args) {
std::vector<char*> argv;
for (const auto& arg : args) {
argv.push_back(const_cast<char*>(arg.c_str()));
Expand All @@ -58,58 +88,74 @@ std::vector<char*> convertToArgv(const std::vector<std::string>& args) {
return argv;
}

} // namespace

PythonEngine::PythonEngine() : q_(4 /*n_parallel*/, "python_engine") {}

PythonEngine::~PythonEngine() {
curl_global_cleanup();
}

config::PythonModelConfig* PythonEngine::GetModelConfig(
const std::string& model) {
std::shared_lock lock(models_mutex_);
auto it = models_.find(model);
if (it != models_.end()) {
return &it->second;
}
return nullptr;
}

// TODO(sang) move to utils to re-use
pid_t PythonEngine::SpawnProcess(const std::string& model,
const std::vector<std::string>& command) {
try {
#ifdef _WIN32
#if defined(_WIN32)
// Windows process creation
STARTUPINFOA si = {0};
PROCESS_INFORMATION pi = {0};
si.cb = sizeof(si);

// Construct command line
std::string cmdLine = constructWindowsCommandLine(command);
std::string cmd_line = ConstructWindowsCommandLine(command);

// Convert string to char* for Windows API
char commandBuffer[4096];
strncpy_s(commandBuffer, cmdLine.c_str(), sizeof(commandBuffer));

if (!CreateProcessA(NULL, // lpApplicationName
commandBuffer, // lpCommandLine
NULL, // lpProcessAttributes
NULL, // lpThreadAttributes
FALSE, // bInheritHandles
0, // dwCreationFlags
NULL, // lpEnvironment
NULL, // lpCurrentDirectory
&si, // lpStartupInfo
&pi // lpProcessInformation
char command_buffer[4096];
strncpy_s(command_buffer, cmd_line.c_str(), sizeof(command_buffer));

if (!CreateProcessA(NULL, // lpApplicationName
command_buffer, // lpCommandLine
NULL, // lpProcessAttributes
NULL, // lpThreadAttributes
FALSE, // bInheritHandles
0, // dwCreationFlags
NULL, // lpEnvironment
NULL, // lpCurrentDirectory
&si, // lpStartupInfo
&pi // lpProcessInformation
)) {
throw std::runtime_error("Failed to create process on Windows");
}

// Store the process ID
pid_t pid = pi.dwProcessId;
processMap[model] = pid;
process_map_[model] = pid;

// Close handles to avoid resource leaks
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);

return pid;

#elif __APPLE__ || __linux__
#elif defined(__APPLE__) || defined(__linux__)
// POSIX process creation
pid_t pid;

// Convert command vector to char*[]
std::vector<char*> argv = convertToArgv(command);
// for (auto c : command) {
// std::cout << c << " " << std::endl;
// }
auto argv = ConvertToArgv(command);

// Use posix_spawn for cross-platform compatibility
int spawn_result = posix_spawn(&pid, // pid output
auto spawn_result = posix_spawn(&pid, // pid output
command[0].c_str(), // executable path
NULL, // file actions
NULL, // spawn attributes
Expand All @@ -122,7 +168,7 @@ pid_t PythonEngine::SpawnProcess(const std::string& model,
}

// Store the process ID
processMap[model] = pid;
process_map_[model] = pid;
return pid;

#else
Expand All @@ -133,16 +179,17 @@ pid_t PythonEngine::SpawnProcess(const std::string& model,
return -1;
}
}

bool PythonEngine::TerminateModelProcess(const std::string& model) {
auto it = processMap.find(model);
if (it == processMap.end()) {
auto it = process_map_.find(model);
if (it == process_map_.end()) {
LOG_ERROR << "No process found for model: " << model
<< ", removing from list running models.";
models_.erase(model);
return false;
}

#ifdef _WIN32
#if defined(_WIN32)
HANDLE hProcess = OpenProcess(PROCESS_TERMINATE, FALSE, it->second);
if (hProcess == NULL) {
LOG_ERROR << "Failed to open process";
Expand All @@ -153,20 +200,21 @@ bool PythonEngine::TerminateModelProcess(const std::string& model) {
CloseHandle(hProcess);

if (terminated) {
processMap.erase(it);
process_map_.erase(it);
return true;
}

#elif __APPLE__ || __linux__
#elif defined(__APPLE__) || defined(__linux__)
int result = kill(it->second, SIGTERM);
if (result == 0) {
processMap.erase(it);
process_map_.erase(it);
return true;
}
#endif

return false;
}

CurlResponse PythonEngine::MakeGetRequest(const std::string& model,
const std::string& path) {
auto const& config = models_[model];
Expand All @@ -182,6 +230,7 @@ CurlResponse PythonEngine::MakeGetRequest(const std::string& model,
}
return response;
}

CurlResponse PythonEngine::MakeDeleteRequest(const std::string& model,
const std::string& path) {
auto const& config = models_[model];
Expand Down Expand Up @@ -304,7 +353,7 @@ void PythonEngine::LoadModel(
auto data_folder_path =
std::filesystem::path(model_folder_path) / std::filesystem::path("venv");
try {
#ifdef _WIN32
#if defined(_WIN32)
auto executable = std::filesystem::path(data_folder_path) /
std::filesystem::path("Scripts");
#else
Expand Down Expand Up @@ -416,16 +465,16 @@ void PythonEngine::UnloadModel(
return;
}

const std::string& model = (*json_body)["model"].asString();
auto model = (*json_body)["model"].asString();

{
std::unique_lock lock(models_mutex_);
{
if (TerminateModelProcess(model)) {
std::unique_lock lock(models_mutex_);
models_.erase(model);
} else {
Json::Value error;
error["error"] = "Fail to terminate process with id: " +
std::to_string(processMap[model]);
std::to_string(process_map_[model]);
Json::Value status;
status["is_done"] = true;
status["has_error"] = true;
Expand All @@ -448,7 +497,9 @@ void PythonEngine::UnloadModel(

void PythonEngine::HandleChatCompletion(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {}
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
LOG_WARN << "Does not support yet!";
}

CurlResponse PythonEngine::MakeStreamPostRequest(
const std::string& model, const std::string& path, const std::string& body,
Expand Down Expand Up @@ -509,7 +560,7 @@ CurlResponse PythonEngine::MakeStreamPostRequest(
void PythonEngine::HandleInference(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
if (!json_body->isMember("model")) {
if (json_body && !json_body->isMember("model")) {
Json::Value error;
error["error"] = "Missing required field: model is required!";
Json::Value status;
Expand All @@ -520,14 +571,14 @@ void PythonEngine::HandleInference(
callback(std::move(status), std::move(error));
return;
}

std::string method = "post";
std::string path = "/inference";
std::string transform_request =
(*json_body).get("transform_request", "").asString();
std::string transform_response =
auto transform_request = (*json_body).get("transform_request", "").asString();
auto transform_response =
(*json_body).get("transform_response", "").asString();
std::string model = (*json_body)["model"].asString();
Json::Value body = (*json_body)["body"];
auto model = (*json_body)["model"].asString();
auto& body = (*json_body)["body"];

if (models_.find(model) == models_.end()) {
Json::Value error;
Expand Down Expand Up @@ -680,10 +731,13 @@ void PythonEngine::HandleInference(
callback(std::move(status), std::move(response_json));
}
}

Json::Value PythonEngine::GetRemoteModels() {
return Json::Value();
}

void PythonEngine::StopInferencing(const std::string& model_id) {}

void PythonEngine::HandleRouteRequest(
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
Expand All @@ -700,14 +754,13 @@ void PythonEngine::HandleRouteRequest(
callback(std::move(status), std::move(error));
return;
}
std::string method = (*json_body)["method"].asString();
std::string path = (*json_body)["path"].asString();
std::string transform_request =
(*json_body).get("transform_request", "").asString();
std::string transform_response =
auto method = (*json_body)["method"].asString();
auto path = (*json_body)["path"].asString();
auto transform_request = (*json_body).get("transform_request", "").asString();
auto transform_response =
(*json_body).get("transform_response", "").asString();
std::string model = (*json_body)["model"].asString();
Json::Value body = (*json_body)["body"];
auto model = (*json_body)["model"].asString();
auto& body = (*json_body)["body"];

if (models_.find(model) == models_.end()) {
Json::Value error;
Expand Down Expand Up @@ -864,6 +917,7 @@ void PythonEngine::GetModelStatus(
callback(std::move(status), std::move(error));
return;
}

auto model = json_body->get("model", "").asString();
auto model_config = models_[model];
auto health_endpoint = model_config.heath_check;
Expand Down Expand Up @@ -947,9 +1001,4 @@ void PythonEngine::Unload(EngineUnloadOption opts) {
}
};

// extern "C" {
// EngineI* get_engine() {
// return new PythonEngine();
// }
// }
} // namespace python_engine
Loading

0 comments on commit 4df0704

Please sign in to comment.