Skip to content

TorchAO multi-gpu support tracker #3284

@jcaip

Description

@jcaip

This issue is to track multi-gpu support for various torchao configs.

Testing the following configs, with this script: https://gist.github.com/jcaip/def30e1ca48b9f3e5b07b68749e7f90b

# fails on CUBLAS error, see: https://www.internalfb.com/phabricator/paste/view/P2021159792
config = MXFPInferenceConfig(
    activation_dtype=torch.float8_e4m3fn,
    weight_dtype=torch.float8_e4m3fn,
    gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
)

# fails on CUTLASS intialization, see: https://www.internalfb.com/phabricator/paste/view/P2021172682
config = MXFPInferenceConfig(
    activation_dtype=torch.float4_e2m1fn_x2,
    weight_dtype=torch.float4_e2m1fn_x2,
    gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
)

# TODO no hardware
config = NVFP4InferenceConfig(
    mm_config=NVFP4MMConfig.DYNAMIC,
    use_dynamic_per_tensor_scale=True,
)

# works out of the box for kernel_preference=TORCH
# need to quantize on hp_device for FBGEMM, see https://github.com/pytorch/ao/pull/3263/files#diff-3f82a055d7180652c52e018ed81fd39258855114f50dc25e80f8f3b2d89da2baR211
# need CUDA_LAUNCH_BLOCKING for bmm support
config = Float8DynamicActivationFloat8WeightConfig(
    granularity=PerRow(),
)

# works, but need to support tuple serialization, see https://github.com/pytorch/ao/pull/3263/files#diff-d3429411a92a37b3a03051ce54abd1efcf76f5792adb35888c5793c6ab396ba1R147
config = Float8DynamicActivationFloat8WeightConfig(
    granularity=(PerBlock((1, 128)), PerBlock((128, 128))),
)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions