From d126b08cc91c4c23c37749d7d118657ea829bcfa Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Fri, 3 Nov 2023 19:04:59 -0700 Subject: [PATCH 01/11] Add Python backend based PyTorch runtime --- CMakeLists.txt | 7 ++ src/model.py | 318 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100755 src/model.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 45ff129..8ce0689 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -504,6 +504,13 @@ install( ${INSTALL_CONFIGDIR} ) +install( + FILES + src/model.py + DESTINATION + ${CMAKE_INSTALL_PREFIX}/backends/pytorch +) + include(CMakePackageConfigHelpers) configure_package_config_file( ${CMAKE_CURRENT_LIST_DIR}/cmake/TritonPyTorchBackendConfig.cmake.in diff --git a/src/model.py b/src/model.py new file mode 100755 index 0000000..028bd66 --- /dev/null +++ b/src/model.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 + +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import importlib +import json +import os + +try: + import torch +except ModuleNotFoundError as error: + raise RuntimeError("Missing/Incomplete PyTorch package installation") from error + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + + +def _get_model_path(config): + filenames = ["model.py", "model.pt"] + if config["default_model_filename"]: + filenames.insert(0, config["default_model_filename"]) + for filename in filenames: + model_path = os.path.join(pb_utils.get_model_dir(), filename) + if os.path.exists(model_path): + return model_path + raise pb_utils.TritonModelException( + "No model found in " + pb_utils.get_model_dir() + "/" + str(filenames) + ) + + +def _get_model_data_path(model_path): + data_path_extensions = [".pt"] + model_path_no_extension = model_path[: -(len(model_path.split(".")[-1]) + 1)] + for extension in data_path_extensions: + data_path = model_path_no_extension + extension + if os.path.exists(data_path): + return data_path + # data file not provided + return "" + + +def _is_py_class_model(model_path): + return model_path[-3:] == ".py" + + +def _import_module_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _get_model_class_from_module(module): + names = dir(module) + for name in names: + attr = getattr(module, name) + try: + if issubclass(attr, torch.nn.Module): + return attr + except TypeError: + # attr may not be a class + pass + raise pb_utils.TritonModelException("Cannot find a subclass of torch.nn.Module") + + +def _parse_io_config(io_config): + io = [] + for conf in io_config: + io.append({"name": conf["name"]}) + return io + + +def _get_device_name(kind, device_id): + if kind == "GPU": + return "cuda:" + device_id + if kind == "CPU": + return "cpu" + # unspecified device + return "" + + +def _get_device(kind, device_id, model): + device_name = _get_device_name(kind, device_id) + if device_name == "": + for param in model.parameters(): + return param.device + raise pb_utils.TritonModelException("Cannot determine model device") + return torch.device(device_name) + + +def _set_torch_parallelism(config): + log_msg = "" + parallelism_settings = ["NUM_THREADS", "NUM_INTEROP_THREADS"] + for setting in parallelism_settings: + val = "1" + if setting in config["parameters"]: + val = config["parameters"][setting]["string_value"] + getattr(torch, "set_" + setting.lower())(int(val)) + log_msg += setting + " = " + val + "; " + return log_msg + + +def _get_torch_compile_params(config): + params = {} + if "TORCH_COMPILE_OPTIONAL_PARAMETERS" in config["parameters"]: + val = config["parameters"]["TORCH_COMPILE_OPTIONAL_PARAMETERS"]["string_value"] + params = json.loads(val) + if "model" in params: + raise pb_utils.TritonModelException( + "'model' is not an optional parameter for 'torch.compile'" + ) + return params + + +def _gather_torch_tensors(scatter_tensors): + gather_tensors = [] + sections = [] + for i in range(len(scatter_tensors)): + tensors = scatter_tensors[i] + for j in range(len(tensors)): + tensor = tensors[j] + if j < len(gather_tensors): + # add to existing tensor + gather_tensors[j] = torch.cat((gather_tensors[j], tensor), 0) + else: + # start a new tensor + gather_tensors.append(tensor) + # record section + section_length = tensors[0].size()[0] + sections.append(section_length) + return gather_tensors, sections + + +def _scatter_torch_tensors(gather_tensors, sections): + scatter_tensors = [] + for j in range(len(gather_tensors)): + scatter_tensor = torch.split(gather_tensors[j], sections) + for i in range(len(scatter_tensor)): + tensor = scatter_tensor[i] + if i < len(scatter_tensors): + # add to existing response + scatter_tensors[i].append(tensor) + else: + # start a new response + scatter_tensors.append([tensor]) + return scatter_tensors + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + self._model_name = args["model_name"] + for_model = "for '" + self._model_name + "'" + self._logger = pb_utils.Logger + self._logger.log_info("Initializing model instance " + for_model) + + self._model_config = json.loads(args["model_config"]) + self._kind = args["model_instance_kind"] + self._device_id = args["model_instance_device_id"] + self._support_batching = self._model_config["max_batch_size"] > 0 + self._inputs = _parse_io_config(self._model_config["input"]) + self._outputs = _parse_io_config(self._model_config["output"]) + + setting_msg = _set_torch_parallelism(self._model_config) + self._logger.log_verbose( + "Torch parallelism settings " + for_model + ": " + setting_msg + ) + + self._infer_mode = torch.inference_mode(mode=True) + self._infer_mode.__enter__() + + params = _get_torch_compile_params(self._model_config) + self._logger.log_verbose( + "'torch.compile' optional parameter(s) " + for_model + ": " + str(params) + ) + if self._support_batching: + self._gather = torch.compile(_gather_torch_tensors, **params) + self._scatter = torch.compile(_scatter_torch_tensors, **params) + + model_path = _get_model_path(self._model_config) + if not _is_py_class_model(model_path): + self._logger.log_info("Loading '" + self._model_name + "' as TorchScript") + self._model = torch.jit.load(model_path) + self._device = _get_device(self._kind, self._device_id, self._model) + self._model.to(self._device) + self._model.eval() + return + + self._model_module = _import_module_from_path(self._model_name, model_path) + self._model_class = _get_model_class_from_module(self._model_module) + self._raw_model = self._model_class() + self._device = _get_device(self._kind, self._device_id, self._raw_model) + data_path = _get_model_data_path(model_path) + if data_path != "": + self._raw_model.load_state_dict( + torch.load(data_path, map_location=self._device) + ) + else: + self._logger.log_info("Model parameter file not found " + for_model) + self._raw_model.to(self._device) + self._raw_model.eval() + self._model = torch.compile(self._raw_model, **params) + + def execute(self, requests): + """`execute` MUST be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference request is made + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + requests_tensors = [] + for request in requests: + tensors = [] + for io in self._inputs: + tensor = pb_utils.get_input_tensor_by_name( + request, io["name"] + ).to_dlpack() + tensor = torch.from_dlpack(tensor).to(self._device) + tensors.append(tensor) + requests_tensors.append(tensors) + + sections = None + if self._support_batching: + requests_tensors, sections = self._gather(requests_tensors) + requests_tensors = [requests_tensors] + + responses_tensors = [] + for input_tensors in requests_tensors: + output_tensors = self._model(*input_tensors) + if not isinstance(output_tensors, tuple) and not isinstance( + output_tensors, list + ): + output_tensors = [output_tensors] + responses_tensors.append(output_tensors) + + if self._support_batching: + responses_tensors = self._scatter(responses_tensors[0], sections) + + for response_tensors in responses_tensors: + output_tensors = [] + for i in range(len(self._outputs)): + io = self._outputs[i] + tensor = response_tensors[i].detach() + tensor = pb_utils.Tensor.from_dlpack(io["name"], tensor) + output_tensors.append(tensor) + inference_response = pb_utils.InferenceResponse( + output_tensors=output_tensors + ) + responses.append(inference_response) + + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is OPTIONAL. This function allows + the model to perform any necessary clean ups before exit. + """ + self._logger.log_info("Removing model instance for '" + self._model_name + "'") + self._infer_mode.__exit__(exc_type=None, exc_value=None, traceback=None) From 8d5007192d2efcc6fa22c075288a8ab7933b3358 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Mon, 13 Nov 2023 19:43:22 -0800 Subject: [PATCH 02/11] Add exec env build --- CMakeLists.txt | 17 ++++++++----- tools/gen_pb_exec_env.sh | 52 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 6 deletions(-) create mode 100755 tools/gen_pb_exec_env.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ce0689..e45e7fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,7 @@ option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) option(TRITON_ENABLE_NVTX "Include nvtx markers collection in backend." OFF) option(TRITON_PYTORCH_ENABLE_TORCHTRT "Enable TorchTRT support" OFF) option(TRITON_PYTORCH_ENABLE_TORCHVISION "Enable Torchvision support" ON) +option(TRITON_PYTORCH_ENABLE_PYTHON_RUNTIME "Enable Python backend runtime support" ON) set(TRITON_PYTORCH_DOCKER_IMAGE "" CACHE STRING "Docker image containing the PyTorch build required by backend.") set(TRITON_PYTORCH_INCLUDE_PATHS "" CACHE PATH "Paths to Torch includes") @@ -504,12 +505,16 @@ install( ${INSTALL_CONFIGDIR} ) -install( - FILES - src/model.py - DESTINATION - ${CMAKE_INSTALL_PREFIX}/backends/pytorch -) +if (${TRITON_PYTORCH_ENABLE_PYTHON_RUNTIME}) + install(CODE "execute_process(COMMAND bash -c ${CMAKE_CURRENT_SOURCE_DIR}/tools/gen_pb_exec_env.sh)") + install( + FILES + src/model.py + ${CMAKE_CURRENT_BINARY_DIR}/pb_exec_env_model.py.tar.gz + DESTINATION + ${CMAKE_INSTALL_PREFIX}/backends/pytorch + ) +endif() # TRITON_PYTORCH_ENABLE_PYTHON_RUNTIME include(CMakePackageConfigHelpers) configure_package_config_file( diff --git a/tools/gen_pb_exec_env.sh b/tools/gen_pb_exec_env.sh new file mode 100755 index 0000000..26914b6 --- /dev/null +++ b/tools/gen_pb_exec_env.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# install conda +rm -rf ./miniconda +wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh +bash Miniconda3-py310_23.3.1-0-Linux-x86_64.sh -p ./miniconda -b +eval "$(./miniconda/bin/conda shell.bash hook)" + +# create conda environment +conda create -n pt python=3.10 -y +conda activate pt +conda install -c conda-forge conda-pack -y + +# pre install step +export PYTHONNOUSERSITE=True +conda install -c conda-forge libstdcxx-ng=12 -y + +# install PyTorch +conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y + +# pack environment +rm -f pb_exec_env_model.py.tar.gz +conda pack -o pb_exec_env_model.py.tar.gz + +# deactivate conda +conda deactivate +conda deactivate From 4381340963e20f1da8da395ff0df77e1ca292602 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Fri, 8 Dec 2023 12:06:47 -0800 Subject: [PATCH 03/11] Add note for adding .pt2 model support --- src/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/model.py b/src/model.py index 028bd66..645a641 100755 --- a/src/model.py +++ b/src/model.py @@ -43,6 +43,7 @@ def _get_model_path(config): + # FIXME: Add support for torch.export IR models (.pt2) filenames = ["model.py", "model.pt"] if config["default_model_filename"]: filenames.insert(0, config["default_model_filename"]) From f71dd1775731a8049060225f183027ab02a8d0b9 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Fri, 8 Dec 2023 13:37:05 -0800 Subject: [PATCH 04/11] Do not specify pytorch cuda version --- tools/gen_pb_exec_env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/gen_pb_exec_env.sh b/tools/gen_pb_exec_env.sh index 26914b6..17d22b0 100755 --- a/tools/gen_pb_exec_env.sh +++ b/tools/gen_pb_exec_env.sh @@ -41,7 +41,7 @@ export PYTHONNOUSERSITE=True conda install -c conda-forge libstdcxx-ng=12 -y # install PyTorch -conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y +conda install pytorch torchvision torchaudio pytorch-cuda -c pytorch -c nvidia -y # pack environment rm -f pb_exec_env_model.py.tar.gz From b459f73741a9d2131369c2f4dbacb50b6fd89a8a Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Fri, 8 Dec 2023 17:41:20 -0800 Subject: [PATCH 05/11] Do not install Python runtime on non x86 --- CMakeLists.txt | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e45e7fc..f286260 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -506,14 +506,18 @@ install( ) if (${TRITON_PYTORCH_ENABLE_PYTHON_RUNTIME}) - install(CODE "execute_process(COMMAND bash -c ${CMAKE_CURRENT_SOURCE_DIR}/tools/gen_pb_exec_env.sh)") - install( - FILES - src/model.py - ${CMAKE_CURRENT_BINARY_DIR}/pb_exec_env_model.py.tar.gz - DESTINATION - ${CMAKE_INSTALL_PREFIX}/backends/pytorch - ) + if ((CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") OR (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "AMD64")) + install(CODE "execute_process(COMMAND bash -c ${CMAKE_CURRENT_SOURCE_DIR}/tools/gen_pb_exec_env.sh)") + install( + FILES + src/model.py + ${CMAKE_CURRENT_BINARY_DIR}/pb_exec_env_model.py.tar.gz + DESTINATION + ${CMAKE_INSTALL_PREFIX}/backends/pytorch + ) + else() + message(WARNING "Skipped PyTorch Python runtime on unsupported architecture ${CMAKE_HOST_SYSTEM_PROCESSOR}") + endif() endif() # TRITON_PYTORCH_ENABLE_PYTHON_RUNTIME include(CMakePackageConfigHelpers) From 78c47fe8b5901ba42283cba41702bc19bae18500 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:15:30 -0800 Subject: [PATCH 06/11] Remove legacy comment --- src/model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/model.py b/src/model.py index 645a641..33f6de8 100755 --- a/src/model.py +++ b/src/model.py @@ -35,10 +35,6 @@ except ModuleNotFoundError as error: raise RuntimeError("Missing/Incomplete PyTorch package installation") from error -# triton_python_backend_utils is available in every Triton Python model. You -# need to use this module to create inference requests and responses. It also -# contains some utility functions for extracting information from model_config -# and converting Triton input/output types to numpy types. import triton_python_backend_utils as pb_utils From 8b856f628b23e7877d455d34a8ddaedafb365f28 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Fri, 5 Jan 2024 17:57:22 -0800 Subject: [PATCH 07/11] User to build PyTorch env --- CMakeLists.txt | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a0e5a43..603677b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,7 +47,6 @@ option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) option(TRITON_ENABLE_NVTX "Include nvtx markers collection in backend." OFF) option(TRITON_PYTORCH_ENABLE_TORCHTRT "Enable TorchTRT support" OFF) option(TRITON_PYTORCH_ENABLE_TORCHVISION "Enable Torchvision support" ON) -option(TRITON_PYTORCH_ENABLE_PYTHON_RUNTIME "Enable Python backend runtime support" ON) set(TRITON_PYTORCH_DOCKER_IMAGE "" CACHE STRING "Docker image containing the PyTorch build required by backend.") set(TRITON_PYTORCH_INCLUDE_PATHS "" CACHE PATH "Paths to Torch includes") @@ -503,20 +502,12 @@ install( ${INSTALL_CONFIGDIR} ) -if (${TRITON_PYTORCH_ENABLE_PYTHON_RUNTIME}) - if ((CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") OR (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "AMD64")) - install(CODE "execute_process(COMMAND bash -c ${CMAKE_CURRENT_SOURCE_DIR}/tools/gen_pb_exec_env.sh)") - install( - FILES - src/model.py - ${CMAKE_CURRENT_BINARY_DIR}/pb_exec_env_model.py.tar.gz - DESTINATION - ${CMAKE_INSTALL_PREFIX}/backends/pytorch - ) - else() - message(WARNING "Skipped PyTorch Python runtime on unsupported architecture ${CMAKE_HOST_SYSTEM_PROCESSOR}") - endif() -endif() # TRITON_PYTORCH_ENABLE_PYTHON_RUNTIME +install( + FILES + src/model.py + DESTINATION + ${CMAKE_INSTALL_PREFIX}/backends/pytorch +) include(CMakePackageConfigHelpers) configure_package_config_file( From 7d8a3a7f04f0bf1d358b799b4c59105e549dd004 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:10:32 -0800 Subject: [PATCH 08/11] Add docs --- README.md | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/README.md b/README.md index 2c14421..0da9036 100644 --- a/README.md +++ b/README.md @@ -243,3 +243,112 @@ instance in the [model configuration](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#instance-groups) to ensure that the model instance and the tensors used for inference are assigned to the same GPU device as on which the model was traced. + +# PyTorch 2.0 Backend \[Experimental\] + +> [!WARNING] +> *This feature is subject to change and removal.* + +Starting from 24.01, PyTorch models can be served directly via +[Python runtime](src/model.py). By default, Triton will use the +[LibTorch runtime](#pytorch-libtorch-backend) for PyTorch models. To use Python +runtime, provide the following +[runtime setting](https://github.com/triton-inference-server/backend/blob/main/README.md#backend-shared-library) +in the model configuration: + +``` +runtime: "model.py" +``` + +## Dependencies + +### Python backend dependency + +This feature depends on +[Python backend](https://github.com/triton-inference-server/python_backend), +see +[Python-based Backends](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md) +for more details. + +### PyTorch dependency + +This feature will take advantage of the +[`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile) +optimization, make sure the +[PyTorch 2.0+ pip package](https://pypi.org/project/torch) is available in the +same Python environment. + +Alternatively, a [Python Execution Environment](#using-custom-python-execution-environments) +with the PyTorch dependency may be used. It can be created with the +[provided script](tools/gen_pb_exec_env.sh). The resulting +`pb_exec_env_model.py.tar.gz` file should be placed at the same +[backend shared library](https://github.com/triton-inference-server/backend/blob/main/README.md#backend-shared-library) +directory as the [Python runtime](src/model.py). + +## Model Layout + +The model repository should look like: + +``` +model_repository/ +`-- model_directory + |-- 1 + | |-- model.py + | `-- model.pt + `-- config.pbtxt +``` + +The `model.py` contains the class definition of the PyTorch model. The class +should extend the +[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module). +The `model.pt` may be optionally provided which contains the saved +[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference) +of the model. For serving TorchScript models, a `model.pt` TorchScript can be +provided in place of the `model.py` file. + +### Customization + +The following PyTorch settings may be customized by setting parameters on the +`config.pbtxt`. + +[`torch.set_num_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads) +- Key: NUM_THREADS +- Value: The number of threads used for intraop parallelism on CPU. + +[`torch.set_num_interop_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_interop_threads.html#torch.set_num_interop_threads) +- Key: NUM_INTEROP_THREADS +- Value: The number of threads used for interop parallelism (e.g. in JIT +interpreter) on CPU. + +[`torch.compile()` parameters](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile) +- Key: TORCH_COMPILE_OPTIONAL_PARAMETERS +- Value: Any of following parameter(s) encoded as a JSON object. + - fullgraph (*bool*): Whether it is ok to break model into several subgraphs. + - dynamic (*bool*): Use dynamic shape tracing. + - backend (*str*): The backend to be used. + - mode (*str*): Can be either "default", "reduce-overhead" or "max-autotune". + - options (*dict*): A dictionary of options to pass to the backend. + - disable (*bool*): Turn `torch.compile()` into a no-op for testing. + +For example: +``` +parameters: { + key: "NUM_THREADS" + value: { string_value: "4" } +} +parameters: { + key: "TORCH_COMPILE_OPTIONAL_PARAMETERS" + value: { string_value: "{\"disable\": true}" } +} +``` + +## Limitations + +Following are few known limitations of this feature: +- Python functions optimizable by `torch.compile` may not be served directly in +the `model.py` file, they need to be enclosed by a class extending the +[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module). +- Model weights cannot be shared across multiple instances on the same GPU +device. +- When using `KIND_MODEL` as model instance kind, the default device of the +first parameter on the model is used. From 9aa6b412335b6bdb9f26c2eb782613b5e7b3613e Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:30:02 -0800 Subject: [PATCH 09/11] Update copyright --- CMakeLists.txt | 2 +- README.md | 2 +- src/model.py | 2 +- tools/gen_pb_exec_env.sh | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 603677b..517481c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions diff --git a/README.md b/README.md index 0da9036..c1b6436 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@