Skip to content

feat(atenlib): add ops(nll_loss) #453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 58 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
3a57f2a
update
xiaowuhu Feb 20, 2023
7bc6288
Update nn.py
xiaowuhu Feb 20, 2023
1f39d17
Update nn.py
xiaowuhu Feb 20, 2023
5eedb40
Update nn.py
xiaowuhu Feb 21, 2023
7c53b4f
add two more
xiaowuhu Feb 21, 2023
a14c98b
Update ops_correctness_test.py
xiaowuhu Feb 21, 2023
49ba59a
Update ops_correctness_test.py
xiaowuhu Feb 21, 2023
7bf4875
Update ops_correctness_test.py
xiaowuhu Feb 21, 2023
72752e7
Update nn.py
xiaowuhu Feb 21, 2023
f4bd4a7
Update nn.py
xiaowuhu Feb 21, 2023
3b76a80
Update nn.py
xiaowuhu Feb 21, 2023
aa1c71a
Update nn.py
xiaowuhu Feb 21, 2023
f79d155
Update nn.py
xiaowuhu Feb 22, 2023
df1f19f
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(0220)
xiaowuhu Feb 22, 2023
267f90c
Update ops_correctness_test.py
xiaowuhu Feb 22, 2023
c4b80c1
Update nn.py
xiaowuhu Feb 22, 2023
91fca0c
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Feb 23, 2023
bff6e2e
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Feb 24, 2023
a1a696c
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Feb 26, 2023
fa3bf8b
update
xiaowuhu Feb 27, 2023
f87e1b1
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Feb 27, 2023
688f1a9
update
xiaowuhu Feb 27, 2023
5c7a345
Merge branch 'xiaowu/addOps(0220)' of https://github.com/xiaowuhu/onn…
xiaowuhu Feb 27, 2023
cd76dd0
update
xiaowuhu Feb 28, 2023
e942a33
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 1, 2023
dc95786
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(0220)
xiaowuhu Mar 1, 2023
e819722
Update nn.py
xiaowuhu Mar 1, 2023
1a9da5f
Merge remote-tracking branch 'upstream/main' into xiaowu/addOps(0220)
xiaowuhu Mar 1, 2023
250ae1d
fix comments
xiaowuhu Mar 1, 2023
28d299f
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 1, 2023
fede72a
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 1, 2023
623f40f
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 2, 2023
a994600
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 2, 2023
17c05ad
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 6, 2023
3d7b382
Merge branch 'main' into pr/453
xiaowuhu Mar 7, 2023
2d13ae6
fix comment
xiaowuhu Mar 7, 2023
fce6009
Merge branch 'xiaowu/addOps(0220)' of https://github.com/xiaowuhu/onn…
xiaowuhu Mar 7, 2023
3819a27
Update nn.py
xiaowuhu Mar 7, 2023
6bc4ce7
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 7, 2023
4e8bf68
Update nn.py
xiaowuhu Mar 7, 2023
b334f48
Merge branch 'xiaowu/addOps(0220)' of https://github.com/xiaowuhu/onn…
xiaowuhu Mar 7, 2023
26baeb2
Update nn.py
xiaowuhu Mar 7, 2023
7588543
Update nn.py
xiaowuhu Mar 7, 2023
430abee
update
xiaowuhu Mar 7, 2023
6c943e0
Update onnxscript/_internal/param_manipulation.py
justinchuby Mar 7, 2023
ae46f06
Merge branch 'main' into pr/453
xiaowuhu Mar 7, 2023
1e5b9ac
Update ops_correctness_test.py
xiaowuhu Mar 7, 2023
fcf143b
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 8, 2023
cb6b4b1
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 8, 2023
25a86c7
Merge branch 'main' into pr/453
xiaowuhu Mar 14, 2023
56acc61
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 14, 2023
01c66e7
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 14, 2023
759b514
Merge branch 'main' into xiaowu/addOps(0220)
xiaowuhu Mar 15, 2023
d4f22df
update
xiaowuhu Mar 15, 2023
ab8bf2e
Update ops_correctness_test.py
xiaowuhu Mar 15, 2023
e051f47
Merge branch 'xiaowu/addOps(0220)' of https://github.com/xiaowuhu/onn…
xiaowuhu Mar 15, 2023
901764c
Update nn.py
xiaowuhu Mar 15, 2023
1558288
Update ops_correctness_test.py
xiaowuhu Mar 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 67 additions & 6 deletions onnxscript/function_libs/torch_aten/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,16 +769,77 @@ def aten_multilabel_margin_loss_forward(
raise NotImplementedError()


@torch_op("aten::nll_loss")
def aten_nll_loss(
self: TensorType,
target: TensorType,
weight: Optional[TensorType] = None,
self: TFloat,
target: INT64,
reduction: int = 1,
ignore_index: INT64 = -100,
) -> TensorType:
ignore_index: int = -100,
) -> TFloat:
"""nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor"""

