@@ -355,6 +355,79 @@ def _format_model_and_input_information(onnx_model, inputs):
355
355
)
356
356
357
357
358
+ TORCH_DTYPE_TO_ONNX_STRING = {
359
+ torch .bool : "tensor(bool)" ,
360
+ torch .uint8 : "tensor(uint8)" ,
361
+ torch .int8 : "tensor(int8)" ,
362
+ torch .int16 : "tensor(int16)" ,
363
+ torch .int32 : "tensor(int32)" ,
364
+ torch .int64 : "tensor(int64)" ,
365
+ torch .float16 : "tensor(float16)" ,
366
+ torch .float32 : "tensor(float)" ,
367
+ torch .float64 : "tensor(double)" ,
368
+ torch .complex64 : "tensor(complex64)" ,
369
+ torch .complex128 : "tensor(complex128)" ,
370
+ torch .bfloat16 : "tensor(bfloat16)" ,
371
+ }
372
+
373
+
374
+ def dtype_op_schema_compatible (dtype : torch .dtype , schema : onnx .defs .OpSchema ) -> bool :
375
+ """Checks if the dtype is compatible with the schema.
376
+
377
+ When a dtype is "compatible" with the schema, it means we can use the dtype
378
+ to create sample inputs by OpInfo to test the ONNX function and expect outputs to match.
379
+
380
+ Args:
381
+ dtype: The torch dtype used to create sample inputs by OpInfo.
382
+ schema: The ONNX schema of the function.
383
+
384
+ Returns:
385
+ True if the dtype is compatible with the schema.
386
+ """
387
+ if not schema .inputs :
388
+ # If there are no inputs, we can't check compatibility. Assume it is compatible.
389
+ # e.g. aten_randn has only attributes.
390
+ return True
391
+ if schema .inputs [0 ].name not in {"self" , "input" }:
392
+ # If the name of the first input is not "self" or "input",
393
+ # it is usually an input that is not of the same type as the output.
394
+ # We assume support in this case.
395
+ #
396
+ # For example, `aten_ones(size: IntType, dtype: int = FLOAT.dtype)`
397
+ # has the first input as `size`, which is an integer, but it can support
398
+ # any dtype.
399
+ return True
400
+
401
+ # Otherwise we check the type constraints of the first input.
402
+ # For example, when dtype=torch.float32, and the op being tested has the schema
403
+ # ```
404
+ # OpSchema(
405
+ # name='aten_abs',
406
+ # domain='onnxscript.atenlib',
407
+ # since_version=1,
408
+ # doc='abs(Tensor self) -> Tensor',
409
+ # type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')],
410
+ # inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
411
+ # outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
412
+ # attributes={}
413
+ # )
414
+ # ```
415
+ # we see the first input type is "TReal", corresponding to the type constraint
416
+ # with allowed types ['tensor(float)', 'tensor(int8)', 'tensor(int16)',
417
+ # 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)',
418
+ # 'tensor(bfloat16)'].
419
+ # Since torch.float32 (tensor(float)) is in the allowed types, we return True.
420
+
421
+ first_input_type_name = schema .inputs [0 ].type_str
422
+ # Find the type constraint for the first input by matching the parameter name
423
+ first_input_type_constraint = next (
424
+ (x for x in schema .type_constraints if x .type_param_str == first_input_type_name ), None
425
+ )
426
+ assert first_input_type_constraint is not None
427
+ allowed_type_strs = first_input_type_constraint .allowed_type_strs
428
+ return TORCH_DTYPE_TO_ONNX_STRING [dtype ] in allowed_type_strs
429
+
430
+
358
431
def graph_executor (
359
432
outputs : Sequence [Any ],
360
433
) -> Callable [[Callable [..., Any ], tuple [Any ], dict [str , Any ]], None ]:
0 commit comments