-
Notifications
You must be signed in to change notification settings - Fork 72
feat(atenlib): add ops(cross_entropy_loss) #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dafe80a
2ca53e2
f15ec49
3437ab4
aa50a53
02c6a39
810a7f7
23a2bb3
5f69ad5
2c61513
601b81f
e6dd73f
6a4228f
a617691
50cabda
96d0a02
4f32ab8
0eb5c73
365e58d
a8a8d5d
818d75c
253ee00
ea18c3f
785f34a
b9a94c5
00a606f
ecc6077
b11d894
f3baa71
164884c
dc0bf6e
b70c4d5
b610562
d4edfe8
6be5eb3
58ba823
94a0f5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -239,17 +239,45 @@ def aten_conv_depthwise3d( | |
raise NotImplementedError() | ||
|
||
|
||
@torch_op("aten::cross_entropy_loss", trace_only=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a comment on why trace only is needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible to combine the functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
def aten_cross_entropy_loss( | ||
self: TensorType, | ||
target: TensorType, | ||
weight: Optional[TensorType] = None, | ||
reduction: int = 1, | ||
ignore_index: INT64 = -100, | ||
label_smoothing: float = 0.0, | ||
) -> TensorType: | ||
self: TFloatOrBFloat16, | ||
target: Sequence[int], | ||
weight: Optional[TFloatOrBFloat16] = None, | ||
reduction: int = 1, # default is 'mean' | ||
ignore_index: int = -100, | ||
label_smoothing: float = 0.0, # this was ignored due to ONNX not support | ||
) -> TFloatOrBFloat16: | ||
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor""" | ||
|
||
raise NotImplementedError() | ||
if reduction == 0: # "none" | ||
result = _aten_cross_entropy_loss_onnx(self, target, weight, "none", ignore_index) | ||
elif reduction == 1: # "mean" | ||
result = _aten_cross_entropy_loss_onnx(self, target, weight, "mean", ignore_index) | ||
else: # "sum" | ||
result = _aten_cross_entropy_loss_onnx(self, target, weight, "sum", ignore_index) | ||
|
||
return result | ||
|
||
|
||
|
||
@torch_op("aten::cross_entropy_loss", overload=True) | ||
def _aten_cross_entropy_loss_onnx( | ||
self: TFloatOrBFloat16, | ||
target: Sequence[int], | ||
weight: Optional[TFloatOrBFloat16], | ||
reduction_str: str, | ||
ignore_index: int, | ||
): | ||
if op.OptionalHasElement(weight): | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result, _ = op.SoftmaxCrossEntropyLoss( | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, target, weight, reduction=reduction_str, ignore_index=ignore_index | ||
) | ||
else: | ||
result, _ = op.SoftmaxCrossEntropyLoss( | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, target, reduction=reduction_str, ignore_index=ignore_index | ||
) | ||
|
||
return result | ||
|
||
|
||
@torch_op("aten::elu") | ||
|
Uh oh!
There was an error while loading. Please reload this page.