Skip to content

Commit

Permalink
add thread control for pytorch backend (#125)
Browse files Browse the repository at this point in the history
* add pytorch thread control

* use function overloading and update copyright years
  • Loading branch information
yongbinfeng authored Apr 18, 2024
1 parent 4fa7daa commit c50d65b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 3 deletions.
56 changes: 55 additions & 1 deletion src/libtorch.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -56,6 +56,12 @@
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU

// for thread control
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#runtime-api
// https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133
#include <ATen/Parallel.h>


//
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
//
Expand Down Expand Up @@ -465,6 +471,54 @@ ModelState::ParseParameters()
" for model instance '" + Name() + "'")
.c_str());
}

// If 'INTRA_OP_THREAD_COUNT' is not present in 'parameters' then no update
// is made to 'intra_op_thread_count', which by default will take all
// threads
int intra_op_thread_count = -1;
err = ParseParameter(
params, "INTRA_OP_THREAD_COUNT", &intra_op_thread_count);
if (err != nullptr) {
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
return err;
} else {
TRITONSERVER_ErrorDelete(err);
}
} else {
if (intra_op_thread_count > 0) {
at::set_num_threads(intra_op_thread_count);
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Intra op thread count is set to ") +
std::to_string(intra_op_thread_count) + " for model instance '" +
Name() + "'")
.c_str());
}
}

// If 'INTER_OP_THREAD_COUNT' is not present in 'parameters' then no update
// is made to 'inter_op_thread_count', which by default will take all
// threads
int inter_op_thread_count = -1;
err = ParseParameter(
params, "INTER_OP_THREAD_COUNT", &inter_op_thread_count);
if (err != nullptr) {
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
return err;
} else {
TRITONSERVER_ErrorDelete(err);
}
} else {
if (inter_op_thread_count > 0) {
at::set_num_interop_threads(inter_op_thread_count);
LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Inter op thread count is set to ") +
std::to_string(inter_op_thread_count) + " for model instance '" +
Name() + "'")
.c_str());
}
}
}

return nullptr;
Expand Down
15 changes: 14 additions & 1 deletion src/libtorch_utils.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-21 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-24 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -149,6 +149,19 @@ ParseParameter(
return nullptr;
}

TRITONSERVER_Error*
ParseParameter(
triton::common::TritonJson::Value& params, const std::string& mkey,
int* value)
{
std::string value_str;
RETURN_IF_ERROR(GetParameterValue(params, mkey, &value_str));
RETURN_IF_ERROR(ParseIntValue(value_str, value));

return nullptr;
}


#ifdef TRITON_ENABLE_GPU
TRITONSERVER_Error*
ConvertCUDAStatusToTritonError(
Expand Down
9 changes: 8 additions & 1 deletion src/libtorch_utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -62,4 +62,11 @@ TRITONSERVER_Error* ParseParameter(
triton::common::TritonJson::Value& params, const std::string& mkey,
bool* value);

// If the key 'mkey' is present in 'params' then update 'value' with the
// value associated with that key. If 'mkey' is not present in 'params' then
// 'value' is set to 'default_value'.
TRITONSERVER_Error* ParseParameter(
triton::common::TritonJson::Value& params, const std::string& mkey,
int* value);

}}} // namespace triton::backend::pytorch

0 comments on commit c50d65b

Please sign in to comment.