Skip to content

Commit a981b8a

Browse files
authored
Fix randint dtypes | test(torchlib) (#1088)
Fix test input types for randint by replacing float input tests to integer inputs.
1 parent 1ee4ee1 commit a981b8a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,44 +1308,44 @@ def sample_inputs_scaled_dot_product_flash_attention(
13081308
opinfo_core.OpInfo(
13091309
"ops.aten.randint",
13101310
aten_name="randint",
1311-
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1311+
dtypes=common_dtype.integral_types(),
13121312
sample_inputs_func=sample_inputs_randint,
13131313
supports_out=False,
13141314
),
13151315
opinfo_core.OpInfo(
13161316
"ops.aten.randint.low",
13171317
aten_name="randint.low",
1318-
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1318+
dtypes=common_dtype.integral_types(),
13191319
sample_inputs_func=sample_inputs_randint_low,
13201320
supports_out=False,
13211321
),
13221322
opinfo_core.OpInfo(
13231323
"ops.aten.randint_like",
13241324
aten_name="randint_like",
1325-
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1325+
dtypes=common_dtype.integral_types(),
13261326
sample_inputs_func=sample_inputs_randint_like,
13271327
supports_out=False,
13281328
),
13291329
opinfo_core.OpInfo(
13301330
"ops.aten.randint_like__dtype",
13311331
op=torch.ops.aten.randint_like,
13321332
aten_name="randint_like",
1333-
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1333+
dtypes=common_dtype.integral_types(),
13341334
sample_inputs_func=sample_inputs_randint_like_dtype,
13351335
supports_out=False,
13361336
),
13371337
opinfo_core.OpInfo(
13381338
"ops.aten.randint_like.low_dtype",
13391339
aten_name="randint_like.low_dtype",
1340-
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1340+
dtypes=common_dtype.integral_types(),
13411341
sample_inputs_func=sample_inputs_randint_like_low_dtype,
13421342
supports_out=False,
13431343
),
13441344
opinfo_core.OpInfo(
13451345
"ops.aten.randint_like.low_dtype__dtype",
13461346
op=torch.ops.aten.randint_like.low_dtype,
13471347
aten_name="randint_like.low_dtype",
1348-
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1348+
dtypes=common_dtype.integral_types(),
13491349
sample_inputs_func=sample_inputs_randint_like_low_dtype_dtype,
13501350
supports_out=False,
13511351
),

0 commit comments

Comments
 (0)