@@ -882,19 +882,24 @@ def __init__(self, config: Glm4MoeConfig, device=None):
882882
883883 @paddle .no_grad ()
884884 def forward (self , x , position_ids ):
885- inv_freq_expanded = (
886- self .inv_freq .unsqueeze (0 )
887- .unsqueeze (- 1 )
888- .cast (paddle .float32 )
889- .expand ([position_ids .shape [0 ], - 1 , 1 ])
890- .to (x .place )
891- )
892- position_ids_expanded = position_ids .unsqueeze (1 ).cast (paddle .float32 )
885+ # NOTE: Paddle's Automatic Mixed Precision (AMP) has a default op whitelist that may automatically cast
886+ # certain operations (like matmul) to FP16/BF16 for performance optimization. However, in scenarios where
887+ # numerical stability is critical (e.g., RoPE init/compute), this conversion can lead to precision loss.
888+ # Disabling auto_cast here ensures the matmul operation runs in the original precision (FP32) as intended.
889+ with paddle .amp .auto_cast (False ):
890+ inv_freq_expanded = (
891+ self .inv_freq .unsqueeze (0 )
892+ .unsqueeze (- 1 )
893+ .cast (paddle .float32 )
894+ .expand ([position_ids .shape [0 ], - 1 , 1 ])
895+ .to (x .place )
896+ )
897+ position_ids_expanded = position_ids .unsqueeze (1 ).cast (paddle .float32 )
893898
894- freqs = paddle .matmul (inv_freq_expanded , position_ids_expanded ).transpose ([0 , 2 , 1 ])
895- emb = paddle .cat ((freqs , freqs ), axis = - 1 )
896- cos = paddle .cos (emb ) * self .attention_scaling
897- sin = paddle .sin (emb ) * self .attention_scaling
899+ freqs = paddle .matmul (inv_freq_expanded , position_ids_expanded ).transpose ([0 , 2 , 1 ])
900+ emb = paddle .cat ((freqs , freqs ), axis = - 1 )
901+ cos = paddle .cos (emb ) * self .attention_scaling
902+ sin = paddle .sin (emb ) * self .attention_scaling
898903
899904 return cos .cast (dtype = x .dtype ), sin .cast (dtype = x .dtype )
900905
0 commit comments