Skip to content
This repository was archived by the owner on Dec 18, 2025. It is now read-only.
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def forward(self, inputs):
self.weight,
(((2,), (1)), ((), ())),
None,
jnp.int32.dtype,
jnp.bfloat16.dtype,
)
result = result * self.weight_scaler
if self.quantize_activation:
Expand Down