-
Notifications
You must be signed in to change notification settings - Fork 72
Skip incompatible dtypes in op_test using OpSchema | test(torchlib) #700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
b84ca6a
d042c34
46d893a
c2b992b
afb7cbf
36fea8e
087e384
c2fee91
605500f
c5187f7
b882e42
92f3ca1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -355,6 +355,79 @@ def _format_model_and_input_information(onnx_model, inputs): | |
) | ||
|
||
|
||
TORCH_DTYPE_TO_ONNX_STRING = { | ||
torch.bool: "tensor(bool)", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Commented elsewhere before but don't recall precisely. Does it has to be plain string for onnx types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it needs to be a plain string in onnx's type constraints There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was wondering if there is anything like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. There probably is but we still need a mapping from torch types to onnx types. Any suggestions? |
||
torch.uint8: "tensor(uint8)", | ||
torch.int8: "tensor(int8)", | ||
torch.int16: "tensor(int16)", | ||
torch.int32: "tensor(int32)", | ||
torch.int64: "tensor(int64)", | ||
torch.float16: "tensor(float16)", | ||
torch.float32: "tensor(float)", | ||
torch.float64: "tensor(double)", | ||
torch.complex64: "tensor(complex64)", | ||
torch.complex128: "tensor(complex128)", | ||
torch.bfloat16: "tensor(bfloat16)", | ||
} | ||
|
||
|
||
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's worth commenting to explain the context for I think it helps also to brief what it means if a schema is compatible to a dtype. Do all arguments, outputs have to share that same type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added When a dtype is "compatible" with the schema, it means we can use the dtype to create sample inputs by OpInfo to test the ONNX function and expect outputs to match. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm very curious if there are edge cases when sample inputs are different type tensors and how opinfo handles it. Thanks for comments. They are very clear now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is good for a quick and general guard for tests and evolve from there. What are your thoughts on more fine-grained type checking that validates type of each concrete input to the type of the matching input in schema? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we will need to build the same logic in the dispatcher. We can evolve that can use it here. For now I think it's good to use logic that is easy to understand and get right for running test cases. It is always better to say inputs are compatible than not so we don't miss test cases. We can always put an xfail on them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also this logic is for reducing the amount of tests we need to manually skip, so accuracy isn't super important here (as long as we don't skip more than we should) |
||
"""Checks if the dtype is compatible with the schema. | ||
|
||
When a dtype is "compatible" with the schema, it means we can use the dtype | ||
to create sample inputs by OpInfo to test the ONNX function and expect outputs to match. | ||
|
||
Args: | ||
dtype: The torch dtype used to create sample inputs by OpInfo. | ||
schema: The ONNX schema of the function. | ||
|
||
Returns: | ||
True if the dtype is compatible with the schema. | ||
""" | ||
if not schema.inputs: | ||
# If there are no inputs, we can't check compatibility. Assume it is compatible. | ||
# e.g. aten_randn has only attributes. | ||
return True | ||
if schema.inputs[0].name not in {"self", "input"}: | ||
# If the name of the first input is not "self" or "input", | ||
# it is usually an input that is not of the same type as the output. | ||
# We assume support in this case. | ||
# | ||
# For example, `aten_ones(size: IntType, dtype: int = FLOAT.dtype)` | ||
# has the first input as `size`, which is an integer, but it can support | ||
# any dtype. | ||
return True | ||
|
||
# Otherwise we check the type constraints of the first input. | ||
# For example, when dtype=torch.float32, and the op being tested has the schema | ||
# ``` | ||
# OpSchema( | ||
# name='aten_abs', | ||
# domain='onnxscript.atenlib', | ||
# since_version=1, | ||
# doc='abs(Tensor self) -> Tensor', | ||
# type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], | ||
# inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], | ||
# outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], | ||
# attributes={} | ||
# ) | ||
# ``` | ||
# we see the first input type is "TReal", corresponding to the type constraint | ||
# with allowed types ['tensor(float)', 'tensor(int8)', 'tensor(int16)', | ||
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', | ||
# 'tensor(bfloat16)']. | ||
# Since torch.float32 (tensor(float)) is in the allowed types, we return True. | ||
|
||
first_input_type_name = schema.inputs[0].type_str | ||
# Find the type constraint for the first input by matching the parameter name | ||
first_input_type_constraint = next( | ||
(x for x in schema.type_constraints if x.type_param_str == first_input_type_name), None | ||
) | ||
assert first_input_type_constraint is not None | ||
allowed_type_strs = first_input_type_constraint.allowed_type_strs | ||
return TORCH_DTYPE_TO_ONNX_STRING[dtype] in allowed_type_strs | ||
|
||
|
||
def graph_executor( | ||
outputs: Sequence[Any], | ||
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a concern on skipping tests with the manual annotated type constraints. If we wrongly annotate an Op inputs, it could skip the tests it shouldn't, which makes us miss the very first opportunity to catch it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wonder if we can xfail
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on what I found in pytorch/pytorch#100265, indeed, if the type constraints are from onnx spec, most likely it's xfailable. Not sure if it's an easy implementation though...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes it will xpass if we do xfail because ort can choose to implement the op for some types intentionally or unintentionally. So that's a little tricky
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about put the logic after runs? Just an initial thought, if it fails, then we check the dtype is supported or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm but that’s still going to skip the tests where the annotation is incorrect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this again. Maybe that’s ok for now since if we declare that something is not supported via the annotations, then the exporter will just need to bridge the types via type promotion logic. When we see unnecessary type promotion in the graph we can then come back and make changes to the annotations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging for now. Feel free to share more thoughts!