Skip to content

Commit b4d21bc

Browse files
authored
Fix zero token issue (#1895)
fix zero token
1 parent 9b6791c commit b4d21bc

2 files changed

Lines changed: 15 additions & 6 deletions

File tree

xtuner/v1/float8/float8_gmm_tile_wise.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ class fp8_gmm_weight_per_block_act_per_tile(torch.autograd.Function):
8888
def forward(ctx, x, w_fp8, tokens_per_expert):
8989
seq, din = x.shape
9090
ne, dout, din = w_fp8.shape
91+
ctx.zero_token_dispatch = seq == 0
92+
ctx.input_shape = x.shape
93+
ctx.weight_shape = w_fp8.shape
94+
95+
if ctx.zero_token_dispatch:
96+
return x.new_empty((seq, dout))
97+
9198
x_fp8, x_scale = per_tile_quant(x)
9299
(
93100
x_trans_quant_fp8,
@@ -104,6 +111,11 @@ def forward(ctx, x, w_fp8, tokens_per_expert):
104111

105112
@staticmethod
106113
def backward(ctx, grad_output_hp):
114+
if ctx.zero_token_dispatch:
115+
dx = grad_output_hp.new_empty(ctx.input_shape)
116+
dw = grad_output_hp.new_zeros(ctx.weight_shape)
117+
return dx, dw, None
118+
107119
(
108120
x_trans_quant_fp8,
109121
x_trans_quant_scale,
@@ -278,9 +290,10 @@ def forward(self, input: torch.Tensor, tokens_per_expert, decoding: bool = False
278290
weight_fp8 = weight_to_per_block_float8_dynamic.apply(weight, torch.float8_e4m3fn, 128)
279291

280292
orig_shape = input.shape
281-
input = input.view(-1, input.shape[-1])
293+
num_tokens = input.numel() // input.shape[-1]
294+
input = input.view(num_tokens, input.shape[-1])
282295
out = fp8_gmm_weight_per_block_act_per_tile.apply(input, weight_fp8, tokens_per_expert)
283-
out = out.view(*orig_shape[:-1], -1)
296+
out = out.view(*orig_shape[:-1], self.out_features)
284297
return out
285298

286299
@property

xtuner/v1/module/decoder_layer/moe_decoder_layer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,6 @@ def __init__(
184184
self.moe_act = moe_act_fn_cfg.build()
185185

186186
def forward(self, x, tokens_per_expert, decoding):
187-
# short cut for dispatching 0 token in ep_size >1 case
188-
if x.numel() == 0:
189-
return x
190-
191187
gate_up_out = self.fused_w1w3(x, tokens_per_expert, decoding)
192188
out = self.moe_act(gate_up_out, split_dim=-1)
193189
res = self.fused_w2(out, tokens_per_expert, decoding)

0 commit comments

Comments
 (0)