raise NotImplementedError()
rank_self = op.Size(op.Shape(self))
if rank_self == 1: # self rank should be at least 2
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))

rank_target = op.Size(op.Shape(target))
if rank_target == 0: # target rank should be at least 1
target = op.Unsqueeze(target, op.Constant(value_ints=[0]))

if reduction == 0:
result = op.NegativeLogLikelihoodLoss(
self, target, ignore_index=ignore_index, reduction="none"
)
elif reduction == 1:
result = op.NegativeLogLikelihoodLoss(
self, target, ignore_index=ignore_index, reduction="mean"
)
else: # assert reduction == 2
result = op.NegativeLogLikelihoodLoss(
self, target, ignore_index=ignore_index, reduction="sum"
)

if rank_self == 1:
result = op.Squeeze(result)

return result


@torch_op("aten::nll_loss", overload=True)
def aten_nll_loss_weight(
self: TFloat,
target: INT64,
weight: TFloat,
reduction: int = 1,
ignore_index: int = -100,
) -> TFloat:
"""nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor"""

rank_self = op.Size(op.Shape(self))
if rank_self == 1: # self rank should be at least 2
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))

rank_target = op.Size(op.Shape(target))
if rank_target == 0: # target rank should be at least 1
target = op.Unsqueeze(target, op.Constant(value_ints=[0]))

if reduction == 0:
result = op.NegativeLogLikelihoodLoss(
self, target, weight, ignore_index=ignore_index, reduction="none"
)
elif reduction == 1:
result = op.NegativeLogLikelihoodLoss(
self, target, weight, ignore_index=ignore_index, reduction="mean"
)
else:
result = op.NegativeLogLikelihoodLoss(
self, target, weight, ignore_index=ignore_index, reduction="sum"
)

if rank_self == 1:
result = op.Squeeze(result)

return result


def aten_nll_loss2d(
Expand Down
27 changes: 27 additions & 0 deletions onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,17 @@ def _full_input_wrangler(
return args, kwargs


def _nll_loss_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "reduction" in kwargs:
# aten_nll_loss can only accept integer argument instead of string
reduction_vals = ["none", "mean", "sum"]
value = kwargs["reduction"]
kwargs["reduction"] = reduction_vals.index(value)
return args, kwargs


def _upsample_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -429,6 +440,8 @@ def _where_input_wrangler(
"nn.functional.embedding": (core_ops.aten_embedding, _embedding_input_wrangler),
"nn.functional.leaky_relu": nn_ops.aten_leaky_relu,
"nn.functional.logsigmoid": nn_ops.aten_log_sigmoid,
"nn.functional.nll_loss_weight": (nn_ops.aten_nll_loss_weight, _nll_loss_input_wrangler),
"nn.functional.nll_loss": (nn_ops.aten_nll_loss, _nll_loss_input_wrangler),
"nn.functional.relu": nn_ops.aten_relu,
"nn.functional.relu6": nn_ops.aten_relu6,
"nn.functional.selu": core_ops.aten_selu,
Expand Down Expand Up @@ -752,6 +765,16 @@ def _where_input_wrangler(
matcher=lambda sample: len(sample.kwargs) == 0 or sample.kwargs.get("p", 0.0) > 0.0,
reason="dropout is random so the result not match",
),
skip(
"nn.functional.nll_loss",
matcher=lambda sample: "weight" in sample.kwargs,
reason="this Aten overload doesn't accept weight as kwargs",
),
skip(
"nn.functional.nll_loss_weight",
matcher=lambda sample: "weight" not in sample.kwargs,
reason="this Aten overload need weight as kwargs",
),
skip(
"nn.functional.upsample_nearest2d",
# Shape should be [N, C, H, W]
Expand Down Expand Up @@ -796,6 +819,8 @@ def _where_input_wrangler(
),
)

duplicate_opinfo(OPS_DB, "nn.functional.nll_loss", ("nn.functional.nll_loss_weight",))

duplicate_opinfo(
OPS_DB,
"min",
Expand Down Expand Up @@ -878,6 +903,8 @@ def _convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
continue
if key == "dtype":
value = TORCH_TYPE_TO_ONNX[value]
if isinstance(value, torch.Tensor):
value = np.array(value)
new_kwargs[key] = value
return new_kwargs

Expand Down