diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 4c997e4d..c993888c 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -540,6 +540,8 @@ def _qtensor_dot_general( cfg: ..., # DotGeneralRaw, # dequant_dtype: DType, dequant_dtype: jnp.dtype, + lhs_scale_transpose_fn=transpose.lhs_scale_transpose_to_output, + rhs_scale_transpose_fn=transpose.rhs_scale_transpose_to_output, ) -> aqt_tensor.QTensor: """QTensor lax.dot_general replacement.""" @@ -621,17 +623,17 @@ def _maybe_dequant( if cfg.lhs.dequant_mode == DequantMode.OUTPUT: extend_scale = _get_scale_t( qt=lhs_qt, - transpose_fn=transpose.lhs_scale_transpose_to_output, + transpose_fn=lhs_scale_transpose_fn, dimension_numbers=dimension_numbers, lhs_shape=lhs_qin.shape, rhs_shape=rhs_qin.shape, ) - out.scale.extend(extend_scale) + if cfg.rhs.dequant_mode == DequantMode.OUTPUT: extend_scale = _get_scale_t( qt=rhs_qt, - transpose_fn=transpose.rhs_scale_transpose_to_output, + transpose_fn=rhs_scale_transpose_fn, dimension_numbers=dimension_numbers, lhs_shape=lhs_qin.shape, rhs_shape=rhs_qin.shape,