Skip to content

Skip incompatible dtypes in op_test using OpSchema | test(torchlib) #700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 2, 2023
13 changes: 13 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def run_test_output_match(
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function, input_wrangler = _split_function_and_wrangler(onnx_function_and_wrangler)
if not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema):
test_suite.skipTest(
f"dtype '{dtype}' is not supported by the op '{op.name}'. "
f"Type constraints: {onnx_function.op_schema.type_constraints}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a concern on skipping tests with the manual annotated type constraints. If we wrongly annotate an Op inputs, it could skip the tests it shouldn't, which makes us miss the very first opportunity to catch it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. Any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if we can xfail

Copy link
Contributor

@titaiwangms titaiwangms May 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on what I found in pytorch/pytorch#100265, indeed, if the type constraints are from onnx spec, most likely it's xfailable. Not sure if it's an easy implementation though...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes it will xpass if we do xfail because ort can choose to implement the op for some types intentionally or unintentionally. So that's a little tricky

Copy link
Contributor

@titaiwangms titaiwangms May 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about put the logic after runs? Just an initial thought, if it fails, then we check the dtype is supported or not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm but that’s still going to skip the tests where the annotation is incorrect?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this again. Maybe that’s ok for now since if we declare that something is not supported via the annotations, then the exporter will just need to bridge the types via type promotion logic. When we see unnecessary type promotion in the graph we can then come back and make changes to the annotations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging for now. Feel free to share more thoughts!


for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
Expand Down Expand Up @@ -251,6 +256,10 @@ def run_test_output_match(
raise


@unittest.skipIf(
version_utils.onnx_older_than("1.14"),
"OpSchema not available for functions before ONNX 1.14",
)
class TestOutputConsistencyEager(unittest.TestCase):
"""Test output consistency between the ONNX op run with ONNX eager mode and PyTorch eager mode.

Expand Down Expand Up @@ -279,6 +288,10 @@ def test_output_match_opinfo_(
run_test_output_match(self, device, dtype, op, ops_test_common.eager_executor)


@unittest.skipIf(
version_utils.onnx_older_than("1.14"),
"OpSchema not available for functions before ONNX 1.14",
)
class TestOutputConsistencyFullGraph(unittest.TestCase):
"""Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.

Expand Down
73 changes: 73 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,79 @@ def _format_model_and_input_information(onnx_model, inputs):
)


TORCH_DTYPE_TO_ONNX_STRING = {
torch.bool: "tensor(bool)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Commented elsewhere before but don't recall precisely. Does it has to be plain string for onnx types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it needs to be a plain string in onnx's type constraints

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if there is anything like str(onnx.BoolTensorType)/onnx.BoolTensorType.to_string()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. There probably is but we still need a mapping from torch types to onnx types. Any suggestions?

torch.uint8: "tensor(uint8)",
torch.int8: "tensor(int8)",
torch.int16: "tensor(int16)",
torch.int32: "tensor(int32)",
torch.int64: "tensor(int64)",
torch.float16: "tensor(float16)",
torch.float32: "tensor(float)",
torch.float64: "tensor(double)",
torch.complex64: "tensor(complex64)",
torch.complex128: "tensor(complex128)",
torch.bfloat16: "tensor(bfloat16)",
}


def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth commenting to explain the context for dtype. It is the "dtype setting" introduced by opinfo tests, so not any average input dtype. Feels like it holds the assumption that to some extent, all arguments outputs should have the same dtype.

I think it helps also to brief what it means if a schema is compatible to a dtype. Do all arguments, outputs have to share that same type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added When a dtype is "compatible" with the schema, it means we can use the dtype to create sample inputs by OpInfo to test the ONNX function and expect outputs to match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very curious if there are edge cases when sample inputs are different type tensors and how opinfo handles it. Thanks for comments. They are very clear now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is good for a quick and general guard for tests and evolve from there.

What are your thoughts on more fine-grained type checking that validates type of each concrete input to the type of the matching input in schema?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need to build the same logic in the dispatcher. We can evolve that can use it here. For now I think it's good to use logic that is easy to understand and get right for running test cases. It is always better to say inputs are compatible than not so we don't miss test cases. We can always put an xfail on them.

Copy link
Collaborator Author

@justinchuby justinchuby May 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this logic is for reducing the amount of tests we need to manually skip, so accuracy isn't super important here (as long as we don't skip more than we should)

"""Checks if the dtype is compatible with the schema.

When a dtype is "compatible" with the schema, it means we can use the dtype
to create sample inputs by OpInfo to test the ONNX function and expect outputs to match.

Args:
dtype: The torch dtype used to create sample inputs by OpInfo.
schema: The ONNX schema of the function.

Returns:
True if the dtype is compatible with the schema.
"""
if not schema.inputs:
# If there are no inputs, we can't check compatibility. Assume it is compatible.
# e.g. aten_randn has only attributes.
return True
if schema.inputs[0].name not in {"self", "input"}:
# If the name of the first input is not "self" or "input",
# it is usually an input that is not of the same type as the output.
# We assume support in this case.
#
# For example, `aten_ones(size: IntType, dtype: int = FLOAT.dtype)`
# has the first input as `size`, which is an integer, but it can support
# any dtype.
return True

# Otherwise we check the type constraints of the first input.
# For example, when dtype=torch.float32, and the op being tested has the schema
# ```
# OpSchema(
# name='aten_abs',
# domain='onnxscript.atenlib',
# since_version=1,
# doc='abs(Tensor self) -> Tensor',
# type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')],
# inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
# outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
# attributes={}
# )
# ```
# we see the first input type is "TReal", corresponding to the type constraint
# with allowed types ['tensor(float)', 'tensor(int8)', 'tensor(int16)',
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
# 'tensor(bfloat16)'].
# Since torch.float32 (tensor(float)) is in the allowed types, we return True.

first_input_type_name = schema.inputs[0].type_str
# Find the type constraint for the first input by matching the parameter name
first_input_type_constraint = next(
(x for x in schema.type_constraints if x.type_param_str == first_input_type_name), None
)
assert first_input_type_constraint is not None
allowed_type_strs = first_input_type_constraint.allowed_type_strs
return TORCH_DTYPE_TO_ONNX_STRING[dtype] in allowed_type_strs


def graph_executor(
outputs: Sequence[Any],
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
Expand Down