Skip to content

Commit cf3288b

Browse files
fix(atenlib): combine cross_entropy_loss (#555)
- Fix static optional support in ParamSchema - Make `cross_entropy_loss` a script function Follow up of https://github.com/microsoft/onnx-script/pull/444/files#diff-647f5c6f5a51c850e0a6bfd534cc8d1357d5b617c432693af121830b65a324c2R242 --------- Co-authored-by: xiaowuhu <[email protected]>
1 parent 9fe15b2 commit cf3288b

File tree

3 files changed

+20
-28
lines changed

3 files changed

+20
-28
lines changed

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def aten_conv_depthwise3d(
239239
raise NotImplementedError()
240240

241241

242-
@torch_op("aten::cross_entropy_loss", trace_only=True)
242+
@torch_op("aten::cross_entropy_loss")
243243
def aten_cross_entropy_loss(
244244
self: TFloatOrBFloat16,
245245
target: Sequence[int],
@@ -251,30 +251,16 @@ def aten_cross_entropy_loss(
251251
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""
252252

253253
if reduction == 0: # "none"
254-
result = _aten_cross_entropy_loss_onnx(self, target, weight, "none", ignore_index)
255-
elif reduction == 1: # "mean"
256-
result = _aten_cross_entropy_loss_onnx(self, target, weight, "mean", ignore_index)
257-
else: # "sum"
258-
result = _aten_cross_entropy_loss_onnx(self, target, weight, "sum", ignore_index)
259-
260-
return result
261-
262-
263-
@torch_op("aten::cross_entropy_loss", private=True)
264-
def _aten_cross_entropy_loss_onnx(
265-
self: TFloatOrBFloat16,
266-
target: Sequence[int],
267-
weight: Optional[TFloatOrBFloat16],
268-
reduction_str: str,
269-
ignore_index: int,
270-
):
271-
if op.OptionalHasElement(weight):
272254
result, _ = op.SoftmaxCrossEntropyLoss(
273-
self, target, weight, reduction=reduction_str, ignore_index=ignore_index
255+
self, target, weight, reduction="none", ignore_index=ignore_index
274256
)
275-
else:
257+
elif reduction == 2: # "sum"
258+
result, _ = op.SoftmaxCrossEntropyLoss(
259+
self, target, weight, reduction="sum", ignore_index=ignore_index
260+
)
261+
else: # "mean", default
276262
result, _ = op.SoftmaxCrossEntropyLoss(
277-
self, target, reduction=reduction_str, ignore_index=ignore_index
263+
self, target, weight, reduction="mean", ignore_index=ignore_index
278264
)
279265

280266
return result

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,11 @@ def _where_input_wrangler(
486486
"nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d,
487487
"nn.functional.adaptive_avg_pool3d": nn_ops.aten_adaptive_avg_pool3d,
488488
"nn.functional.celu": nn_ops.aten_celu,
489+
# use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB)
490+
"nn.functional.cross_entropy": (
491+
nn_ops.aten_cross_entropy_loss,
492+
_cross_entropy_input_wrangler,
493+
),
489494
"nn.functional.dropout": (core_ops.aten_dropout, _dropout_input_wrangler),
490495
"nn.functional.elu": nn_ops.aten_elu,
491496
"nn.functional.embedding": (core_ops.aten_embedding, _embedding_input_wrangler),
@@ -566,11 +571,6 @@ def _where_input_wrangler(
566571
"nn.functional.conv1d": core_ops.aten_conv1d,
567572
"nn.functional.conv2d": core_ops.aten_conv2d,
568573
"nn.functional.conv3d": core_ops.aten_conv3d,
569-
# use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB)
570-
"nn.functional.cross_entropy": (
571-
nn_ops.aten_cross_entropy_loss,
572-
_cross_entropy_input_wrangler,
573-
),
574574
"nn.functional.gelu": nn_ops.aten_gelu,
575575
"nn.functional.linear": nn_ops.aten_linear,
576576
"nn.functional.upsample_nearest2d": (

onnxscript/values.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,13 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
315315
# args with default value are attributes
316316
schemas = []
317317
for arg in inputs:
318-
param_schema = ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True)
318+
if isinstance(arg.typeinfo, onnx.TypeProto.Optional):
319+
required = False
320+
else:
321+
required = True
322+
param_schema = ParamSchema(
323+
name=arg.name, type=arg.typeinfo, is_input=True, required=required
324+
)
319325
schemas.append(param_schema)
320326

321327
for attr_name in attributes:

0 commit comments

Comments
 (0)