Skip to content

Commit

Permalink
ci: Support BF16 data type in TensorRT backend (#7310)
Browse files Browse the repository at this point in the history
  • Loading branch information
pskiran1 authored Jun 7, 2024
1 parent a821958 commit 906baa5
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 6 deletions.
4 changes: 2 additions & 2 deletions docs/user_guide/model_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,12 @@ library.
|TYPE_INT8 | kINT8 |DT_INT8 |INT8 |kChar |INT8 |int8 |
|TYPE_INT16 | |DT_INT16 |INT16 |kShort |INT16 |int16 |
|TYPE_INT32 | kINT32 |DT_INT32 |INT32 |kInt |INT32 |int32 |
|TYPE_INT64 | |DT_INT64 |INT64 |kLong |INT64 |int64 |
|TYPE_INT64 | kINT64 |DT_INT64 |INT64 |kLong |INT64 |int64 |
|TYPE_FP16 | kHALF |DT_HALF |FLOAT16 | |FP16 |float16 |
|TYPE_FP32 | kFLOAT |DT_FLOAT |FLOAT |kFloat |FP32 |float32 |
|TYPE_FP64 | |DT_DOUBLE |DOUBLE |kDouble |FP64 |float64 |
|TYPE_STRING | |DT_STRING |STRING | |BYTES |dtype(object) |
|TYPE_BF16 | | | | |BF16 | |
|TYPE_BF16 | kBF16 | | | |BF16 | |

For TensorRT each value is in the nvinfer1::DataType namespace. For
example, nvinfer1::DataType::kFLOAT is the 32-bit floating-point
Expand Down
91 changes: 91 additions & 0 deletions qa/L0_trt_bf16_dtype/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

source ../common/util.sh

REPO_VERSION=${NVIDIA_TRITON_SERVER_VERSION}
if [ "$#" -ge 1 ]; then
REPO_VERSION=$1
fi
if [ -z "$REPO_VERSION" ]; then
echo -e "Repository version must be specified"
echo -e "\n***\n*** Test Failed\n***"
exit 1
fi
if [ ! -z "$TEST_REPO_ARCH" ]; then
REPO_VERSION=${REPO_VERSION}_${TEST_REPO_ARCH}
fi

RET=0
TRT_TEST="trt_bf16_dtype_test.py"
TEST_RESULT_FILE="./test_results.txt"
SERVER=/opt/tritonserver/bin/tritonserver

rm -rf ./fixed_models/ ./dynamic_models/ *.log* && mkdir ./fixed_models/ ./dynamic_models/
cp -r /data/inferenceserver/${REPO_VERSION}/qa_model_repository/plan_*bf16_bf16_bf16 ./fixed_models/
cp -r /data/inferenceserver/${REPO_VERSION}/qa_variable_model_repository/plan_*bf16_bf16_bf16 ./dynamic_models/

for TEST in "fixed" "dynamic"; do
MODELDIR="./${TEST}_models"
CLIENT_LOG="./${TEST}_client.log"
SERVER_LOG="./${TEST}_inference_server.log"
SERVER_ARGS="--model-repository=${MODELDIR} --log-verbose=1"

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

set +e
python3 $TRT_TEST TrtBF16DataTypeTest.test_${TEST} >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Running $TRT_TEST TrtBF16DataTypeTest.test_${TEST} Failed\n***"
cat $CLIENT_LOG
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

kill $SERVER_PID
wait $SERVER_PID
done

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
echo -e "\n***\n*** Test Failed\n***"
fi

exit $RET
103 changes: 103 additions & 0 deletions qa/L0_trt_bf16_dtype/trt_bf16_dtype_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python3
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import sys

sys.path.append("../common")

import unittest

import numpy as np
import test_util as tu
import tritonclient.http as client


class TrtBF16DataTypeTest(tu.TestResultCollector):
def setUp(self):
self.triton_client = client.InferenceServerClient(
"localhost:8000", verbose=True
)

def _infer_helper(self, model_name, shape):
inputs = []
outputs = []
inputs.append(client.InferInput("INPUT0", shape, "BF16"))
inputs.append(client.InferInput("INPUT1", shape, "BF16"))

input0_data = np.ones(shape=shape).astype(np.float32)
input1_data = np.ones(shape=shape).astype(np.float32)

inputs[0].set_data_from_numpy(input0_data, binary_data=True)
inputs[1].set_data_from_numpy(input1_data, binary_data=True)

outputs.append(client.InferRequestedOutput("OUTPUT0", binary_data=True))
outputs.append(client.InferRequestedOutput("OUTPUT1", binary_data=True))

results = self.triton_client.infer(model_name, inputs, outputs=outputs)

output0_data = results.as_numpy("OUTPUT0")
output1_data = results.as_numpy("OUTPUT1")

np.testing.assert_equal(
output0_data,
input0_data + input1_data,
"Result output does not match the expected output",
)
np.testing.assert_equal(
output1_data,
input0_data - input1_data,
"Result output does not match the expected output",
)

def test_fixed(self):
for bs in [1, 4, 8]:
self._infer_helper(
"plan_bf16_bf16_bf16",
[bs, 16],
)

self._infer_helper(
"plan_nobatch_bf16_bf16_bf16",
[16],
)

def test_dynamic(self):
for bs in [1, 4, 8]:
self._infer_helper(
"plan_bf16_bf16_bf16",
[bs, 16, 16],
)

self._infer_helper(
"plan_nobatch_bf16_bf16_bf16",
[16, 16],
)


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions qa/common/gen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@

np_dtype_string = np.dtype(object)

# Numpy does not support the BF16 datatype natively.
# We use this dummy dtype as a representative for BF16.
np_dtype_bfloat16 = np.dtype([("bf16", object)])


def np_to_onnx_dtype(np_dtype):
import onnx
Expand Down Expand Up @@ -83,6 +87,8 @@ def np_to_model_dtype(np_dtype):
return "TYPE_FP64"
elif np_dtype == np_dtype_string:
return "TYPE_STRING"
elif np_dtype == np_dtype_bfloat16:
return "TYPE_BF16"
return None


Expand All @@ -101,6 +107,8 @@ def np_to_trt_dtype(np_dtype):
return trt.float16
elif np_dtype == np.float32:
return trt.float32
elif np_dtype == np_dtype_bfloat16:
return trt.bfloat16
return None


Expand Down
30 changes: 30 additions & 0 deletions qa/common/gen_qa_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import gen_ensemble_model_utils as emu
import numpy as np
from gen_common import (
np_dtype_bfloat16,
np_to_model_dtype,
np_to_onnx_dtype,
np_to_tf_dtype,
Expand Down Expand Up @@ -2578,6 +2579,18 @@ def create_fixed_models(
)

if FLAGS.tensorrt:
if tu.check_gpus_compute_capability(min_capability=8.0):
create_fixed_models(
FLAGS.models_dir,
np_dtype_bfloat16,
np_dtype_bfloat16,
np_dtype_bfloat16,
)
else:
print(
"Skipping the generation of TensorRT PLAN models for the BF16 datatype!"
)

for vt in [np.float32, np.float16, np.int32, np.uint8]:
create_plan_modelfile(
FLAGS.models_dir, 8, 2, (16,), (16,), (16,), vt, vt, vt, swap=True
Expand Down Expand Up @@ -2854,6 +2867,23 @@ def create_fixed_models(
32,
)

if FLAGS.tensorrt:
if tu.check_gpus_compute_capability(min_capability=8.0):
create_models(
FLAGS.models_dir,
np_dtype_bfloat16,
np_dtype_bfloat16,
np_dtype_bfloat16,
(-1, -1),
(-1, -1),
(-1, -1),
0,
)
else:
print(
"Skipping the generation of TensorRT PLAN models for the BF16 datatype!"
)

if FLAGS.ensemble:
# Create utility models used in ensemble
# nop (only creates model config, should add model file before use)
Expand Down
52 changes: 48 additions & 4 deletions qa/common/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@

_last_request_id = 0

# Numpy does not support the BF16 datatype natively.
# We use this dummy dtype as a representative for BF16.
np_dtype_bfloat16 = np.dtype([("bf16", object)])


def shape_element_count(shape):
cnt = 0
Expand Down Expand Up @@ -103,7 +107,15 @@ def validate_for_trt_model(
input_dtype, output0_dtype, output1_dtype, input_shape, output0_shape, output1_shape
):
"""Return True if input and output dtypes are supported by a TRT model."""
supported_datatypes = [bool, np.int8, np.int32, np.uint8, np.float16, np.float32]
supported_datatypes = [
bool,
np.int8,
np.int32,
np.uint8,
np.float16,
np.float32,
np_dtype_bfloat16,
]
# FIXME: Remove this check when jetson supports TRT 8.5 (DLIS-4256)
if not support_trt_uint8():
supported_datatypes.remove(np.uint8)
Expand Down Expand Up @@ -275,12 +287,19 @@ def validate_for_openvino_model(
return True


def get_dtype_name(dtype):
if dtype == np_dtype_bfloat16:
return "bf16"
else:
return np.dtype(dtype).name


def get_model_name(pf, input_dtype, output0_dtype, output1_dtype):
return "{}_{}_{}_{}".format(
pf,
np.dtype(input_dtype).name,
np.dtype(output0_dtype).name,
np.dtype(output1_dtype).name,
get_dtype_name(input_dtype),
get_dtype_name(output0_dtype),
get_dtype_name(output1_dtype),
)


Expand Down Expand Up @@ -309,6 +328,31 @@ def support_trt_uint8():
return hasattr(trt, "uint8")


def check_gpus_compute_capability(min_capability):
"""
Check if all GPUs have a compute capability greater than or equal to the given value.
Args:
min_capability (float): The minimum required compute capability (e.g., 8.0).
Returns:
bool
"""
import pycuda.driver as cuda

cuda.init()

for device_index in range(cuda.Device.count()):
device = cuda.Device(device_index)
compute_capability = device.compute_capability()
compute_capability_value = compute_capability[0] + compute_capability[1] / 10.0

if compute_capability_value < min_capability:
return False

return True


class TestResultCollector(unittest.TestCase):
# TestResultCollector stores test result and prints it to stdout. In order
# to use this class, unit tests must inherit this class. Use
Expand Down

0 comments on commit 906baa5

Please sign in to comment.