@@ -88,6 +88,13 @@ class fp8_gmm_weight_per_block_act_per_tile(torch.autograd.Function):
8888 def forward (ctx , x , w_fp8 , tokens_per_expert ):
8989 seq , din = x .shape
9090 ne , dout , din = w_fp8 .shape
91+ ctx .zero_token_dispatch = seq == 0
92+ ctx .input_shape = x .shape
93+ ctx .weight_shape = w_fp8 .shape
94+
95+ if ctx .zero_token_dispatch :
96+ return x .new_empty ((seq , dout ))
97+
9198 x_fp8 , x_scale = per_tile_quant (x )
9299 (
93100 x_trans_quant_fp8 ,
@@ -104,6 +111,11 @@ def forward(ctx, x, w_fp8, tokens_per_expert):
104111
105112 @staticmethod
106113 def backward (ctx , grad_output_hp ):
114+ if ctx .zero_token_dispatch :
115+ dx = grad_output_hp .new_empty (ctx .input_shape )
116+ dw = grad_output_hp .new_zeros (ctx .weight_shape )
117+ return dx , dw , None
118+
107119 (
108120 x_trans_quant_fp8 ,
109121 x_trans_quant_scale ,
@@ -278,9 +290,10 @@ def forward(self, input: torch.Tensor, tokens_per_expert, decoding: bool = False
278290 weight_fp8 = weight_to_per_block_float8_dynamic .apply (weight , torch .float8_e4m3fn , 128 )
279291
280292 orig_shape = input .shape
281- input = input .view (- 1 , input .shape [- 1 ])
293+ num_tokens = input .numel () // input .shape [- 1 ]
294+ input = input .view (num_tokens , input .shape [- 1 ])
282295 out = fp8_gmm_weight_per_block_act_per_tile .apply (input , weight_fp8 , tokens_per_expert )
283- out = out .view (* orig_shape [:- 1 ], - 1 )
296+ out = out .view (* orig_shape [:- 1 ], self . out_features )
284297 return out
285298
286299 @property
0 commit comments