diff --git a/backends/arm/test/passes/test_rescale_pass.py b/backends/arm/test/passes/test_rescale_pass.py index 90ad502378c..21317c23a8a 100644 --- a/backends/arm/test/passes/test_rescale_pass.py +++ b/backends/arm/test/passes/test_rescale_pass.py @@ -13,7 +13,6 @@ from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized -from torch.testing._internal import optests def test_rescale_op(): @@ -64,7 +63,7 @@ def test_nonzero_zp_for_int32(): ), ] for sample_input in sample_inputs: - with pytest.raises(optests.generate_tests.OpCheckError): + with pytest.raises(Exception, match="opcheck"): torch.library.opcheck(torch.ops.tosa._rescale, sample_input) @@ -87,7 +86,7 @@ def test_zp_outside_range(): ), ] for sample_input in sample_inputs: - with pytest.raises(optests.generate_tests.OpCheckError): + with pytest.raises(Exception, match="opcheck"): torch.library.opcheck(torch.ops.tosa._rescale, sample_input) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 5a0bfe2c37c..d8e5970040b 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -34,12 +34,33 @@ from torch.fx.node import Node from torch.overrides import TorchFunctionMode -from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict from tosa import TosaGraph logger = logging.getLogger(__name__) logger.setLevel(logging.CRITICAL) +# Copied from PyTorch. +# From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict +# To avoid a dependency on _internal stuff. +_torch_to_numpy_dtype_dict = { + torch.bool: np.bool_, + torch.uint8: np.uint8, + torch.uint16: np.uint16, + torch.uint32: np.uint32, + torch.uint64: np.uint64, + torch.int8: np.int8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.bfloat16: np.float32, + torch.complex32: np.complex64, + torch.complex64: np.complex64, + torch.complex128: np.complex128, +} + class QuantizationParams: __slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"] @@ -335,7 +356,7 @@ def run_corstone( output_dtype = node.meta["val"].dtype tosa_ref_output = np.fromfile( os.path.join(intermediate_path, f"out-{i}.bin"), - torch_to_numpy_dtype_dict[output_dtype], + _torch_to_numpy_dtype_dict[output_dtype], ) output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) @@ -349,7 +370,7 @@ def prep_data_for_save( ): if isinstance(data, torch.Tensor): data_np = np.array(data.detach(), order="C").astype( - torch_to_numpy_dtype_dict[data.dtype] + _torch_to_numpy_dtype_dict[data.dtype] ) else: data_np = np.array(data)