From 906baa58f4967bc314e711cc5aeebf9cd93e0396 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Fri, 7 Jun 2024 10:13:13 +0530 Subject: [PATCH] ci: Support BF16 data type in TensorRT backend (#7310) --- docs/user_guide/model_configuration.md | 4 +- qa/L0_trt_bf16_dtype/test.sh | 91 +++++++++++++++++ qa/L0_trt_bf16_dtype/trt_bf16_dtype_test.py | 103 ++++++++++++++++++++ qa/common/gen_common.py | 8 ++ qa/common/gen_qa_models.py | 30 ++++++ qa/common/test_util.py | 52 +++++++++- 6 files changed, 282 insertions(+), 6 deletions(-) create mode 100755 qa/L0_trt_bf16_dtype/test.sh create mode 100755 qa/L0_trt_bf16_dtype/trt_bf16_dtype_test.py diff --git a/docs/user_guide/model_configuration.md b/docs/user_guide/model_configuration.md index e8165081c1..40ec39ec03 100644 --- a/docs/user_guide/model_configuration.md +++ b/docs/user_guide/model_configuration.md @@ -447,12 +447,12 @@ library. |TYPE_INT8 | kINT8 |DT_INT8 |INT8 |kChar |INT8 |int8 | |TYPE_INT16 | |DT_INT16 |INT16 |kShort |INT16 |int16 | |TYPE_INT32 | kINT32 |DT_INT32 |INT32 |kInt |INT32 |int32 | -|TYPE_INT64 | |DT_INT64 |INT64 |kLong |INT64 |int64 | +|TYPE_INT64 | kINT64 |DT_INT64 |INT64 |kLong |INT64 |int64 | |TYPE_FP16 | kHALF |DT_HALF |FLOAT16 | |FP16 |float16 | |TYPE_FP32 | kFLOAT |DT_FLOAT |FLOAT |kFloat |FP32 |float32 | |TYPE_FP64 | |DT_DOUBLE |DOUBLE |kDouble |FP64 |float64 | |TYPE_STRING | |DT_STRING |STRING | |BYTES |dtype(object) | -|TYPE_BF16 | | | | |BF16 | | +|TYPE_BF16 | kBF16 | | | |BF16 | | For TensorRT each value is in the nvinfer1::DataType namespace. For example, nvinfer1::DataType::kFLOAT is the 32-bit floating-point diff --git a/qa/L0_trt_bf16_dtype/test.sh b/qa/L0_trt_bf16_dtype/test.sh new file mode 100755 index 0000000000..da787bc41a --- /dev/null +++ b/qa/L0_trt_bf16_dtype/test.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +source ../common/util.sh + +REPO_VERSION=${NVIDIA_TRITON_SERVER_VERSION} +if [ "$#" -ge 1 ]; then + REPO_VERSION=$1 +fi +if [ -z "$REPO_VERSION" ]; then + echo -e "Repository version must be specified" + echo -e "\n***\n*** Test Failed\n***" + exit 1 +fi +if [ ! -z "$TEST_REPO_ARCH" ]; then + REPO_VERSION=${REPO_VERSION}_${TEST_REPO_ARCH} +fi + +RET=0 +TRT_TEST="trt_bf16_dtype_test.py" +TEST_RESULT_FILE="./test_results.txt" +SERVER=/opt/tritonserver/bin/tritonserver + +rm -rf ./fixed_models/ ./dynamic_models/ *.log* && mkdir ./fixed_models/ ./dynamic_models/ +cp -r /data/inferenceserver/${REPO_VERSION}/qa_model_repository/plan_*bf16_bf16_bf16 ./fixed_models/ +cp -r /data/inferenceserver/${REPO_VERSION}/qa_variable_model_repository/plan_*bf16_bf16_bf16 ./dynamic_models/ + +for TEST in "fixed" "dynamic"; do + MODELDIR="./${TEST}_models" + CLIENT_LOG="./${TEST}_client.log" + SERVER_LOG="./${TEST}_inference_server.log" + SERVER_ARGS="--model-repository=${MODELDIR} --log-verbose=1" + + run_server + if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 + fi + + set +e + python3 $TRT_TEST TrtBF16DataTypeTest.test_${TEST} >>$CLIENT_LOG 2>&1 + if [ $? -ne 0 ]; then + echo -e "\n***\n*** Running $TRT_TEST TrtBF16DataTypeTest.test_${TEST} Failed\n***" + cat $CLIENT_LOG + RET=1 + else + check_test_results $TEST_RESULT_FILE 1 + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi + fi + set -e + + kill $SERVER_PID + wait $SERVER_PID +done + +if [ $RET -eq 0 ]; then + echo -e "\n***\n*** Test Passed\n***" +else + echo -e "\n***\n*** Test Failed\n***" +fi + +exit $RET diff --git a/qa/L0_trt_bf16_dtype/trt_bf16_dtype_test.py b/qa/L0_trt_bf16_dtype/trt_bf16_dtype_test.py new file mode 100755 index 0000000000..265c1930b0 --- /dev/null +++ b/qa/L0_trt_bf16_dtype/trt_bf16_dtype_test.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import sys + +sys.path.append("../common") + +import unittest + +import numpy as np +import test_util as tu +import tritonclient.http as client + + +class TrtBF16DataTypeTest(tu.TestResultCollector): + def setUp(self): + self.triton_client = client.InferenceServerClient( + "localhost:8000", verbose=True + ) + + def _infer_helper(self, model_name, shape): + inputs = [] + outputs = [] + inputs.append(client.InferInput("INPUT0", shape, "BF16")) + inputs.append(client.InferInput("INPUT1", shape, "BF16")) + + input0_data = np.ones(shape=shape).astype(np.float32) + input1_data = np.ones(shape=shape).astype(np.float32) + + inputs[0].set_data_from_numpy(input0_data, binary_data=True) + inputs[1].set_data_from_numpy(input1_data, binary_data=True) + + outputs.append(client.InferRequestedOutput("OUTPUT0", binary_data=True)) + outputs.append(client.InferRequestedOutput("OUTPUT1", binary_data=True)) + + results = self.triton_client.infer(model_name, inputs, outputs=outputs) + + output0_data = results.as_numpy("OUTPUT0") + output1_data = results.as_numpy("OUTPUT1") + + np.testing.assert_equal( + output0_data, + input0_data + input1_data, + "Result output does not match the expected output", + ) + np.testing.assert_equal( + output1_data, + input0_data - input1_data, + "Result output does not match the expected output", + ) + + def test_fixed(self): + for bs in [1, 4, 8]: + self._infer_helper( + "plan_bf16_bf16_bf16", + [bs, 16], + ) + + self._infer_helper( + "plan_nobatch_bf16_bf16_bf16", + [16], + ) + + def test_dynamic(self): + for bs in [1, 4, 8]: + self._infer_helper( + "plan_bf16_bf16_bf16", + [bs, 16, 16], + ) + + self._infer_helper( + "plan_nobatch_bf16_bf16_bf16", + [16, 16], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/qa/common/gen_common.py b/qa/common/gen_common.py index d574627dfd..5bb751f3c8 100644 --- a/qa/common/gen_common.py +++ b/qa/common/gen_common.py @@ -31,6 +31,10 @@ np_dtype_string = np.dtype(object) +# Numpy does not support the BF16 datatype natively. +# We use this dummy dtype as a representative for BF16. +np_dtype_bfloat16 = np.dtype([("bf16", object)]) + def np_to_onnx_dtype(np_dtype): import onnx @@ -83,6 +87,8 @@ def np_to_model_dtype(np_dtype): return "TYPE_FP64" elif np_dtype == np_dtype_string: return "TYPE_STRING" + elif np_dtype == np_dtype_bfloat16: + return "TYPE_BF16" return None @@ -101,6 +107,8 @@ def np_to_trt_dtype(np_dtype): return trt.float16 elif np_dtype == np.float32: return trt.float32 + elif np_dtype == np_dtype_bfloat16: + return trt.bfloat16 return None diff --git a/qa/common/gen_qa_models.py b/qa/common/gen_qa_models.py index 88e9ffc97c..efe3fca1e9 100755 --- a/qa/common/gen_qa_models.py +++ b/qa/common/gen_qa_models.py @@ -33,6 +33,7 @@ import gen_ensemble_model_utils as emu import numpy as np from gen_common import ( + np_dtype_bfloat16, np_to_model_dtype, np_to_onnx_dtype, np_to_tf_dtype, @@ -2578,6 +2579,18 @@ def create_fixed_models( ) if FLAGS.tensorrt: + if tu.check_gpus_compute_capability(min_capability=8.0): + create_fixed_models( + FLAGS.models_dir, + np_dtype_bfloat16, + np_dtype_bfloat16, + np_dtype_bfloat16, + ) + else: + print( + "Skipping the generation of TensorRT PLAN models for the BF16 datatype!" + ) + for vt in [np.float32, np.float16, np.int32, np.uint8]: create_plan_modelfile( FLAGS.models_dir, 8, 2, (16,), (16,), (16,), vt, vt, vt, swap=True @@ -2854,6 +2867,23 @@ def create_fixed_models( 32, ) + if FLAGS.tensorrt: + if tu.check_gpus_compute_capability(min_capability=8.0): + create_models( + FLAGS.models_dir, + np_dtype_bfloat16, + np_dtype_bfloat16, + np_dtype_bfloat16, + (-1, -1), + (-1, -1), + (-1, -1), + 0, + ) + else: + print( + "Skipping the generation of TensorRT PLAN models for the BF16 datatype!" + ) + if FLAGS.ensemble: # Create utility models used in ensemble # nop (only creates model config, should add model file before use) diff --git a/qa/common/test_util.py b/qa/common/test_util.py index d0d7bda590..d241f5909b 100755 --- a/qa/common/test_util.py +++ b/qa/common/test_util.py @@ -33,6 +33,10 @@ _last_request_id = 0 +# Numpy does not support the BF16 datatype natively. +# We use this dummy dtype as a representative for BF16. +np_dtype_bfloat16 = np.dtype([("bf16", object)]) + def shape_element_count(shape): cnt = 0 @@ -103,7 +107,15 @@ def validate_for_trt_model( input_dtype, output0_dtype, output1_dtype, input_shape, output0_shape, output1_shape ): """Return True if input and output dtypes are supported by a TRT model.""" - supported_datatypes = [bool, np.int8, np.int32, np.uint8, np.float16, np.float32] + supported_datatypes = [ + bool, + np.int8, + np.int32, + np.uint8, + np.float16, + np.float32, + np_dtype_bfloat16, + ] # FIXME: Remove this check when jetson supports TRT 8.5 (DLIS-4256) if not support_trt_uint8(): supported_datatypes.remove(np.uint8) @@ -275,12 +287,19 @@ def validate_for_openvino_model( return True +def get_dtype_name(dtype): + if dtype == np_dtype_bfloat16: + return "bf16" + else: + return np.dtype(dtype).name + + def get_model_name(pf, input_dtype, output0_dtype, output1_dtype): return "{}_{}_{}_{}".format( pf, - np.dtype(input_dtype).name, - np.dtype(output0_dtype).name, - np.dtype(output1_dtype).name, + get_dtype_name(input_dtype), + get_dtype_name(output0_dtype), + get_dtype_name(output1_dtype), ) @@ -309,6 +328,31 @@ def support_trt_uint8(): return hasattr(trt, "uint8") +def check_gpus_compute_capability(min_capability): + """ + Check if all GPUs have a compute capability greater than or equal to the given value. + + Args: + min_capability (float): The minimum required compute capability (e.g., 8.0). + + Returns: + bool + """ + import pycuda.driver as cuda + + cuda.init() + + for device_index in range(cuda.Device.count()): + device = cuda.Device(device_index) + compute_capability = device.compute_capability() + compute_capability_value = compute_capability[0] + compute_capability[1] / 10.0 + + if compute_capability_value < min_capability: + return False + + return True + + class TestResultCollector(unittest.TestCase): # TestResultCollector stores test result and prints it to stdout. In order # to use this class, unit tests must inherit this class. Use