-
Notifications
You must be signed in to change notification settings - Fork 367
Open
Labels
Description
This issue is to track multi-gpu support for various torchao configs.
Testing the following configs, with this script: https://gist.github.com/jcaip/def30e1ca48b9f3e5b07b68749e7f90b
# fails on CUBLAS error, see: https://www.internalfb.com/phabricator/paste/view/P2021159792
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
)
# fails on CUTLASS intialization, see: https://www.internalfb.com/phabricator/paste/view/P2021172682
config = MXFPInferenceConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
)
# TODO no hardware
config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.DYNAMIC,
use_dynamic_per_tensor_scale=True,
)
# works out of the box for kernel_preference=TORCH
# need to quantize on hp_device for FBGEMM, see https://github.com/pytorch/ao/pull/3263/files#diff-3f82a055d7180652c52e018ed81fd39258855114f50dc25e80f8f3b2d89da2baR211
# need CUDA_LAUNCH_BLOCKING for bmm support
config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
)
# works, but need to support tuple serialization, see https://github.com/pytorch/ao/pull/3263/files#diff-d3429411a92a37b3a03051ce54abd1efcf76f5792adb35888c5793c6ab396ba1R147
config = Float8DynamicActivationFloat8WeightConfig(
granularity=(PerBlock((1, 128)), PerBlock((128, 128))),
)