diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 9f5dc0de66..0c4e19d902 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -400,6 +400,26 @@ def log_softmax_decomposition( ) +@register_torch_trt_decomposition(aten.instance_norm, registry=TORCH_TRT_DECOMPOSITIONS) +def instance_norm_decomposition( + input: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + use_input_stats: bool, + momentum: float, + eps: float, + cudnn_enabled: bool, +) -> torch.Tensor: + if use_input_stats: + return torch.nn.functional.group_norm(input, input.shape[1], weight, bias, eps) + else: + return torch.nn.functional.batch_norm( + input, running_mean, running_var, weight, bias, False, momentum, eps + ) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 0b9695e616..b5225dc5c9 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -1587,6 +1587,75 @@ def forward(self, x): f"Log_softmax TRT outputs don't match with the original model.", ) + @parameterized.expand( + [ + ((1, 3, 5), True), + ((1, 3, 5), False), + ((2, 4, 6, 8), True), + ((2, 4, 6, 8), False), + ((3, 6, 9, 12, 15), True), + ((3, 6, 9, 12, 15), False), + ] + ) + def test_lowering_instance_norm(self, shape, use_input_stats): + class TestModule(torch.nn.Module): + def forward(self, input, weight, bias, running_mean=None, running_var=None): + return torch.ops.aten.instance_norm.default( + input, + weight, + bias, + running_mean, + running_var, + use_input_stats, + 0.1, + 1e-05, + True, + ) + + # Operations expected to be removed in the traced graph after decompositions + unexpected_ops = {torch.ops.aten.instance_norm.default} + + inputs = [ + torch.randn(shape, device="cuda"), + torch.randn(shape[1], device="cuda"), + torch.randn(shape[1], device="cuda"), + ] + if not use_input_stats: + inputs += [ + torch.randn(shape[1], device="cuda"), + torch.rand(shape[1], device="cuda"), + ] + + fx_graph = torch.fx.symbolic_trace(TestModule()) + unexpected_ops_seen, _ = lower_graph_testing( + fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1 + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, "dynamo", inputs, min_block_size=1 + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + "Instance_norm TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests()