Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unload model stop background #122

Merged
merged 2 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,25 @@ void llamaCPP::embedding(
return;
}

void llamaCPP::unloadModel(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
Json::Value jsonResp;
jsonResp["message"] = "No model loaded";
if (model_loaded) {
stopBackgroundTask();

llama_free(llama.ctx);
llama_free_model(llama.model);
llama.ctx = nullptr;
llama.model = nullptr;
jsonResp["message"] = "Model unloaded successfully";
}
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
callback(resp);
return;
}

void llamaCPP::loadModel(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
Expand Down Expand Up @@ -274,7 +293,19 @@ void llamaCPP::loadModel(

void llamaCPP::backgroundTask() {
while (model_loaded) {
model_loaded = llama.update_slots();
// model_loaded =
llama.update_slots();
}
LOG_INFO << "Background task stopped!";
return;
}

void llamaCPP::stopBackgroundTask() {
if (model_loaded) {
model_loaded = false;
LOG_INFO << "changed to false";
if (backgroundThread.joinable()) {
backgroundThread.join();
}
}
}
8 changes: 7 additions & 1 deletion controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,8 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
METHOD_ADD(llamaCPP::chatCompletion, "chat_completion", Post);
METHOD_ADD(llamaCPP::embedding, "embedding", Post);
METHOD_ADD(llamaCPP::loadModel, "loadmodel", Post);
METHOD_ADD(llamaCPP::unloadModel, "unloadmodel", Get);

// PATH_ADD("/llama/chat_completion", Post);
METHOD_LIST_END
void chatCompletion(const HttpRequestPtr &req,
Expand All @@ -2132,13 +2134,17 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
std::function<void(const HttpResponsePtr &)> &&callback);
void loadModel(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void unloadModel(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
void warmupModel();

void backgroundTask();

void stopBackgroundTask();

private:
llama_server_context llama;
bool model_loaded = false;
std::atomic<bool> model_loaded = false;
size_t sent_count = 0;
size_t sent_token_probs_index = 0;
std::thread backgroundThread;
Expand Down
Loading