Skip to content

Commit b804fb0

Browse files
erastorgueva-nvchtruong814cuichenx
authored
Force activations and weights cast to FP32 Jasper Encoder Squeeze-Excite (#14715)
* comment out x = x.float() Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com> * Update Reasoning-SFT.ipynb (#14716) (#14717) Signed-off-by: Chen Cui <chcui@nvidia.com> Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com> * cast SE weights and activations to fp32 Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com> --------- Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com> Signed-off-by: Chen Cui <chcui@nvidia.com> Co-authored-by: Charlie Truong <chtruong@nvidia.com> Co-authored-by: Chen Cui <chcui@nvidia.com>
1 parent 422deee commit b804fb0

File tree

1 file changed

+9
-1
lines changed
  • nemo/collections/asr/parts/submodules

1 file changed

+9
-1
lines changed

nemo/collections/asr/parts/submodules/jasper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,13 @@ def forward_for_export(self, x, lengths):
477477
# Create sample mask - 1 represents value, 0 represents pad
478478
mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device)
479479
mask = ~mask # 0 represents value, 1 represents pad
480-
x = x.float() # For stable AMP, SE must be computed at fp32.
480+
481+
# Ensure SE runs in FP32: cast fc weights and activations to float32
482+
if self.fc[0].weight.dtype != torch.float32:
483+
self.fc.float()
484+
if x.dtype != torch.float32:
485+
x = x.float()
486+
481487
x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0
482488
y = self._se_pool_step(x, mask) # [B, C, 1]
483489
y = y.transpose(1, -1) # [B, 1, C]
@@ -490,6 +496,8 @@ def forward_for_export(self, x, lengths):
490496

491497
y = torch.sigmoid(y)
492498
y = x * y
499+
# Cast back to original dtype for downstream consistency
500+
y = y.to(dtype)
493501
return y, lengths
494502

495503
def _se_pool_step(self, x, mask):

0 commit comments

Comments
 (0)