Skip to content

Commit f141dae

Browse files
authored
Skip incompatible dtypes in op_test using OpSchema | test(torchlib) (#700)
This requires ONNX 1.14. I am skipping the op test for ONNX<1.14 altogether. I tested locally to make sure no float32 test is skipped with this mechanism, and that it skips properly according to the type annotation. Currently no new dtypes are enabled because there are too many failures. Requires - #701 - #698
1 parent 6e182f6 commit f141dae

File tree

3 files changed

+87
-1
lines changed

3 files changed

+87
-1
lines changed

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def run_test_output_match(
162162
# An example is nn.functional.upsample_nearest2d, which has a different signature
163163
# than the aten operator upsample_nearest2d
164164
onnx_function, input_wrangler = _split_function_and_wrangler(onnx_function_and_wrangler)
165+
if not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema):
166+
test_suite.skipTest(
167+
f"dtype '{dtype}' is not supported by the op '{op.name}'. "
168+
f"Type constraints: {onnx_function.op_schema.type_constraints}"
169+
)
165170

166171
for i, cpu_sample in enumerate(samples):
167172
inputs = (cpu_sample.input, *cpu_sample.args)
@@ -251,6 +256,10 @@ def run_test_output_match(
251256
raise
252257

253258

259+
@unittest.skipIf(
260+
version_utils.onnx_older_than("1.14"),
261+
"OpSchema not available for functions before ONNX 1.14",
262+
)
254263
class TestOutputConsistencyEager(unittest.TestCase):
255264
"""Test output consistency between the ONNX op run with ONNX eager mode and PyTorch eager mode.
256265
@@ -279,6 +288,10 @@ def test_output_match_opinfo_(
279288
run_test_output_match(self, device, dtype, op, ops_test_common.eager_executor)
280289

281290

291+
@unittest.skipIf(
292+
version_utils.onnx_older_than("1.14"),
293+
"OpSchema not available for functions before ONNX 1.14",
294+
)
282295
class TestOutputConsistencyFullGraph(unittest.TestCase):
283296
"""Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.
284297

onnxscript/tests/function_libs/torch_lib/ops_test_common.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,79 @@ def _format_model_and_input_information(onnx_model, inputs):
355355
)
356356

357357

358+
TORCH_DTYPE_TO_ONNX_STRING = {
359+
torch.bool: "tensor(bool)",
360+
torch.uint8: "tensor(uint8)",
361+
torch.int8: "tensor(int8)",
362+
torch.int16: "tensor(int16)",
363+
torch.int32: "tensor(int32)",
364+
torch.int64: "tensor(int64)",
365+
torch.float16: "tensor(float16)",
366+
torch.float32: "tensor(float)",
367+
torch.float64: "tensor(double)",
368+
torch.complex64: "tensor(complex64)",
369+
torch.complex128: "tensor(complex128)",
370+
torch.bfloat16: "tensor(bfloat16)",
371+
}
372+
373+
374+
def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
375+
"""Checks if the dtype is compatible with the schema.
376+
377+
When a dtype is "compatible" with the schema, it means we can use the dtype
378+
to create sample inputs by OpInfo to test the ONNX function and expect outputs to match.
379+
380+
Args:
381+
dtype: The torch dtype used to create sample inputs by OpInfo.
382+
schema: The ONNX schema of the function.
383+
384+
Returns:
385+
True if the dtype is compatible with the schema.
386+
"""
387+
if not schema.inputs:
388+
# If there are no inputs, we can't check compatibility. Assume it is compatible.
389+
# e.g. aten_randn has only attributes.
390+
return True
391+
if schema.inputs[0].name not in {"self", "input"}:
392+
# If the name of the first input is not "self" or "input",
393+
# it is usually an input that is not of the same type as the output.
394+
# We assume support in this case.
395+
#
396+
# For example, `aten_ones(size: IntType, dtype: int = FLOAT.dtype)`
397+
# has the first input as `size`, which is an integer, but it can support
398+
# any dtype.
399+
return True
400+
401+
# Otherwise we check the type constraints of the first input.
402+
# For example, when dtype=torch.float32, and the op being tested has the schema
403+
# ```
404+
# OpSchema(
405+
# name='aten_abs',
406+
# domain='onnxscript.atenlib',
407+
# since_version=1,
408+
# doc='abs(Tensor self) -> Tensor',
409+
# type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')],
410+
# inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
411+
# outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
412+
# attributes={}
413+
# )
414+
# ```
415+
# we see the first input type is "TReal", corresponding to the type constraint
416+
# with allowed types ['tensor(float)', 'tensor(int8)', 'tensor(int16)',
417+
# 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
418+
# 'tensor(bfloat16)'].
419+
# Since torch.float32 (tensor(float)) is in the allowed types, we return True.
420+
421+
first_input_type_name = schema.inputs[0].type_str
422+
# Find the type constraint for the first input by matching the parameter name
423+
first_input_type_constraint = next(
424+
(x for x in schema.type_constraints if x.type_param_str == first_input_type_name), None
425+
)
426+
assert first_input_type_constraint is not None
427+
allowed_type_strs = first_input_type_constraint.allowed_type_strs
428+
return TORCH_DTYPE_TO_ONNX_STRING[dtype] in allowed_type_strs
429+
430+
358431
def graph_executor(
359432
outputs: Sequence[Any],
360433
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
--index-url=https://download.pytorch.org/whl/nightly/cpu
22
--pre
3-
torch==2.1.0.dev20230418
3+
torch

0 commit comments

Comments
 (0)