diff --git a/python/aitemplate/compiler/transform/fuse_utils.py b/python/aitemplate/compiler/transform/fuse_utils.py index dae123d4c..e7048db72 100644 --- a/python/aitemplate/compiler/transform/fuse_utils.py +++ b/python/aitemplate/compiler/transform/fuse_utils.py @@ -87,6 +87,11 @@ def _find_fusion_root(tensor: Tensor, fusion_patterns: List[Any]) -> int: dst_op_tensor = dst_op._attrs["outputs"] if len(dst_op_tensor) != 1: break + if dst_op_tensor[0]._attrs["is_output"]: + # if we don't break here, dst_op_tensor[0] will be + # eliminated as an intermediate tensor in the linear + # op pattern, but we can't eliminate a graph output + break curr_tensor = dst_op_tensor[0] if fusion_idx != -1: diff --git a/tests/unittest/compiler/test_fuse_mm_elementwise.py b/tests/unittest/compiler/test_fuse_mm_elementwise.py index bd64e5fd0..19c7eb96c 100644 --- a/tests/unittest/compiler/test_fuse_mm_elementwise.py +++ b/tests/unittest/compiler/test_fuse_mm_elementwise.py @@ -947,7 +947,14 @@ def _test_gemm_rcr_bias_activation( self.assertTrue(torch.allclose(Y_pt, y, atol=1e-1, rtol=1e-1)) def _test_gemm_rcr_bias_sigmoid_mul( - self, Ms, N, K, decomposed, testname, dtype="float16" + self, + Ms, + N, + K, + decomposed, + testname, + dtype="float16", + output_in_the_middle=False, ): m_dim = shape_utils.gen_int_var_min_max(Ms, name="M_size") D_shape = [m_dim, N] @@ -963,28 +970,34 @@ def _test_gemm_rcr_bias_sigmoid_mul( output._attrs["name"] = "output_0" output._attrs["is_output"] = True + outputs = [output] + if output_in_the_middle: + sigmoid_tensor._attrs["name"] = "output_1" + sigmoid_tensor._attrs["is_output"] = True + outputs.append(sigmoid_tensor) + # Check value correctness target = detect_target() - module = compile_model(output, target, "./tmp", testname) + module = compile_model(outputs, target, "./tmp", testname) - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "final_tensor": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "gemm_rcr_bias_sigmoid_mul") + if not output_in_the_middle: + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "final_tensor": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "gemm_rcr_bias_sigmoid_mul") for M in Ms: X_pt = get_random_torch_tensor([M, K], dtype) W_pt = get_random_torch_tensor([N, K], dtype) B_pt = get_random_torch_tensor([N], dtype) D_pt = get_random_torch_tensor([M, N], dtype) - Y_pt = torch.cos( - torch.sigmoid(torch.nn.functional.linear(X_pt, W_pt, B_pt)) * D_pt - ) + sigmoid_pt = torch.sigmoid(torch.nn.functional.linear(X_pt, W_pt, B_pt)) + Y_pt = [torch.cos(sigmoid_pt * D_pt)] input_name_to_index = module.get_input_name_to_index_map() inputs = [0, 0, 0, 0] @@ -993,9 +1006,15 @@ def _test_gemm_rcr_bias_sigmoid_mul( inputs[input_name_to_index["input_2"]] = B_pt inputs[input_name_to_index["input_3"]] = D_pt - y = get_torch_empty_tensor([M, N], dtype) - module.run_with_tensors(inputs, [y]) - self.assertTrue(torch.allclose(Y_pt, y, atol=1e-1, rtol=1e-1)) + y = [get_torch_empty_tensor([M, N], dtype)] + + if output_in_the_middle: + # add another tensor to capture sigmoid output from AIT + y.append(get_torch_empty_tensor([M, N], dtype)) + Y_pt.append(sigmoid_pt) + + module.run_with_tensors(inputs, y) + torch.testing.assert_close(Y_pt, y, atol=1e-1, rtol=1e-1) def _test_gemm_rcr_bias_sigmoid_mul_tanh( self, Ms, N, K, decomposed, testname, dtype="float16" @@ -1135,6 +1154,14 @@ def test_gemm_rcr_bias_sigmoid_mul(self): self._test_gemm_rcr_bias_sigmoid_mul( [8], 16, 3, False, "gemm_rcr_bias_sigmoid_mul_need_align" ) + self._test_gemm_rcr_bias_sigmoid_mul( + [8], + 16, + 3, + False, + "gemm_rcr_bias_sigmoid_mul_output_in_the_middle", + output_in_the_middle=True, + ) def test_gemm_rcr_bias_sigmoid_mul_tanh(self): self._test_gemm_rcr_bias_sigmoid_mul_tanh(