Skip to content

Commit c10e88f

Browse files
committed
conv3d
1 parent fc0f8e7 commit c10e88f

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def sample_inputs_conv3d(op_info, device, dtype, requires_grad, **kwargs):
6969
(32,),
7070
{
7171
"stride": (3, 3, 3),
72-
"padding": 2,
72+
"padding": (2, 2, 2),
7373
"dilation": (1, 1, 1),
7474
"groups": 1,
7575
},
@@ -1394,7 +1394,7 @@ def sample_inputs__native_batch_norm_legit_no_stats(
13941394
supports_out=False,
13951395
),
13961396
opinfo_core.OpInfo(
1397-
"nn.functional.conv3d",
1397+
"ops.aten.conv3d",
13981398
aten_name="conv3d",
13991399
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
14001400
sample_inputs_func=sample_inputs_conv3d,

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,7 @@ def _where_input_wrangler(
18791879
reason="String padding is not accepted by aten::conv2d",
18801880
),
18811881
TorchLibOpInfo(
1882-
"nn.functional.conv3d",
1882+
"ops.aten.conv3d",
18831883
core_ops.aten_conv3d,
18841884
trace_only=True,
18851885
tolerance={torch.float32: (3.7e-5, 1.8e-4)},

0 commit comments

Comments
 (0)