-
Notifications
You must be signed in to change notification settings - Fork 72
Skip incompatible dtypes in op_test using OpSchema | test(torchlib) #700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -355,6 +355,41 @@ def _format_model_and_input_information(onnx_model, inputs): | |||
) | |||
|
|||
|
|||
TORCH_DTYPE_TO_ONNX_STRING = { | |||
torch.bool: "tensor(bool)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Commented elsewhere before but don't recall precisely. Does it has to be plain string for onnx types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it needs to be a plain string in onnx's type constraints
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if there is anything like str(onnx.BoolTensorType)
/onnx.BoolTensorType.to_string()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. There probably is but we still need a mapping from torch types to onnx types. Any suggestions?
} | ||
|
||
|
||
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth commenting to explain the context for dtype
. It is the "dtype setting" introduced by opinfo tests, so not any average input dtype. Feels like it holds the assumption that to some extent, all arguments outputs should have the same dtype
.
I think it helps also to brief what it means if a schema is compatible to a dtype. Do all arguments, outputs have to share that same type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added When a dtype is "compatible" with the schema, it means we can use the dtype to create sample inputs by OpInfo to test the ONNX function and expect outputs to match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm very curious if there are edge cases when sample inputs are different type tensors and how opinfo handles it. Thanks for comments. They are very clear now.
test_suite.skipTest( | ||
f"dtype '{dtype}' is not supported by the op '{op.name}'. " | ||
f"Type constraints: {onnx_function.op_schema.type_constraints}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a concern on skipping tests with the manual annotated type constraints. If we wrongly annotate an Op inputs, it could skip the tests it shouldn't, which makes us miss the very first opportunity to catch it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wonder if we can xfail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on what I found in pytorch/pytorch#100265, indeed, if the type constraints are from onnx spec, most likely it's xfailable. Not sure if it's an easy implementation though...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes it will xpass if we do xfail because ort can choose to implement the op for some types intentionally or unintentionally. So that's a little tricky
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about put the logic after runs? Just an initial thought, if it fails, then we check the dtype is supported or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm but that’s still going to skip the tests where the annotation is incorrect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this again. Maybe that’s ok for now since if we declare that something is not supported via the annotations, then the exporter will just need to bridge the types via type promotion logic. When we see unnecessary type promotion in the graph we can then come back and make changes to the annotations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging for now. Feel free to share more thoughts!
} | ||
|
||
|
||
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is good for a quick and general guard for tests and evolve from there.
What are your thoughts on more fine-grained type checking that validates type of each concrete input to the type of the matching input in schema?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we will need to build the same logic in the dispatcher. We can evolve that can use it here. For now I think it's good to use logic that is easy to understand and get right for running test cases. It is always better to say inputs are compatible than not so we don't miss test cases. We can always put an xfail on them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this logic is for reducing the amount of tests we need to manually skip, so accuracy isn't super important here (as long as we don't skip more than we should)
This requires ONNX 1.14. I am skipping the op test for ONNX<1.14 altogether.
I tested locally to make sure no float32 test is skipped with this mechanism, and that it skips properly according to the type annotation.
Currently no new dtypes are enabled because there are too many failures.
Requires
get_schema
in Op | chore!(api) #698