Skip to content

Commit d125de7

Browse files
committed
rename torchao.testing.float8 to torchao.testing.training
Summary: Most of the code here is applicable to MX, so making the dir name more generic. Test Plan: ```bash ./test/float8/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e4c8486 ghstack-comment-id: 2991842790 Pull Request resolved: #2415
1 parent c412625 commit d125de7

File tree

13 files changed

+11
-8
lines changed

13 files changed

+11
-8
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
get_name_to_shapes_iter,
1717
)
1818

19-
from torchao.testing.float8.roofline_utils import get_specs
19+
from torchao.testing.training.roofline_utils import get_specs
2020

2121

2222
def benchmark_fn_in_sec(f, *args, **kwargs):

benchmarks/float8/float8_roofline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
)
6464
from torchao.prototype.mx_formats import MXLinearConfig
6565
from torchao.quantization import quantize_
66-
from torchao.testing.float8.roofline_utils import (
66+
from torchao.testing.training.roofline_utils import (
6767
get_float8_mem_sympy,
6868
get_gemm_time_sympy,
6969
)

test/float8/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
fp8_tensor_statistics,
5656
tensor_to_scale,
5757
)
58-
from torchao.testing.float8.test_utils import get_test_float8_linear_config
58+
from torchao.testing.training.test_utils import get_test_float8_linear_config
5959
from torchao.utils import is_MI300, is_ROCM
6060

6161
random.seed(0)

test/float8/test_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
hp_tensor_to_float8_dynamic,
3838
)
3939
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
40-
from torchao.testing.float8.test_utils import get_test_float8_linear_config
40+
from torchao.testing.training.test_utils import get_test_float8_linear_config
4141

4242

4343
def _test_compile_base(

test/float8/test_dtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
)
5858
from torchao.float8.float8_utils import tensor_to_scale
5959
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
60-
from torchao.testing.float8.dtensor_utils import ToyModel
60+
from torchao.testing.training.dtensor_utils import ToyModel
6161

6262
torch.set_float32_matmul_precision("high")
6363

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@
4343
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
4444
from torchao.float8.float8_tensor import GemmInputRole
4545
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
46-
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
46+
from torchao.testing.training.fsdp2_utils import (
47+
check_parity_bf16_mp,
48+
check_parity_no_mp,
49+
)
4750

4851
if not is_sm_at_least_89():
4952
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)

test/float8/test_fsdp2_tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
Float8ColwiseParallel,
3333
Float8RowwiseParallel,
3434
)
35-
from torchao.testing.float8.dtensor_utils import ToyModel
35+
from torchao.testing.training.dtensor_utils import ToyModel
3636

3737

3838
def setup_distributed():

test/float8/test_numerics_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
convert_to_float8_training,
3434
)
3535
from torchao.float8.float8_utils import IS_ROCM, compute_error
36-
from torchao.testing.float8.test_utils import get_test_float8_linear_config
36+
from torchao.testing.training.test_utils import get_test_float8_linear_config
3737

3838
torch.manual_seed(0)
3939

0 commit comments

Comments
 (0)