Skip to content

Commit 9bae2b5

Browse files
authored
[torchlib] Fix _log_softmax (#1789)
Fix _log_softmax by moving the IsScalar call to the top so it can be eagerly evaluated. Also specify the squeeze axis explicitly to improve compatibility with ORT: microsoft/onnxruntime#21661 This should fix a runtime error in XGLMForCausalLM
1 parent b1f4942 commit 9bae2b5

File tree

1 file changed

+8
-4
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+8
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ def aten__log_softmax_half(
8282
) -> FLOAT:
8383
"""_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor"""
8484

85-
# trace_only because we need to cast conditionally based on half_to_float
85+
self_is_scalar = IsScalar(self)
8686
if half_to_float:
8787
self = op.Cast(self, to=FLOAT.dtype)
88-
89-
return aten__log_softmax(self, dim, half_to_float)
88+
if self_is_scalar:
89+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
90+
result = op.LogSoftmax(self, axis=dim)
91+
if self_is_scalar:
92+
result = op.Squeeze(result, op.Constant(value_ints=[0]))
93+
return result
9094

9195

9296
@torch_op("aten::_log_softmax", traceable=True)
@@ -101,7 +105,7 @@ def aten__log_softmax(
101105
if self_is_scalar:
102106
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
103107
result = op.LogSoftmax(self, axis=dim)
104-
if self_is_scalar: # squeeze to scalar due to input is scalar
108+
if self_is_scalar:
105109
result = op.Squeeze(result)
106110
return result
107111

0 commit comments

Comments
 (0)