From 03d23ee90728be359d04d965cae469680a51bdb8 Mon Sep 17 00:00:00 2001 From: Zuby A Date: Sun, 23 Mar 2025 20:33:33 -0700 Subject: [PATCH 1/8] Add mixed dtype check for XNNPACK partitioner --- .../xnnpack/partition/config/xnnpack_config.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 20018610fce..1214ec5fbda 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -144,9 +144,10 @@ def check_common_constraints( return True def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): - # Check inputs are valid dtypes + # Check inputs are valid and have the same dtypes # Gather all args which are nodes args_to_check = [] + reference_dtype = None for arg in node.args: if isinstance(arg, list) or isinstance(arg, tuple): for item in arg: @@ -174,11 +175,17 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): if arg_val.dtype not in valid_dtypes: return False + if reference_dtype is None: + reference_dtype = arg_val.dtype + elif arg_val.dtype != reference_dtype: + return False + return True def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): - # Check outputs are valid dtype + # Check outputs are valid and have the same dtypes node_val = node.meta.get("val", None) + reference_dtype = None if node_val is None: return True @@ -192,6 +199,11 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): if val.dtype not in valid_dtypes: return False + if reference_dtype is None: + reference_dtype = val.dtype + elif val.dtype != reference_dtype: + return False + return True def _check_node_has_valid_dtype(self, node): From e308e61e3c76b81aca87ee001d1fee3112b3e27e Mon Sep 17 00:00:00 2001 From: Zuby A Date: Sun, 23 Mar 2025 20:40:13 -0700 Subject: [PATCH 2/8] Add tests for mixed dtype cat ops --- backends/xnnpack/test/ops/test_cat.py | 32 ++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index dd551ea3fa7..5f8581c143d 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -23,7 +23,7 @@ def forward(self, *args): x = torch.cat(xs, dim=self.dim) return x + x # Quantize by propagation. - def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): + def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2, mixed_dtype=False): for legacy_mode in (True, False): tester = Tester(module, inputs) @@ -53,12 +53,16 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): if quant: tester.check_not(["torch.ops.quantized_decomposed"]) + # Inverse check for mixed-dtype: original node remains and no delegate node + if mixed_dtype: + tester.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) + ( - tester.check_count( - {"torch.ops.higher_order.executorch_call_delegate": 1} - ) - .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) - .to_executorch() + tester.to_executorch() .serialize() .run_method_and_compare_outputs() ) @@ -249,3 +253,19 @@ def forward(self, x, y): def _test_qs8_cat_nhwc2(self): inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)) self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=4) + + def test_fp32_cat_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + torch.randn(1, 2, 3).to(torch.float32), + torch.randn(3, 2, 3).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_cat(self.Cat(), inputs, mixed_dtype=True) \ No newline at end of file From 7740b73bc22e13b8415a5c702fa10aa133acff57 Mon Sep 17 00:00:00 2001 From: Zuby A Date: Sun, 23 Mar 2025 21:32:43 -0700 Subject: [PATCH 3/8] Add test for mixed dtype add ops --- backends/xnnpack/test/ops/test_add.py | 35 +++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/backends/xnnpack/test/ops/test_add.py b/backends/xnnpack/test/ops/test_add.py index 29a87df1303..0a59628d0f9 100644 --- a/backends/xnnpack/test/ops/test_add.py +++ b/backends/xnnpack/test/ops/test_add.py @@ -42,15 +42,24 @@ def forward(self, x): out2 = x + self._constant2 + self._constant3 return out1, out2 - def _test_add(self, inputs): - ( + def _test_add(self, inputs, mixed_dtype=False): + tester = ( Tester(self.Add(), inputs) .export() .check_count({"torch.ops.aten.add.Tensor": 4}) .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) - .to_executorch() + ) + + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) + + ( + tester.to_executorch() .serialize() .run_method_and_compare_outputs() ) @@ -237,3 +246,19 @@ def forward(self, x, z): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_add_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + torch.randn(1, 1, 4, 4).to(torch.float32), + torch.randn(1, 1, 4, 4).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_add(inputs, mixed_dtype=True) \ No newline at end of file From 5d40cd62970e5488c501c34c28bffaaf8661354a Mon Sep 17 00:00:00 2001 From: Zuby A Date: Sun, 23 Mar 2025 21:33:03 -0700 Subject: [PATCH 4/8] Update comments --- backends/xnnpack/test/ops/test_cat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index 5f8581c143d..c401b4e1c94 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -53,8 +53,8 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2, mixed_d if quant: tester.check_not(["torch.ops.quantized_decomposed"]) - # Inverse check for mixed-dtype: original node remains and no delegate node if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node tester.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) else: @@ -264,7 +264,7 @@ def test_fp32_cat_with_mixed_dtype(self): with self.subTest(dtype=str(dtype)): inputs = ( torch.randn(1, 2, 3).to(torch.float32), - torch.randn(3, 2, 3).to(dtype), + torch.randn(1, 2, 3).to(dtype), ) # Set mixed_dtype=True to verify that # no delegate node is inserted and the original node remains in the graph From fb352770cb30ee4ba825f0894aa9ee1ac4885281 Mon Sep 17 00:00:00 2001 From: Zuby A Date: Sun, 23 Mar 2025 21:33:37 -0700 Subject: [PATCH 5/8] Add test for mixed dtype div ops --- backends/xnnpack/test/ops/test_div.py | 35 +++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/backends/xnnpack/test/ops/test_div.py b/backends/xnnpack/test/ops/test_div.py index 9bca5feed48..ca3bcf7b290 100644 --- a/backends/xnnpack/test/ops/test_div.py +++ b/backends/xnnpack/test/ops/test_div.py @@ -27,15 +27,24 @@ def forward(self, x): z = x / x return z - def _test_div(self, inputs): - ( + def _test_div(self, inputs, mixed_dtype=False): + tester = ( Tester(self.Div(), inputs) .export() .check_count({"torch.ops.aten.div.Tensor": 1}) .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) - .to_executorch() + ) + + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count({"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1}) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) + + ( + tester.to_executorch() .serialize() .run_method_and_compare_outputs() ) @@ -67,3 +76,19 @@ def test_fp32_div_single_input(self): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_div_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + (torch.randn(1) + 4).to(torch.float32), + (torch.randn(1) + 4).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_div(inputs, mixed_dtype=True) \ No newline at end of file From 821551fe0dc94ca3840a88907ee597d0ab5af6a7 Mon Sep 17 00:00:00 2001 From: Zuby A Date: Sun, 23 Mar 2025 21:34:36 -0700 Subject: [PATCH 6/8] Add test for mixed dtype mul ops --- backends/xnnpack/test/ops/test_multiply.py | 35 ++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/backends/xnnpack/test/ops/test_multiply.py b/backends/xnnpack/test/ops/test_multiply.py index db50bc5dd44..54b22e74333 100644 --- a/backends/xnnpack/test/ops/test_multiply.py +++ b/backends/xnnpack/test/ops/test_multiply.py @@ -31,15 +31,24 @@ def forward(self, x, y): z = x * y return torch.nn.functional.relu(z) - def _test_mul(self, inputs): - ( + def _test_mul(self, inputs, mixed_dtype=False): + tester = ( Tester(self.Mul(), inputs) .export() .check_count({"torch.ops.aten.mul.Tensor": 1}) .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) - .to_executorch() + ) + + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count({"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1}) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) + + ( + tester.to_executorch() .serialize() .run_method_and_compare_outputs() ) @@ -144,3 +153,19 @@ def test_qs8_mul_relu(self): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_mul_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + torch.randn(1, 1, 4, 4).to(torch.float32), + torch.randn(1, 1, 4, 4).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_mul(inputs, mixed_dtype=True) \ No newline at end of file From 79bd48a75ed0482ae1b1fb7acc13b03f77b79e6c Mon Sep 17 00:00:00 2001 From: Zuby A Date: Sun, 23 Mar 2025 21:35:11 -0700 Subject: [PATCH 7/8] Add test for mixed dtype sub ops --- backends/xnnpack/test/ops/test_sub.py | 35 +++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/backends/xnnpack/test/ops/test_sub.py b/backends/xnnpack/test/ops/test_sub.py index fb3d3d3f948..3a81595a80e 100644 --- a/backends/xnnpack/test/ops/test_sub.py +++ b/backends/xnnpack/test/ops/test_sub.py @@ -27,15 +27,24 @@ def forward(self, x): z = x - x return z - def _test_sub(self, inputs): - ( + def _test_sub(self, inputs, mixed_dtype=False): + tester = ( Tester(self.Sub(), inputs) .export() .check_count({"torch.ops.aten.sub.Tensor": 1}) .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"]) - .to_executorch() + ) + + if mixed_dtype: + # Inverse check for mixed-dtype: original node remains and no delegate node + tester.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1}) + tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) + else: + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"]) + + ( + tester.to_executorch() .serialize() .run_method_and_compare_outputs() ) @@ -149,3 +158,19 @@ def forward(self, x, y): .serialize() .run_method_and_compare_outputs() ) + + def test_fp32_sub_with_mixed_dtype(self): + test_cases = [ + torch.bfloat16, + torch.float16, + torch.int8, + ] + for dtype in test_cases: + with self.subTest(dtype=str(dtype)): + inputs = ( + torch.randn(1, 1, 4, 4).to(torch.float32), + torch.randn(1, 1, 4, 4).to(dtype), + ) + # Set mixed_dtype=True to verify that + # no delegate node is inserted and the original node remains in the graph + self._test_sub(inputs, mixed_dtype=True) \ No newline at end of file From 75218f0bb27c8b028cd0dbde34b39ce040de974d Mon Sep 17 00:00:00 2001 From: Zuby A Date: Sun, 23 Mar 2025 21:50:16 -0700 Subject: [PATCH 8/8] Apply lintrunner suggestions --- .../xnnpack/partition/config/xnnpack_config.py | 2 +- backends/xnnpack/test/ops/test_add.py | 12 +++++------- backends/xnnpack/test/ops/test_cat.py | 16 ++++++++-------- backends/xnnpack/test/ops/test_div.py | 12 +++++------- backends/xnnpack/test/ops/test_multiply.py | 12 +++++------- backends/xnnpack/test/ops/test_sub.py | 12 +++++------- 6 files changed, 29 insertions(+), 37 deletions(-) diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 1214ec5fbda..f247f0631cf 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -202,7 +202,7 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): if reference_dtype is None: reference_dtype = val.dtype elif val.dtype != reference_dtype: - return False + return False return True diff --git a/backends/xnnpack/test/ops/test_add.py b/backends/xnnpack/test/ops/test_add.py index 0a59628d0f9..4cd6532756d 100644 --- a/backends/xnnpack/test/ops/test_add.py +++ b/backends/xnnpack/test/ops/test_add.py @@ -52,17 +52,15 @@ def _test_add(self, inputs, mixed_dtype=False): if mixed_dtype: # Inverse check for mixed-dtype: original node remains and no delegate node - tester.check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}) + tester.check_count( + {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4} + ) tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) else: tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) tester.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) - ( - tester.to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + (tester.to_executorch().serialize().run_method_and_compare_outputs()) def test_fp16_add(self): inputs = (torch.randn(1).to(torch.float16), torch.randn(1).to(torch.float16)) @@ -261,4 +259,4 @@ def test_fp32_add_with_mixed_dtype(self): ) # Set mixed_dtype=True to verify that # no delegate node is inserted and the original node remains in the graph - self._test_add(inputs, mixed_dtype=True) \ No newline at end of file + self._test_add(inputs, mixed_dtype=True) diff --git a/backends/xnnpack/test/ops/test_cat.py b/backends/xnnpack/test/ops/test_cat.py index c401b4e1c94..4455667e952 100644 --- a/backends/xnnpack/test/ops/test_cat.py +++ b/backends/xnnpack/test/ops/test_cat.py @@ -23,7 +23,9 @@ def forward(self, *args): x = torch.cat(xs, dim=self.dim) return x + x # Quantize by propagation. - def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2, mixed_dtype=False): + def _test_cat( + self, module, inputs, cat_num=1, quant=False, quant_ops=2, mixed_dtype=False + ): for legacy_mode in (True, False): tester = Tester(module, inputs) @@ -58,14 +60,12 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2, mixed_d tester.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) else: - tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_count( + {"torch.ops.higher_order.executorch_call_delegate": 1} + ) tester.check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) - ( - tester.to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + (tester.to_executorch().serialize().run_method_and_compare_outputs()) def test_fp16_cat2(self): """ @@ -268,4 +268,4 @@ def test_fp32_cat_with_mixed_dtype(self): ) # Set mixed_dtype=True to verify that # no delegate node is inserted and the original node remains in the graph - self._test_cat(self.Cat(), inputs, mixed_dtype=True) \ No newline at end of file + self._test_cat(self.Cat(), inputs, mixed_dtype=True) diff --git a/backends/xnnpack/test/ops/test_div.py b/backends/xnnpack/test/ops/test_div.py index ca3bcf7b290..4dce39d6dfa 100644 --- a/backends/xnnpack/test/ops/test_div.py +++ b/backends/xnnpack/test/ops/test_div.py @@ -37,17 +37,15 @@ def _test_div(self, inputs, mixed_dtype=False): if mixed_dtype: # Inverse check for mixed-dtype: original node remains and no delegate node - tester.check_count({"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1}) + tester.check_count( + {"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1} + ) tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) else: tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) tester.check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) - ( - tester.to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + (tester.to_executorch().serialize().run_method_and_compare_outputs()) def test_fp16_div(self): # Adding 4 to move distribution away from 0, 4 Std Dev should be far enough @@ -91,4 +89,4 @@ def test_fp32_div_with_mixed_dtype(self): ) # Set mixed_dtype=True to verify that # no delegate node is inserted and the original node remains in the graph - self._test_div(inputs, mixed_dtype=True) \ No newline at end of file + self._test_div(inputs, mixed_dtype=True) diff --git a/backends/xnnpack/test/ops/test_multiply.py b/backends/xnnpack/test/ops/test_multiply.py index 54b22e74333..99d78ee28e1 100644 --- a/backends/xnnpack/test/ops/test_multiply.py +++ b/backends/xnnpack/test/ops/test_multiply.py @@ -41,17 +41,15 @@ def _test_mul(self, inputs, mixed_dtype=False): if mixed_dtype: # Inverse check for mixed-dtype: original node remains and no delegate node - tester.check_count({"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1}) + tester.check_count( + {"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1} + ) tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) else: tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) tester.check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) - ( - tester.to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + (tester.to_executorch().serialize().run_method_and_compare_outputs()) def test_fp16_mul(self): inputs = ( @@ -168,4 +166,4 @@ def test_fp32_mul_with_mixed_dtype(self): ) # Set mixed_dtype=True to verify that # no delegate node is inserted and the original node remains in the graph - self._test_mul(inputs, mixed_dtype=True) \ No newline at end of file + self._test_mul(inputs, mixed_dtype=True) diff --git a/backends/xnnpack/test/ops/test_sub.py b/backends/xnnpack/test/ops/test_sub.py index 3a81595a80e..952cce20cfc 100644 --- a/backends/xnnpack/test/ops/test_sub.py +++ b/backends/xnnpack/test/ops/test_sub.py @@ -37,17 +37,15 @@ def _test_sub(self, inputs, mixed_dtype=False): if mixed_dtype: # Inverse check for mixed-dtype: original node remains and no delegate node - tester.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1}) + tester.check_count( + {"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1} + ) tester.check_not(["torch.ops.higher_order.executorch_call_delegate"]) else: tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) tester.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"]) - ( - tester.to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + (tester.to_executorch().serialize().run_method_and_compare_outputs()) def test_fp16_sub(self): inputs = ( @@ -173,4 +171,4 @@ def test_fp32_sub_with_mixed_dtype(self): ) # Set mixed_dtype=True to verify that # no delegate node is inserted and the original node remains in the graph - self._test_sub(inputs, mixed_dtype=True) \ No newline at end of file + self._test_sub(inputs, mixed_dtype=True)