diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 3ecd875a..20deb4da 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -139,7 +139,8 @@ def forward(self, inputs): if not self.quantize_activation: result = F.linear(inputs, self.weight) else: - # We have to call jax because we need to do dot(int8, int8)->int32. + # We have to call jax because we need to specify the output dtype of dot + # dot(int8, int8)->bf16. # This semantic cannot be represented in torch. The inferred output dtype # will be int8 in torch, causing the dot result to overflow. result = torchjax.call_jax( @@ -148,7 +149,7 @@ def forward(self, inputs): self.weight, (((2,), (1)), ((), ())), None, - jnp.int32.dtype, + jnp.bfloat16.dtype, ) result = result * self.weight_scaler if self.quantize_activation: