diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index c274d95f9f..87575cea74 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1308,21 +1308,21 @@ def sample_inputs_scaled_dot_product_flash_attention( opinfo_core.OpInfo( "ops.aten.randint", aten_name="randint", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint, supports_out=False, ), opinfo_core.OpInfo( "ops.aten.randint.low", aten_name="randint.low", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_low, supports_out=False, ), opinfo_core.OpInfo( "ops.aten.randint_like", aten_name="randint_like", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_like, supports_out=False, ), @@ -1330,14 +1330,14 @@ def sample_inputs_scaled_dot_product_flash_attention( "ops.aten.randint_like__dtype", op=torch.ops.aten.randint_like, aten_name="randint_like", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_like_dtype, supports_out=False, ), opinfo_core.OpInfo( "ops.aten.randint_like.low_dtype", aten_name="randint_like.low_dtype", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_like_low_dtype, supports_out=False, ), @@ -1345,7 +1345,7 @@ def sample_inputs_scaled_dot_product_flash_attention( "ops.aten.randint_like.low_dtype__dtype", op=torch.ops.aten.randint_like.low_dtype, aten_name="randint_like.low_dtype", - dtypes=common_dtype.floating_types_and(torch.bfloat16), + dtypes=common_dtype.integral_types(), sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype, supports_out=False, ),