From b0e78eb60ee6c3c24b83939397f514ea441206c0 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Thu, 8 May 2025 21:24:31 +0000 Subject: [PATCH 1/2] cleanup Signed-off-by: Amog Kamsetty --- .../layers/quantization/utils/w8a8_utils.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b50a94e4f07..3a0d9e4d185 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -344,8 +344,26 @@ def apply( out_dtype = input.dtype # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - input_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) - qinput, x_scale = input_2d, input_scale + if self.cutlass_fp8_supported and input.dtype != current_platform.fp8_dtype(): + assert input.dtype != current_platform.fp8_dtype( + ), "FP8 input to cutlass is not currently implemented" + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + if input.dtype != current_platform.fp8_dtype(): + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + # qinput, x_scale = input_2d, input_scale + qinput = input_2d + x_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) From 5375fc5e511bbbc738b895b6affbda30b60ccd59 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Thu, 8 May 2025 21:27:52 +0000 Subject: [PATCH 2/2] update Signed-off-by: Amog Kamsetty --- .../layers/quantization/utils/w8a8_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 3a0d9e4d185..1c581401e5a 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -361,9 +361,11 @@ def apply( num_token_padding=self.output_padding, use_per_token_if_dynamic=use_per_token_if_dynamic) else: - # qinput, x_scale = input_2d, input_scale - qinput = input_2d - x_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) + if x_scale is not None: + qinput, x_scale = input_2d, input_scale + else: + qinput = input_2d + x_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1)