-
Notifications
You must be signed in to change notification settings - Fork 63
Implement aten::div.Tensor_mode
| feat(torchlib)
#988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eb03a13
4d98de5
2ca50d5
d22902f
571c076
325b4b1
53dcd0d
1aad280
acbec94
1711722
8af4056
761ed77
35410ef
0657ab3
597b3c8
ea0158c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2190,18 +2190,41 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType | |
raise NotImplementedError() | ||
|
||
|
||
@torch_op(("aten::div", "aten::div.Tensor")) | ||
@torch_op( | ||
( | ||
"aten::div", | ||
"aten::div.Tensor", | ||
"aten::div.Scalar", | ||
# When rounding_mode is None, performs a true division | ||
# https://pytorch.org/docs/stable/generated/torch.div.html | ||
"aten::div.Tensor_mode", | ||
"aten::div.Scalar_mode", | ||
"aten::divide", | ||
"aten::true_divide", | ||
) | ||
) | ||
def aten_div(self: TFloat, other: TFloat) -> TFloat: | ||
"""div.Tensor(Tensor self, Tensor other) -> Tensor""" | ||
|
||
# Int inputs will be promoted to float by PyTorch | ||
return op.Div(self, other) | ||
|
||
|
||
def aten_divide(self: TensorType, other: TensorType) -> TensorType: | ||
"""divide.Tensor(Tensor self, Tensor other) -> Tensor""" | ||
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) | ||
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you recall what kind of attributes have default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. str, int, float, bool attributes can have defaults I think. But I suppose any attributes should be able to have defaults with the attribute proto. Is this what you are asking? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if there is a situation that two ONNX variants only differs on one default attribute. In that case, the dispatcher won't be able to dispatch it. aten_op_attr(X, Y, attr="Good"):
...
aten_op(X, Y):
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True. I would just hope/make sure that we don’t create variants like these. I wonder if there is a way to test it. I think the matching logic you created can come in handy here. We can use that to test all variants registered in torchlib are not compatible with each other. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In dispatcher if we do see this case we can only pick one I suppose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Supposedly, if I pick any from them, there shouldn't be an issue, because they should be equal when it comes to no attr specified. |
||
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" | ||
|
||
raise NotImplementedError() | ||
# TODO(justinchuby): trace_only=False when we use opset19 which supports string comparison | ||
assert rounding_mode in {"trunc", "floor"} | ||
|
||
if rounding_mode == "trunc": | ||
# Rounds the results of the division towards zero. | ||
# Equivalent to C-style integer division | ||
result = aten_trunc(op.Div(self, other)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will move to a common function when #834 is done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed this. Could you share more about why we can use nested OnnxFunction now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a trace only function. So calling functions is fine. However we still do not like calling other aten functions. When we can have nested OnnxFunction calls, I will extract the trunc logic to a common function and call it from aten_trunc and this. Right now I am doing this so |
||
else: # rounding_mode == "floor" | ||
result = op.Floor(op.Div(self, other)) | ||
|
||
return result | ||
|
||
|
||
@torch_op("aten::dot") | ||
|
@@ -2746,10 +2769,11 @@ def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: | |
return op.Floor(self) | ||
|
||
|
||
def aten_floor_divide(self: TensorType, other: TensorType) -> TensorType: | ||
@torch_op("aten::floor_divide") | ||
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: | ||
"""floor_divide(Tensor self, Tensor other) -> Tensor""" | ||
|
||
raise NotImplementedError() | ||
return op.Floor(op.Div(self, other)) | ||
|
||
|
||
def aten_fmax(self: TensorType, other: TensorType) -> TensorType: | ||
|
@@ -6918,12 +6942,6 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType: | |
raise NotImplementedError() | ||
|
||
|
||
def aten_true_divide(self: TensorType, other: TensorType) -> TensorType: | ||
"""true_divide.Tensor(Tensor self, Tensor other) -> Tensor""" | ||
|
||
raise NotImplementedError() | ||
|
||
|
||
@torch_op("aten::trunc") | ||
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16: | ||
"""trunc(Tensor self) -> Tensor""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is dispatcher expected to filter any attribute with None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would consider this to be a better match I think? Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think the dispatcher should strip None keyword args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I think that makes sense. It's just we are altering attributes, and
param_schema
matching is diverged from the inputs/attributes sent into OnnxFunction. It's like there are many indications around dispatching/OnnxFunctionparam_schema
. And it's not good for debugging.Dispatcher alters inputs/attributes with hidden assumptions, but never return the altered inputs/attributes. So in OnnxFunction perspective, it runs directly on that dispatched function with attributes it doesn't need (won't error).