Skip to content

Commit a05a40f

Browse files
committed
resolve conflict with latest main
Differential Revision: D63048850 Pull Request resolved: pytorch#912
1 parent ab3435c commit a05a40f

File tree

4 files changed

+3
-3
lines changed

4 files changed

+3
-3
lines changed

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
1919
from torchao.float8.float8_linear_utils import convert_to_float8_training
2020
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
21-
from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp
21+
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
2222
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
2323
from torch.distributed._tensor import DTensor
2424
from torch.testing._internal.common_cuda import TEST_CUDA

torchao/testing/__init__.py

Whitespace-only changes.

torchao/testing/float8/__init__.py

Whitespace-only changes.

test/float8/test_fsdp2/fsdp2_common.py renamed to torchao/testing/float8/fsdp2_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def check_parity_no_mp(
4848
):
4949
precompute_float8_dynamic_scale_for_fsdp(model)
5050

51-
test_cls.assertEqual(losses[0], losses[1])
51+
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
5252

5353

5454
def check_parity_bf16_mp(
@@ -83,4 +83,4 @@ def check_parity_bf16_mp(
8383
ref_model.parameters(), ref_model_bf16.parameters()
8484
):
8585
param_bf16.detach().copy_(param_fp32)
86-
test_cls.assertEqual(losses[0], losses[1])
86+
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")

0 commit comments

Comments
 (0)