-
Notifications
You must be signed in to change notification settings - Fork 71
Add ops(unflatten) | feat(atenlib) #612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
b147e1a
6617524
3b7771f
d662e1e
35c417f
81319b6
9803af2
ca464d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -424,6 +424,13 @@ def _sum_input_wrangler( | |
return args, kwargs | ||
|
||
|
||
def _unflatten_input_wrangler( | ||
args: list[Any], kwargs: dict[str, Any] | ||
) -> tuple[list[Any], dict[str, Any]]: | ||
args[1] = np.array(args[1], dtype=np.int64) | ||
return args, kwargs | ||
|
||
|
||
def _where_input_wrangler( | ||
args: list[Any], kwargs: dict[str, Any] | ||
) -> tuple[list[Any], dict[str, Any]]: | ||
|
@@ -609,6 +616,7 @@ def _where_input_wrangler( | |
"tril": core_ops.aten_tril, | ||
"triu": core_ops.aten_triu, | ||
"trunc": core_ops.aten_trunc, | ||
"unflatten": (core_ops.aten_unflatten, _unflatten_input_wrangler), | ||
"unsqueeze": core_ops.aten_unsqueeze, | ||
"view": core_ops.aten_view, | ||
"where": (core_ops.aten_where, _where_input_wrangler), | ||
|
@@ -808,6 +816,11 @@ def _where_input_wrangler( | |
reason="ORT Graph attribute inferencing failed on rank-1 input", | ||
test_class_name="TestOutputConsistencyFullGraph", | ||
), | ||
xfail( | ||
"unflatten", | ||
reason="fixme: ORT fails with invalid model: 'INVALID_ARGUMENT : Failed to load model with error: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1)'", | ||
test_class_name="TestOutputConsistencyFullGraph", | ||
), | ||
) | ||
|
||
|
||
|
@@ -1063,6 +1076,11 @@ def _where_input_wrangler( | |
matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), | ||
reason="this Aten overload only support one tensor as input and one int as args by design", | ||
), | ||
skip( | ||
"unflatten", | ||
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), | ||
reason="0 dim in ONNX is undefined behavior.", | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we create a condition to handle this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To raise a warning? Not sure what we should give back to users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is pytorch's behavior?
I was suggesting handling it like pytorch does if possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PyTorch takes this because 0-d is in their spec. a = torch.randn(4, 0)
# tensor([], size=(4, 0))
torch.unflatten(a, 0, (2, 2))
# tensor([], size=(2, 2, 0)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if torch_libs has a certain way to inform users when the difference between specs is hit? In TorchScript exporter, we could raise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. We could skip this if it is too complicated and come back to it as needed. I would update the skip reason to say something to the effect of "fixme: logic not implemented for size 0 inputs" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't raise errors in atenlib. The exporter can raise errors when it sees a difference after OpSchema is implemented |
||
) | ||
|
||
duplicate_opinfo(OPS_DB, "all", ("all_dim",)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gramalingam @justinchuby
ORTEvaluator doesn't seem to understand: