@@ -1271,6 +1271,19 @@ def _where_input_wrangler(
1271
1271
),
1272
1272
TorchLibOpInfo ("polar" , core_ops .aten_polar ),
1273
1273
TorchLibOpInfo ("pow" , core_ops .aten_pow ),
1274
+ TorchLibOpInfo ("prod" , core_ops .aten_prod ).skip (
1275
+ matcher = lambda sample : sample .kwargs .get ("dim" ) is not None
1276
+ or sample .kwargs .get ("keepdim" ) is not None
1277
+ or sample .kwargs .get ("dtype" ) != - 1 ,
1278
+ reason = "this Aten overload only accept 1 inputs: self" ,
1279
+ ),
1280
+ TorchLibOpInfo ("prod_dim_int" , core_ops .aten_prod_dim_int ).skip (
1281
+ matcher = lambda sample : (
1282
+ sample .kwargs .get ("dim" ) is None and sample .kwargs .get ("keepdim" ) is None
1283
+ )
1284
+ or sample .kwargs .get ("dtype" ) != - 1 ,
1285
+ reason = "this Aten overload can accept 3 inputs:(self, dim, keepdim)" ,
1286
+ ),
1274
1287
TorchLibOpInfo ("nn.functional.prelu" , core_ops .aten_prelu ),
1275
1288
TorchLibOpInfo ("ops.aten.rand" , core_ops .aten_rand , nondeterministic = True ),
1276
1289
TorchLibOpInfo ("ops.aten.rand_like" , core_ops .aten_rand_like , nondeterministic = True ),
@@ -2203,6 +2216,7 @@ def _where_input_wrangler(
2203
2216
OPS_DB , "ops.aten._log_softmax" , ("ops.aten._log_softmax_half" ,)
2204
2217
)
2205
2218
ops_test_common .duplicate_opinfo (OPS_DB , "ops.aten._softmax" , ("ops.aten._softmax_half" ,))
2219
+ ops_test_common .duplicate_opinfo (OPS_DB , "prod" , ("prod_dim_int" ,))
2206
2220
ops_test_common .duplicate_opinfo (OPS_DB , "round" , ("round_decimals" ,))
2207
2221
ops_test_common .duplicate_opinfo (OPS_DB , "squeeze" , ("squeeze_dim" ,))
2208
2222
ops_test_common .duplicate_opinfo (OPS_DB , "view_as_complex" , ("view_as_complex_copy" ,))
0 commit comments