diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index e2e7097f6b..ecde051e36 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -18,7 +18,7 @@ from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp +from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import DTensor from torch.testing._internal.common_cuda import TEST_CUDA diff --git a/torchao/testing/__init__.py b/torchao/testing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/testing/float8/__init__.py b/torchao/testing/float8/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/float8/test_fsdp2/fsdp2_common.py b/torchao/testing/float8/fsdp2_utils.py similarity index 90% rename from test/float8/test_fsdp2/fsdp2_common.py rename to torchao/testing/float8/fsdp2_utils.py index 333206ba41..f558bb11f9 100644 --- a/test/float8/test_fsdp2/fsdp2_common.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -49,9 +49,9 @@ def check_parity_no_mp( precompute_float8_dynamic_scale_for_fsdp(model) if compile_transformer_block: - test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4) + test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4, msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") else: - test_cls.assertEqual(losses[0], losses[1]) + test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") def check_parity_bf16_mp( @@ -86,4 +86,4 @@ def check_parity_bf16_mp( ref_model.parameters(), ref_model_bf16.parameters() ): param_bf16.detach().copy_(param_fp32) - test_cls.assertEqual(losses[0], losses[1]) + test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")