Skip to content

feat(atenlib): add ops(cross_entropy_loss) #444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
dafe80a
Update core.py
xiaowuhu Feb 16, 2023
2ca53e2
Update core.py
xiaowuhu Feb 16, 2023
f15ec49
fix kwargs issue
xiaowuhu Feb 16, 2023
3437ab4
update
xiaowuhu Feb 16, 2023
aa50a53
Update nn.py
xiaowuhu Feb 16, 2023
02c6a39
Update nn.py
xiaowuhu Feb 16, 2023
810a7f7
Update ops_correctness_test.py
xiaowuhu Feb 16, 2023
23a2bb3
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Feb 16, 2023
5f69ad5
Update ops_correctness_test.py
xiaowuhu Feb 16, 2023
2c61513
Merge branch 'xiaowu/addOps(0216)' of https://github.com/xiaowuhu/onn…
xiaowuhu Feb 16, 2023
601b81f
Update nn.py
xiaowuhu Feb 17, 2023
e6dd73f
Merge branch 'xiaowu/addOps(0216)' of https://github.com/xiaowuhu/onn…
xiaowuhu Feb 17, 2023
6a4228f
Update nn.py
xiaowuhu Feb 17, 2023
a617691
Update ops_correctness_test.py
xiaowuhu Feb 17, 2023
50cabda
Update nn.py
xiaowuhu Feb 17, 2023
96d0a02
Update ops_correctness_test.py
xiaowuhu Feb 17, 2023
4f32ab8
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Feb 17, 2023
0eb5c73
Update nn.py
xiaowuhu Feb 18, 2023
365e58d
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Feb 18, 2023
a8a8d5d
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(0216)
xiaowuhu Feb 20, 2023
818d75c
fix comment
xiaowuhu Feb 20, 2023
253ee00
Update ops_correctness_test.py
xiaowuhu Feb 21, 2023
ea18c3f
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(0216)
xiaowuhu Feb 22, 2023
785f34a
update
xiaowuhu Feb 22, 2023
b9a94c5
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Feb 23, 2023
00a606f
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(0216)
xiaowuhu Feb 28, 2023
ecc6077
Update nn.py
xiaowuhu Feb 28, 2023
b11d894
Update ops_correctness_test.py
xiaowuhu Feb 28, 2023
f3baa71
Update ops_correctness_test.py
xiaowuhu Feb 28, 2023
164884c
Update ops_correctness_test.py
xiaowuhu Feb 28, 2023
dc0bf6e
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Feb 28, 2023
b70c4d5
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Mar 1, 2023
b610562
Update ops_correctness_test.py
xiaowuhu Mar 1, 2023
d4edfe8
Update nn.py
xiaowuhu Mar 1, 2023
6be5eb3
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Mar 1, 2023
58ba823
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Mar 2, 2023
94a0f5c
Merge branch 'main' into xiaowu/addOps(0216)
xiaowuhu Mar 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,45 @@ def aten_conv_depthwise3d(
raise NotImplementedError()


@torch_op("aten::cross_entropy_loss", trace_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment on why trace only is needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to combine the functions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def aten_cross_entropy_loss(
self: TensorType,
target: TensorType,
weight: Optional[TensorType] = None,
reduction: int = 1,
ignore_index: INT64 = -100,
label_smoothing: float = 0.0,
) -> TensorType:
self: TFloatOrBFloat16,
target: Sequence[int],
weight: Optional[TFloatOrBFloat16] = None,
reduction: int = 1, # default is 'mean'
ignore_index: int = -100,
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
) -> TFloatOrBFloat16:
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""

raise NotImplementedError()
if reduction == 0: # "none"
result = _aten_cross_entropy_loss_onnx(self, target, weight, "none", ignore_index)
elif reduction == 1: # "mean"
result = _aten_cross_entropy_loss_onnx(self, target, weight, "mean", ignore_index)
else: # "sum"
result = _aten_cross_entropy_loss_onnx(self, target, weight, "sum", ignore_index)

return result


@torch_op("aten::cross_entropy_loss", overload=True)
def _aten_cross_entropy_loss_onnx(
self: TFloatOrBFloat16,
target: Sequence[int],
weight: Optional[TFloatOrBFloat16],
reduction_str: str,
ignore_index: int,
):
if op.OptionalHasElement(weight):
result, _ = op.SoftmaxCrossEntropyLoss(
self, target, weight, reduction=reduction_str, ignore_index=ignore_index
)
else:
result, _ = op.SoftmaxCrossEntropyLoss(
self, target, reduction=reduction_str, ignore_index=ignore_index
)

return result


@torch_op("aten::elu")
Expand Down
21 changes: 21 additions & 0 deletions onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,17 @@ def _cat_input_wrangler(
return args, kwargs


def _cross_entropy_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "reduction" in kwargs:
reduction_vals = ["none", "mean", "sum"]
value = kwargs["reduction"]
idx = reduction_vals.index(value)
kwargs["reduction"] = idx
return args, kwargs


def _dropout_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -408,6 +419,11 @@ def _where_input_wrangler(
"nn.functional.conv1d": core_ops.aten_conv1d,
"nn.functional.conv2d": core_ops.aten_conv2d,
"nn.functional.conv3d": core_ops.aten_conv3d,
# use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB)
"nn.functional.cross_entropy": (
nn_ops.aten_cross_entropy_loss,
_cross_entropy_input_wrangler,
),
"nn.functional.gelu": nn_ops.aten_gelu,
"nn.functional.linear": nn_ops.aten_linear,
"nn.functional.upsample_nearest2d": (
Expand Down Expand Up @@ -534,6 +550,11 @@ def _where_input_wrangler(
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
reason="String padding is not accepted by aten::conv2d",
),
skip(
"nn.functional.cross_entropy",
matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int),
reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type",
),
skip(
"nn.functional.dropout",
matcher=lambda sample: len(sample.kwargs) == 0 or sample.kwargs.get("p", 0.0) > 0.0,
Expand Down