Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add cpu_threads to model.yaml #1845

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions engine/cli/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,7 @@ void CommandLineParser::ModelUpdate(CLI::App* parent) {
"ngl",
"ctx_len",
"n_parallel",
"cpu_threads",
"engine",
"prompt_template",
"system_template",
Expand Down
6 changes: 6 additions & 0 deletions engine/cli/commands/model_upd_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ void ModelUpdCmd::UpdateConfig(Json::Value& data, const std::string& key,
data["n_parallel"] = static_cast<int>(f);
});
}},
{"cpu_threads",
[this](Json::Value &data, const std::string& k, const std::string& v) {
UpdateNumericField(k, v, [&data](float f) {
data["cpu_threads"] = static_cast<int>(f);
});
}},
{"tp",
[this](Json::Value &data, const std::string& k, const std::string& v) {
UpdateNumericField(k, v, [&data](float f) {
Expand Down
8 changes: 8 additions & 0 deletions engine/config/model_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ struct ModelConfig {
int ngl = std::numeric_limits<int>::quiet_NaN();
int ctx_len = std::numeric_limits<int>::quiet_NaN();
int n_parallel = 1;
int cpu_threads = -1;
std::string engine;
std::string prompt_template;
std::string system_template;
Expand Down Expand Up @@ -272,6 +273,8 @@ struct ModelConfig {
ctx_len = json["ctx_len"].asInt();
if (json.isMember("n_parallel"))
n_parallel = json["n_parallel"].asInt();
if (json.isMember("cpu_threads"))
cpu_threads = json["cpu_threads"].asInt();
if (json.isMember("engine"))
engine = json["engine"].asString();
if (json.isMember("prompt_template"))
Expand Down Expand Up @@ -362,6 +365,9 @@ struct ModelConfig {
obj["ngl"] = ngl;
obj["ctx_len"] = ctx_len;
obj["n_parallel"] = n_parallel;
if (cpu_threads > 0) {
obj["cpu_threads"] = cpu_threads;
}
obj["engine"] = engine;
obj["prompt_template"] = prompt_template;
obj["system_template"] = system_template;
Expand Down Expand Up @@ -474,6 +480,8 @@ struct ModelConfig {
format_utils::MAGENTA);
oss << format_utils::print_kv("n_parallel", std::to_string(n_parallel),
format_utils::MAGENTA);
oss << format_utils::print_kv("cpu_threads", std::to_string(cpu_threads),
format_utils::MAGENTA);
if (ngl != std::numeric_limits<int>::quiet_NaN())
oss << format_utils::print_kv("ngl", std::to_string(ngl),
format_utils::MAGENTA);
Expand Down
124 changes: 65 additions & 59 deletions engine/config/yaml_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ void YamlHandler::ModelConfigFromYaml() {
tmp.ctx_len = yaml_node_["ctx_len"].as<int>();
if (yaml_node_["n_parallel"])
tmp.n_parallel = yaml_node_["n_parallel"].as<int>();
if (yaml_node_["cpu_threads"])
tmp.cpu_threads = yaml_node_["cpu_threads"].as<int>();
if (yaml_node_["tp"])
tmp.tp = yaml_node_["tp"].as<int>();
if (yaml_node_["stream"])
Expand Down Expand Up @@ -224,6 +226,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
yaml_node_["ctx_len"] = model_config_.ctx_len;
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
yaml_node_["n_parallel"] = model_config_.n_parallel;
if (!std::isnan(static_cast<double>(model_config_.cpu_threads)))
yaml_node_["cpu_threads"] = model_config_.cpu_threads;
if (!std::isnan(static_cast<double>(model_config_.tp)))
yaml_node_["tp"] = model_config_.tp;
if (!std::isnan(static_cast<double>(model_config_.stream)))
Expand Down Expand Up @@ -283,110 +287,112 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
// Method to write all attributes to a YAML file
void YamlHandler::WriteYamlFile(const std::string& file_path) const {
try {
std::ofstream outFile(file_path);
if (!outFile) {
std::ofstream out_file(file_path);
if (!out_file) {
throw std::runtime_error("Failed to open output file.");
}
// Write GENERAL GGUF METADATA
outFile << "# BEGIN GENERAL GGUF METADATA\n";
outFile << format_utils::writeKeyValue(
out_file << "# BEGIN GENERAL GGUF METADATA\n";
out_file << format_utils::WriteKeyValue(
"id", yaml_node_["id"],
"Model ID unique between models (author / quantization)");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"model", yaml_node_["model"],
"Model ID which is used for request construct - should be "
"unique between models (author / quantization)");
outFile << format_utils::writeKeyValue("name", yaml_node_["name"],
out_file << format_utils::WriteKeyValue("name", yaml_node_["name"],
"metadata.general.name");
if (yaml_node_["version"]) {
outFile << "version: " << yaml_node_["version"].as<std::string>() << "\n";
out_file << "version: " << yaml_node_["version"].as<std::string>() << "\n";
}
if (yaml_node_["files"] && yaml_node_["files"].size()) {
outFile << "files: # Can be relative OR absolute local file "
out_file << "files: # Can be relative OR absolute local file "
"path\n";
for (const auto& source : yaml_node_["files"]) {
outFile << " - " << source << "\n";
out_file << " - " << source << "\n";
}
}

outFile << "# END GENERAL GGUF METADATA\n";
outFile << "\n";
out_file << "# END GENERAL GGUF METADATA\n";
out_file << "\n";
// Write INFERENCE PARAMETERS
outFile << "# BEGIN INFERENCE PARAMETERS\n";
outFile << "# BEGIN REQUIRED\n";
out_file << "# BEGIN INFERENCE PARAMETERS\n";
out_file << "# BEGIN REQUIRED\n";
if (yaml_node_["stop"] && yaml_node_["stop"].size()) {
outFile << "stop: # tokenizer.ggml.eos_token_id\n";
out_file << "stop: # tokenizer.ggml.eos_token_id\n";
for (const auto& stop : yaml_node_["stop"]) {
outFile << " - " << stop << "\n";
out_file << " - " << stop << "\n";
}
}

outFile << "# END REQUIRED\n";
outFile << "\n";
outFile << "# BEGIN OPTIONAL\n";
outFile << format_utils::writeKeyValue("size", yaml_node_["size"]);
outFile << format_utils::writeKeyValue("stream", yaml_node_["stream"],
out_file << "# END REQUIRED\n";
out_file << "\n";
out_file << "# BEGIN OPTIONAL\n";
out_file << format_utils::WriteKeyValue("size", yaml_node_["size"]);
out_file << format_utils::WriteKeyValue("stream", yaml_node_["stream"],
"Default true?");
outFile << format_utils::writeKeyValue("top_p", yaml_node_["top_p"],
out_file << format_utils::WriteKeyValue("top_p", yaml_node_["top_p"],
"Ranges: 0 to 1");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"temperature", yaml_node_["temperature"], "Ranges: 0 to 1");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"frequency_penalty", yaml_node_["frequency_penalty"], "Ranges: 0 to 1");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"presence_penalty", yaml_node_["presence_penalty"], "Ranges: 0 to 1");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"max_tokens", yaml_node_["max_tokens"],
"Should be default to context length");
outFile << format_utils::writeKeyValue("seed", yaml_node_["seed"]);
outFile << format_utils::writeKeyValue("dynatemp_range",
out_file << format_utils::WriteKeyValue("seed", yaml_node_["seed"]);
out_file << format_utils::WriteKeyValue("dynatemp_range",
yaml_node_["dynatemp_range"]);
outFile << format_utils::writeKeyValue("dynatemp_exponent",
out_file << format_utils::WriteKeyValue("dynatemp_exponent",
yaml_node_["dynatemp_exponent"]);
outFile << format_utils::writeKeyValue("top_k", yaml_node_["top_k"]);
outFile << format_utils::writeKeyValue("min_p", yaml_node_["min_p"]);
outFile << format_utils::writeKeyValue("tfs_z", yaml_node_["tfs_z"]);
outFile << format_utils::writeKeyValue("typ_p", yaml_node_["typ_p"]);
outFile << format_utils::writeKeyValue("repeat_last_n",
out_file << format_utils::WriteKeyValue("top_k", yaml_node_["top_k"]);
out_file << format_utils::WriteKeyValue("min_p", yaml_node_["min_p"]);
out_file << format_utils::WriteKeyValue("tfs_z", yaml_node_["tfs_z"]);
out_file << format_utils::WriteKeyValue("typ_p", yaml_node_["typ_p"]);
out_file << format_utils::WriteKeyValue("repeat_last_n",
yaml_node_["repeat_last_n"]);
outFile << format_utils::writeKeyValue("repeat_penalty",
out_file << format_utils::WriteKeyValue("repeat_penalty",
yaml_node_["repeat_penalty"]);
outFile << format_utils::writeKeyValue("mirostat", yaml_node_["mirostat"]);
outFile << format_utils::writeKeyValue("mirostat_tau",
out_file << format_utils::WriteKeyValue("mirostat", yaml_node_["mirostat"]);
out_file << format_utils::WriteKeyValue("mirostat_tau",
yaml_node_["mirostat_tau"]);
outFile << format_utils::writeKeyValue("mirostat_eta",
out_file << format_utils::WriteKeyValue("mirostat_eta",
yaml_node_["mirostat_eta"]);
outFile << format_utils::writeKeyValue("penalize_nl",
out_file << format_utils::WriteKeyValue("penalize_nl",
yaml_node_["penalize_nl"]);
outFile << format_utils::writeKeyValue("ignore_eos",
out_file << format_utils::WriteKeyValue("ignore_eos",
yaml_node_["ignore_eos"]);
outFile << format_utils::writeKeyValue("n_probs", yaml_node_["n_probs"]);
outFile << format_utils::writeKeyValue("min_keep", yaml_node_["min_keep"]);
outFile << format_utils::writeKeyValue("grammar", yaml_node_["grammar"]);
outFile << "# END OPTIONAL\n";
outFile << "# END INFERENCE PARAMETERS\n";
outFile << "\n";
out_file << format_utils::WriteKeyValue("n_probs", yaml_node_["n_probs"]);
out_file << format_utils::WriteKeyValue("min_keep", yaml_node_["min_keep"]);
out_file << format_utils::WriteKeyValue("grammar", yaml_node_["grammar"]);
out_file << "# END OPTIONAL\n";
out_file << "# END INFERENCE PARAMETERS\n";
out_file << "\n";
// Write MODEL LOAD PARAMETERS
outFile << "# BEGIN MODEL LOAD PARAMETERS\n";
outFile << "# BEGIN REQUIRED\n";
outFile << format_utils::writeKeyValue("engine", yaml_node_["engine"],
out_file << "# BEGIN MODEL LOAD PARAMETERS\n";
out_file << "# BEGIN REQUIRED\n";
out_file << format_utils::WriteKeyValue("engine", yaml_node_["engine"],
"engine to run model");
outFile << "prompt_template:";
outFile << " " << yaml_node_["prompt_template"] << "\n";
outFile << "# END REQUIRED\n";
outFile << "\n";
outFile << "# BEGIN OPTIONAL\n";
outFile << format_utils::writeKeyValue(
out_file << "prompt_template:";
out_file << " " << yaml_node_["prompt_template"] << "\n";
out_file << "# END REQUIRED\n";
out_file << "\n";
out_file << "# BEGIN OPTIONAL\n";
out_file << format_utils::WriteKeyValue(
"ctx_len", yaml_node_["ctx_len"],
"llama.context_length | 0 or undefined = loaded from model");
outFile << format_utils::writeKeyValue("n_parallel",
out_file << format_utils::WriteKeyValue("n_parallel",
yaml_node_["n_parallel"]);
outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"],
out_file << format_utils::WriteKeyValue("cpu_threads",
yaml_node_["cpu_threads"]);
out_file << format_utils::WriteKeyValue("ngl", yaml_node_["ngl"],
"Undefined = loaded from model");
outFile << "# END OPTIONAL\n";
outFile << "# END MODEL LOAD PARAMETERS\n";
out_file << "# END OPTIONAL\n";
out_file << "# END MODEL LOAD PARAMETERS\n";

outFile.close();
out_file.close();
} catch (const std::exception& e) {
std::cerr << "Error writing to file: " << e.what() << std::endl;
throw;
Expand Down
12 changes: 6 additions & 6 deletions engine/test/components/test_format_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,37 @@ TEST_F(FormatUtilsTest, WriteKeyValue) {
{
YAML::Node node;
std::string result =
format_utils::writeKeyValue("key", node["does_not_exist"]);
format_utils::WriteKeyValue("key", node["does_not_exist"]);
EXPECT_EQ(result, "");
}

{
YAML::Node node = YAML::Load("value");
std::string result = format_utils::writeKeyValue("key", node);
std::string result = format_utils::WriteKeyValue("key", node);
EXPECT_EQ(result, "key: value\n");
}

{
YAML::Node node = YAML::Load("3.14159");
std::string result = format_utils::writeKeyValue("key", node);
std::string result = format_utils::WriteKeyValue("key", node);
EXPECT_EQ(result, "key: 3.14159\n");
}

{
YAML::Node node = YAML::Load("3.000000");
std::string result = format_utils::writeKeyValue("key", node);
std::string result = format_utils::WriteKeyValue("key", node);
EXPECT_EQ(result, "key: 3\n");
}

{
YAML::Node node = YAML::Load("3.140000");
std::string result = format_utils::writeKeyValue("key", node);
std::string result = format_utils::WriteKeyValue("key", node);
EXPECT_EQ(result, "key: 3.14\n");
}

{
YAML::Node node = YAML::Load("value");
std::string result = format_utils::writeKeyValue("key", node, "comment");
std::string result = format_utils::WriteKeyValue("key", node, "comment");
EXPECT_EQ(result, "key: value # comment\n");
}
}
Expand Down
6 changes: 6 additions & 0 deletions engine/test/components/test_yaml_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ temperature: 0.7
max_tokens: 100
stream: true
n_parallel: 2
cpu_threads: 3
stop:
- "END"
files:
Expand All @@ -84,6 +85,7 @@ n_parallel: 2
EXPECT_EQ(config.max_tokens, 100);
EXPECT_TRUE(config.stream);
EXPECT_EQ(config.n_parallel, 2);
EXPECT_EQ(config.cpu_threads, 3);
EXPECT_EQ(config.stop.size(), 1);
EXPECT_EQ(config.stop[0], "END");
EXPECT_EQ(config.files.size(), 1);
Expand All @@ -104,6 +106,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) {
new_config.max_tokens = 200;
new_config.stream = false;
new_config.n_parallel = 2;
new_config.cpu_threads = 3;
new_config.stop = {"STOP", "END"};
new_config.files = {"updated_file1.gguf", "updated_file2.gguf"};

Expand All @@ -120,6 +123,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) {
EXPECT_EQ(config.max_tokens, 200);
EXPECT_FALSE(config.stream);
EXPECT_EQ(config.n_parallel, 2);
EXPECT_EQ(config.cpu_threads, 3);
EXPECT_EQ(config.stop.size(), 2);
EXPECT_EQ(config.stop[0], "STOP");
EXPECT_EQ(config.stop[1], "END");
Expand All @@ -140,6 +144,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) {
new_config.max_tokens = 150;
new_config.stream = true;
new_config.n_parallel = 2;
new_config.cpu_threads = 3;
new_config.stop = {"HALT"};
new_config.files = {"write_test_file.gguf"};

Expand All @@ -164,6 +169,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) {
EXPECT_EQ(read_config.max_tokens, 150);
EXPECT_TRUE(read_config.stream);
EXPECT_EQ(read_config.n_parallel, 2);
EXPECT_EQ(read_config.cpu_threads, 3);
EXPECT_EQ(read_config.stop.size(), 1);
EXPECT_EQ(read_config.stop[0], "HALT");
EXPECT_EQ(read_config.files.size(), 1);
Expand Down
18 changes: 9 additions & 9 deletions engine/utils/format_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ inline std::string print_float(const std::string& key, float value) {
} else
return "";
};
inline std::string writeKeyValue(const std::string& key,
inline std::string WriteKeyValue(const std::string& key,
const YAML::Node& value,
const std::string& comment = "") {
std::ostringstream outFile;
std::ostringstream out_file;
if (!value)
return "";
outFile << key << ": ";
out_file << key << ": ";

// Check if the value is a float and round it to 6 decimal places
if (value.IsScalar()) {
Expand All @@ -66,19 +66,19 @@ inline std::string writeKeyValue(const std::string& key,
if (strValue.back() == '.') {
strValue.pop_back();
}
outFile << strValue;
out_file << strValue;
} catch (const std::exception& e) {
outFile << value; // If not a float, write as is
out_file << value; // If not a float, write as is
}
} else {
outFile << value;
out_file << value;
}

if (!comment.empty()) {
outFile << " # " << comment;
out_file << " # " << comment;
}
outFile << "\n";
return outFile.str();
out_file << "\n";
return out_file.str();
};

inline std::string BytesToHumanReadable(uint64_t bytes) {
Expand Down
Loading