From 89ab580a4d7571deacd16298a607abcbb5b93911 Mon Sep 17 00:00:00 2001 From: Ashish Karale Date: Thu, 19 Sep 2024 08:30:00 +0000 Subject: [PATCH] Added new cmake flag TRITON_ENABLE_CIG to make the CiG support build conditional --- CMakeLists.txt | 14 +++++++++++++- src/instance_state.cc | 15 ++++++++++++--- src/model_state.cc | 21 +++++++++++++++++---- src/tensorrt.cc | 8 +++++++- src/tensorrt_model.cc | 9 ++++++++- src/tensorrt_model.h | 10 +++++++++- 6 files changed, 66 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b798d11..c88248d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,8 @@ set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which feat option(TRITON_ENABLE_GPU "Enable GPU support in backend." ON) option(TRITON_ENABLE_STATS "Include statistics collections in backend." ON) option(TRITON_ENABLE_NVTX "Include nvtx markers collection in backend." OFF) +option(TRITON_ENABLE_CIG "Enable Cuda in Graphics (CiG) support in backend." OFF) + set(TRITON_TENSORRT_LIB_PATHS "" CACHE PATH "Paths to TensorRT libraries. Multiple paths may be specified by separating them with a semicolon.") set(TRITON_TENSORRT_INCLUDE_PATHS "" CACHE PATH "Paths to TensorRT includes. Multiple paths may be specified by separating them with a semicolon.") @@ -269,9 +271,19 @@ target_link_libraries( triton-tensorrt-backend PRIVATE CUDA::cudart - CUDA::cuda_driver ) +if(${TRITON_ENABLE_CIG}) + target_compile_definitions( + triton-tensorrt-backend + PRIVATE TRITON_ENABLE_CIG + ) + target_link_libraries( + triton-tensorrt-backend + PRIVATE + CUDA::cuda_driver + ) +endif() # # Install diff --git a/src/instance_state.cc b/src/instance_state.cc index e7113d9..4a4fbc1 100644 --- a/src/instance_state.cc +++ b/src/instance_state.cc @@ -257,8 +257,11 @@ ModelInstanceState::ModelInstanceState( ModelInstanceState::~ModelInstanceState() { +#ifdef TRITON_ENABLE_CIG // Set device if CiG is disabled - if (!model_state_->isCiGEnabled()) { + if (!model_state_->isCiGEnabled()) +#endif //TRITON_ENABLE_CIG + { cudaSetDevice(DeviceId()); } for (auto& io_binding_infos : io_binding_infos_) { @@ -427,8 +430,11 @@ ModelInstanceState::Run( payload_.reset(new Payload(next_set_, requests, request_count)); SET_TIMESTAMP(payload_->compute_start_ns_); +#ifdef TRITON_ENABLE_CIG // Set device if CiG is disabled - if (!model_state_->isCiGEnabled()) { + if (!model_state_->isCiGEnabled()) +#endif //TRITON_ENABLE_CIG + { cudaSetDevice(DeviceId()); } #ifdef TRITON_ENABLE_STATS @@ -1557,8 +1563,11 @@ ModelInstanceState::EvaluateTensorRTContext( TRITONSERVER_Error* ModelInstanceState::InitStreamsAndEvents() { +#ifdef TRITON_ENABLE_CIG // Set device if CiG is disabled - if (!model_state_->isCiGEnabled()) { + if (!model_state_->isCiGEnabled()) +#endif //TRITON_ENABLE_CIG + { // Set the device before preparing the context. auto cuerr = cudaSetDevice(DeviceId()); if (cuerr != cudaSuccess) { diff --git a/src/model_state.cc b/src/model_state.cc index 0622a94..8b8b5d2 100644 --- a/src/model_state.cc +++ b/src/model_state.cc @@ -175,8 +175,11 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) ModelState::~ModelState() { for (auto& device_engine : device_engines_) { +#ifdef TRITON_ENABLE_CIG // Set device if CiG is disabled - if (!isCiGEnabled()) { + if (!isCiGEnabled()) +#endif //TRITON_ENABLE_CIG + { cudaSetDevice(device_engine.first.first); } auto& runtime = device_engine.second.first; @@ -212,8 +215,12 @@ ModelState::CreateEngine( // We share the engine (for models that don't have dynamic shapes) and // runtime across instances that have access to the same GPU/NVDLA. if (eit->second.second == nullptr) { + +#ifdef TRITON_ENABLE_CIG // Set device if CiG is disabled - if (!isCiGEnabled()) { + if (!isCiGEnabled()) +#endif //TRITON_ENABLE_CIG + { auto cuerr = cudaSetDevice(gpu_device); if (cuerr != cudaSuccess) { return TRITONSERVER_ErrorNew( @@ -326,8 +333,11 @@ ModelState::AutoCompleteConfig() " to auto-complete config for " + Name()) .c_str())); +#ifdef TRITON_ENABLE_CIG // Set device if CiG is disabled - if (!isCiGEnabled()) { + if (!isCiGEnabled()) +#endif //TRITON_ENABLE_CIG + { cuerr = cudaSetDevice(device_id); if (cuerr != cudaSuccess) { return TRITONSERVER_ErrorNew( @@ -381,8 +391,11 @@ ModelState::AutoCompleteConfig() RETURN_IF_ERROR(AutoCompleteConfigHelper(model_path)); +#ifdef TRITON_ENABLE_CIG // Set device if CiG is disabled - if (!isCiGEnabled()) { + if (!isCiGEnabled()) +#endif //TRITON_ENABLE_CIG + { cuerr = cudaSetDevice(current_device); if (cuerr != cudaSuccess) { return TRITONSERVER_ErrorNew( diff --git a/src/tensorrt.cc b/src/tensorrt.cc index 1bd0266..6476313 100644 --- a/src/tensorrt.cc +++ b/src/tensorrt.cc @@ -318,7 +318,9 @@ TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) DeviceMemoryTracker::TrackThreadMemoryUsage(lusage.get()); } +#ifdef TRITON_ENABLE_CIG ScopedRuntimeCiGContext cig_scope(model_state); +#endif //TRITON_ENABLE_CIG // With each instance we create a ModelInstanceState object and // associate it with the TRITONBACKEND_ModelInstance. @@ -357,7 +359,9 @@ TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) if (!instance_state) { return nullptr; } +#ifdef TRITON_ENABLE_CIG ScopedRuntimeCiGContext cig_scope(instance_state->StateForModel()); +#endif //TRITON_ENABLE_CIG delete instance_state; @@ -382,7 +386,9 @@ TRITONBACKEND_ModelInstanceExecute( instance, reinterpret_cast(&instance_state))); ModelState* model_state = instance_state->StateForModel(); - ScopedRuntimeCiGContext cig_scope(instance_state->StateForModel()); +#ifdef TRITON_ENABLE_CIG + ScopedRuntimeCiGContext cig_scope(model_state); +#endif //TRITON_ENABLE_CIG // For TensorRT backend, the executing instance may not closely tie to // TRITONBACKEND_ModelInstance, the instance will be assigned based on diff --git a/src/tensorrt_model.cc b/src/tensorrt_model.cc index 8285189..71259e9 100644 --- a/src/tensorrt_model.cc +++ b/src/tensorrt_model.cc @@ -55,7 +55,10 @@ TensorRTModel::TensorRTModel(TRITONBACKEND_Model* triton_model) : BackendModel(triton_model), priority_(Priority::DEFAULT), use_cuda_graphs_(false), gather_kernel_buffer_threshold_(0), separate_output_stream_(false), eager_batching_(false), - busy_wait_events_(false), cig_ctx_(nullptr) + busy_wait_events_(false) +#ifdef TRITON_ENABLE_CIG + ,cig_ctx_(nullptr) +#endif // TRITON_ENABLE_CIG { ParseModelConfig(); } @@ -91,6 +94,8 @@ TensorRTModel::ParseModelConfig() cuda.MemberAsBool("output_copy_stream", &separate_output_stream_)); } } + +#ifdef TRITON_ENABLE_CIG triton::common::TritonJson::Value parameters; if (model_config_.Find("parameters", ¶meters)) { triton::common::TritonJson::Value value; @@ -105,6 +110,8 @@ TensorRTModel::ParseModelConfig() LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, "CiG Context pointer is set"); } } +#endif //TRITON_ENABLE_CIG + return nullptr; // Success } diff --git a/src/tensorrt_model.h b/src/tensorrt_model.h index 708a51a..27c1f2d 100644 --- a/src/tensorrt_model.h +++ b/src/tensorrt_model.h @@ -25,7 +25,9 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#ifdef TRITON_ENABLE_CIG #include +#endif //TRITON_ENABLE_CIG #include "triton/backend/backend_model.h" @@ -55,7 +57,7 @@ class TensorRTModel : public BackendModel { bool EagerBatching() { return eager_batching_; } bool BusyWaitEvents() { return busy_wait_events_; } - +#ifdef TRITON_ENABLE_CIG //! Following functions are related to CiG (Cuda in Graphics) context sharing //! for gaming use case. Creating a shared contexts reduces context switching //! overhead and leads to better performance of model execution along side @@ -88,6 +90,7 @@ class TensorRTModel : public BackendModel { } return nullptr; } +#endif //TRITON_ENABLE_CIG protected: common::TritonJson::Value graph_specs_; @@ -97,9 +100,13 @@ class TensorRTModel : public BackendModel { bool separate_output_stream_; bool eager_batching_; bool busy_wait_events_; +#ifdef TRITON_ENABLE_CIG CUcontext cig_ctx_; +#endif //TRITON_ENABLE_CIG + }; +#ifdef TRITON_ENABLE_CIG struct ScopedRuntimeCiGContext { ScopedRuntimeCiGContext(TensorRTModel* model_state) : model_state_(model_state) @@ -116,5 +123,6 @@ struct ScopedRuntimeCiGContext { } TensorRTModel* model_state_; }; +#endif //TRITON_ENABLE_CIG }}} // namespace triton::backend::tensorrt