diff --git a/onnxscript/function_libs/torch_aten/ops/nn.py b/onnxscript/function_libs/torch_aten/ops/nn.py index beb5f16ddd..f7813ba8ec 100644 --- a/onnxscript/function_libs/torch_aten/ops/nn.py +++ b/onnxscript/function_libs/torch_aten/ops/nn.py @@ -239,7 +239,7 @@ def aten_conv_depthwise3d( raise NotImplementedError() -@torch_op("aten::cross_entropy_loss", trace_only=True) +@torch_op("aten::cross_entropy_loss") def aten_cross_entropy_loss( self: TFloatOrBFloat16, target: Sequence[int], @@ -251,30 +251,16 @@ def aten_cross_entropy_loss( """cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor""" if reduction == 0: # "none" - result = _aten_cross_entropy_loss_onnx(self, target, weight, "none", ignore_index) - elif reduction == 1: # "mean" - result = _aten_cross_entropy_loss_onnx(self, target, weight, "mean", ignore_index) - else: # "sum" - result = _aten_cross_entropy_loss_onnx(self, target, weight, "sum", ignore_index) - - return result - - -@torch_op("aten::cross_entropy_loss", private=True) -def _aten_cross_entropy_loss_onnx( - self: TFloatOrBFloat16, - target: Sequence[int], - weight: Optional[TFloatOrBFloat16], - reduction_str: str, - ignore_index: int, -): - if op.OptionalHasElement(weight): result, _ = op.SoftmaxCrossEntropyLoss( - self, target, weight, reduction=reduction_str, ignore_index=ignore_index + self, target, weight, reduction="none", ignore_index=ignore_index ) - else: + elif reduction == 2: # "sum" + result, _ = op.SoftmaxCrossEntropyLoss( + self, target, weight, reduction="sum", ignore_index=ignore_index + ) + else: # "mean", default result, _ = op.SoftmaxCrossEntropyLoss( - self, target, reduction=reduction_str, ignore_index=ignore_index + self, target, weight, reduction="mean", ignore_index=ignore_index ) return result diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index aa9696a635..f1967dde98 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -486,6 +486,11 @@ def _where_input_wrangler( "nn.functional.adaptive_avg_pool2d": nn_ops.aten_adaptive_avg_pool2d, "nn.functional.adaptive_avg_pool3d": nn_ops.aten_adaptive_avg_pool3d, "nn.functional.celu": nn_ops.aten_celu, + # use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB) + "nn.functional.cross_entropy": ( + nn_ops.aten_cross_entropy_loss, + _cross_entropy_input_wrangler, + ), "nn.functional.dropout": (core_ops.aten_dropout, _dropout_input_wrangler), "nn.functional.elu": nn_ops.aten_elu, "nn.functional.embedding": (core_ops.aten_embedding, _embedding_input_wrangler), @@ -566,11 +571,6 @@ def _where_input_wrangler( "nn.functional.conv1d": core_ops.aten_conv1d, "nn.functional.conv2d": core_ops.aten_conv2d, "nn.functional.conv3d": core_ops.aten_conv3d, - # use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB) - "nn.functional.cross_entropy": ( - nn_ops.aten_cross_entropy_loss, - _cross_entropy_input_wrangler, - ), "nn.functional.gelu": nn_ops.aten_gelu, "nn.functional.linear": nn_ops.aten_linear, "nn.functional.upsample_nearest2d": ( diff --git a/onnxscript/values.py b/onnxscript/values.py index f9e035abf2..9c8dd80e1a 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -315,7 +315,13 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # args with default value are attributes schemas = [] for arg in inputs: - param_schema = ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True) + if isinstance(arg.typeinfo, onnx.TypeProto.Optional): + required = False + else: + required = True + param_schema = ParamSchema( + name=arg.name, type=arg.typeinfo, is_input=True, required=required + ) schemas.append(param_schema) for attr_name in attributes: