11
22from modules .module .quantized .mixin .QuantizedLinearMixin import QuantizedLinearMixin
33from modules .module .quantized .mixin .QuantizedModuleMixin import QuantizedModuleMixin
4+ from modules .util .mm_8bit import mm_8bit as mm_8bit
45from modules .util .quantization_util import (
56 dequantize ,
67 quantize_fp8_axiswise ,
78 quantize_fp8_tensorwise ,
89 quantize_int8_axiswise ,
910 quantize_int8_tensorwise ,
1011)
11- from modules .util .triton_mm_8bit import mm_8bit as triton_mm_8bit
1212
1313import torch
1414from torch import Tensor , nn
@@ -37,13 +37,13 @@ def fp8_forward_tokenwise(x: Tensor, weight: Tensor, weight_scale: float, bias:
3737def int8_backward_axiswise (output : Tensor , weight : Tensor , weight_scale : float ) -> Tensor :
3838 output_8 , output_scale = quantize_int8_axiswise (output , dim = - 1 )
3939 #almost always, grad outputs are already contiguous and this is a no-op. But there are some grad outputs from SDXL that are non-contiguous:
40- mm_res = triton_mm_8bit (output_8 .contiguous (), weight )
40+ mm_res = mm_8bit (output_8 .contiguous (), weight )
4141 return mm_res .float ().mul_ (weight_scale * output_scale ).to (output .dtype )
4242
4343@torch .no_grad ()
4444def fp8_backward_axiswise (output : Tensor , weight : Tensor , weight_scale : float ) -> Tensor :
4545 output_8 , output_scale = quantize_fp8_axiswise (output , dim = - 1 )
46- mm_res = triton_mm_8bit (output_8 .contiguous (), weight )
46+ mm_res = mm_8bit (output_8 .contiguous (), weight )
4747 return mm_res .float ().mul_ (weight_scale * output_scale ).to (output .dtype )
4848
4949
@@ -158,11 +158,11 @@ def benchmark_int8(m, k, n, device = 'cuda'):
158158
159159
160160 run_benchmark (lambda : torch ._int_mm (x_8 , w_8 .T ), "torch mm int" )
161- run_benchmark (lambda : triton_mm_8bit (x_8 , w_8 .T ), "triton mm int" )
161+ run_benchmark (lambda : mm_8bit (x_8 , w_8 .T ), "triton mm int" )
162162 def torch_backward (a , b ):
163163 torch ._int_mm (a , b .T .contiguous ().T )
164164 run_benchmark (lambda : torch_backward (y_8 , w_8 ), "torch mm backward int8" )
165- run_benchmark (lambda : triton_mm_8bit (y_8 , w_8 ), "triton mm backward int8" )
165+ run_benchmark (lambda : mm_8bit (y_8 , w_8 ), "triton mm backward int8" )
166166
167167 run_benchmark (lambda : int8_forward_tokenwise (x , w_8 , w_scale ), "torch forward int" , compile = True )
168168 run_benchmark (lambda : int8_backward_axiswise (y , w_8 , w_scale ), "triton backward int" , compile = True )
@@ -179,11 +179,11 @@ def benchmark_fp8(m, k, n, device = 'cuda'):
179179 one_scale = torch .ones (1 , device = device )
180180
181181 run_benchmark (lambda : torch ._scaled_mm (x_8 , w_8 .T , out_dtype = torch .bfloat16 , scale_a = one_scale .float (), scale_b = w_scale .float ()), "torch mm fp8" )
182- run_benchmark (lambda : triton_mm_8bit (x_8 , w_8 .T ), "triton mm fp8" )
182+ run_benchmark (lambda : mm_8bit (x_8 , w_8 .T ), "triton mm fp8" )
183183 def torch_backward (a , b ):
184184 torch ._scaled_mm (a , b .T .contiguous ().T , out_dtype = torch .bfloat16 , scale_a = one_scale .float (), scale_b = w_scale .float ())
185185 run_benchmark (lambda : torch_backward (y_8 , w_8 ), "torch mm backward fp8" )
186- run_benchmark (lambda : triton_mm_8bit (y_8 , w_8 ), "triton mm backward fp8" )
186+ run_benchmark (lambda : mm_8bit (y_8 , w_8 ), "triton mm backward fp8" )
187187 run_benchmark (lambda : fp8_forward_tokenwise (x , w_8 , w_scale ), "torch forward fp8" , compile = True )
188188 run_benchmark (lambda : fp8_backward_axiswise (y , w_8 , w_scale ), "triton backward fp8" , compile = True )
189189
0 commit comments