diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b50a94e4f07..1c581401e5a 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -344,8 +344,28 @@ def apply( out_dtype = input.dtype # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - input_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) - qinput, x_scale = input_2d, input_scale + if self.cutlass_fp8_supported and input.dtype != current_platform.fp8_dtype(): + assert input.dtype != current_platform.fp8_dtype( + ), "FP8 input to cutlass is not currently implemented" + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + if input.dtype != current_platform.fp8_dtype(): + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + if x_scale is not None: + qinput, x_scale = input_2d, input_scale + else: + qinput = input_2d + x_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1)