1616import onnxscript
1717from onnxscript .function_libs .torch_aten .ops import core as core_ops
1818from onnxscript .function_libs .torch_aten .ops import nn as nn_ops
19+ from onnxscript .function_libs .torch_aten .ops import special as special_ops
1920
2021T = TypeVar ("T" )
2122
@@ -161,20 +162,39 @@ def duplicate_opinfo(opinfos: list[opinfo_core.OpInfo], name: str, new_names: tu
161162# Modify this section ##########################################################
162163
163164
164- def _amax_amin_kwargs_wrangler (kwargs : dict [str , Any ]) -> dict [str , Any ]:
165+ def _amax_amin_input_wrangler (
166+ args : list [Any ], kwargs : dict [str , Any ]
167+ ) -> tuple [list [Any ], dict [str , Any ]]:
165168 if "dim" not in kwargs :
166169 kwargs ["dim" ] = None
167- return kwargs
170+ return args , kwargs
168171
169172
170- def _upsample_kwargs_wrangler (kwargs : dict [str , Any ]) -> dict [str , Any ]:
173+ def _full_input_wrangler (
174+ args : list [Any ], kwargs : dict [str , Any ]
175+ ) -> tuple [list [Any ], dict [str , Any ]]:
176+ # Remove the self argument
177+ args .pop (0 )
178+ return args , kwargs
179+
180+
181+ def _upsample_input_wrangler (
182+ args : list [Any ], kwargs : dict [str , Any ]
183+ ) -> tuple [list [Any ], dict [str , Any ]]:
171184 if "scale_factor" in kwargs :
172185 kwargs ["scales_h" ] = kwargs ["scale_factor" ]
173186 kwargs ["scales_w" ] = kwargs ["scale_factor" ]
174187 del kwargs ["scale_factor" ]
175188 if "size" in kwargs :
176189 kwargs ["size" ] = np .array (kwargs ["size" ])
177- return kwargs
190+ return args , kwargs
191+
192+
193+ def _logcumsumexp_input_wrangler (
194+ args : list [Any ], kwargs : dict [str , Any ]
195+ ) -> tuple [list [Any ], dict [str , Any ]]:
196+ kwargs ["keepdim" ] = args .pop ()
197+ return args , kwargs
178198
179199
180200# Ops to be tested for numerical consistency between onnx and pytorch
@@ -185,16 +205,16 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
185205 | Callable [..., Any ]
186206 | tuple [
187207 onnxscript .OnnxFunction | Callable [..., Any ],
188- Callable [[dict [str , Any ]], dict [str , Any ]],
208+ Callable [[list [ Any ], dict [str , Any ]], tuple [ list [ Any ], dict [str , Any ] ]],
189209 ],
190210] = {
191211 "abs" : core_ops .aten_abs ,
192212 "acos" : core_ops .aten_acos ,
193213 "acosh" : core_ops .aten_acosh ,
194214 "add" : core_ops .aten_add ,
195215 "addmm" : core_ops .aten_addmm ,
196- "amax" : (core_ops .aten_amax , _amax_amin_kwargs_wrangler ),
197- "amin" : (core_ops .aten_amin , _amax_amin_kwargs_wrangler ),
216+ "amax" : (core_ops .aten_amax , _amax_amin_input_wrangler ),
217+ "amin" : (core_ops .aten_amin , _amax_amin_input_wrangler ),
198218 "arange_start_step" : core_ops .aten_arange_start_step ,
199219 "arange_start" : core_ops .aten_arange_start ,
200220 "arange" : core_ops .aten_arange ,
@@ -219,11 +239,20 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
219239 "expand" : core_ops .aten_expand ,
220240 "erf" : core_ops .aten_erf ,
221241 "fmod" : core_ops .aten_fmod ,
222- # TODO(justinchuby): Test aten::full
242+ "full" : ( core_ops . aten_full , _full_input_wrangler ),
223243 "full_like" : core_ops .aten_full_like ,
224244 "gt" : core_ops .aten_gt ,
225245 "index_select" : core_ops .aten_index_select ,
226246 "isinf" : core_ops .aten_isinf ,
247+ "log" : core_ops .aten_log ,
248+ "log10" : core_ops .aten_log10 ,
249+ "log1p" : core_ops .aten_log1p ,
250+ "log2" : core_ops .aten_log2 ,
251+ "logaddexp" : core_ops .aten_logaddexp ,
252+ "logaddexp2" : core_ops .aten_logaddexp2 ,
253+ "logcumsumexp" : core_ops .aten_logcumsumexp ,
254+ "logdet" : core_ops .aten_logdet ,
255+ "logsumexp" : (core_ops .aten_logsumexp , _logcumsumexp_input_wrangler ),
227256 "lt" : core_ops .aten_lt ,
228257 "matmul" : core_ops .aten_matmul ,
229258 "mm" : core_ops .aten_mm ,
@@ -237,12 +266,13 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
237266 "nn.functional.elu" : nn_ops .aten_elu ,
238267 "nn.functional.leaky_relu" : nn_ops .aten_leaky_relu ,
239268 "nn.functional.linear" : nn_ops .aten_linear ,
269+ "nn.functional.logsigmoid" : nn_ops .aten_log_sigmoid ,
240270 "nn.functional.relu" : nn_ops .aten_relu ,
241271 "nn.functional.relu6" : nn_ops .aten_relu6 ,
242272 "nn.functional.selu" : core_ops .aten_selu ,
243273 "nn.functional.upsample_nearest2d" : (
244274 nn_ops .aten_upsample_nearest2d ,
245- _upsample_kwargs_wrangler ,
275+ _upsample_input_wrangler ,
246276 ),
247277 "nonzero" : core_ops .aten_nonzero ,
248278 "ones_like" : core_ops .aten_ones_like ,
@@ -267,6 +297,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
267297 "unsqueeze" : core_ops .aten_unsqueeze ,
268298 "view" : core_ops .aten_view ,
269299 "where" : core_ops .aten_where ,
300+ "xlogy" : special_ops .aten_special_xlogy ,
270301 "zeros" : core_ops .aten_zeros ,
271302 "zeros_like" : core_ops .aten_zeros_like ,
272303}
@@ -276,7 +307,9 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
276307EXPECTED_SKIPS_OR_FAILS = (
277308 xfail ("amax" , reason = "ONNX Runtime 1.13 does not support ReduceMax-18" ),
278309 xfail ("amin" , reason = "ONNX Runtime 1.13 does not support ReduceMin-18" ),
279- skip ("clamp" , reason = "Enable when onnxscript supports optional inputs" ),
310+ skip ("clamp" , reason = "enable when onnxscript supports optional inputs" ),
311+ xfail ("logcumsumexp" , reason = "naive implementation not numerically stable" ),
312+ xfail ("logsumexp" , reason = "ONNX Runtime 1.13 does not support ReduceLogSumExp-18" ),
280313 xfail (
281314 "nn.functional.linear" ,
282315 reason = "ONNX Runtime thinks the graph is invalid" ,
@@ -358,23 +391,25 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
358391
359392duplicate_opinfo (
360393 OPS_DB ,
361- "nn.functional.upsample_nearest " ,
394+ "arange " ,
362395 (
363- "nn.functional.upsample_nearest1d" ,
364- "nn.functional.upsample_nearest2d" ,
365- "nn.functional.upsample_nearest3d" ,
396+ "arange_start" ,
397+ "arange_start_step" ,
366398 ),
367399)
368400
369401duplicate_opinfo (
370402 OPS_DB ,
371- "arange " ,
403+ "nn.functional.upsample_nearest " ,
372404 (
373- "arange_start" ,
374- "arange_start_step" ,
405+ "nn.functional.upsample_nearest1d" ,
406+ "nn.functional.upsample_nearest2d" ,
407+ "nn.functional.upsample_nearest3d" ,
375408 ),
376409)
377410
411+ duplicate_opinfo (OPS_DB , "new_full" , ("full" ,))
412+
378413
379414# END OF SECTION TO MODIFY #####################################################
380415
@@ -477,13 +512,13 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
477512 )
478513
479514 onnx_function_and_wrangler = OPINFO_FUNCTION_MAPPING [op .name ]
480- kwarg_wrangler = None
515+ input_wrangler = None
481516 if isinstance (onnx_function_and_wrangler , tuple ):
482- # Obtain the kwarg_wrangler that manipulates the OpInfo inputs
517+ # Obtain the input_wrangler that manipulates the OpInfo inputs
483518 # to match the aten operator signature
484519 # An example is nn.functional.upsample_nearest2d, which has a different signature
485520 # than the aten operator upsample_nearest2d
486- onnx_function , kwarg_wrangler = onnx_function_and_wrangler
521+ onnx_function , input_wrangler = onnx_function_and_wrangler
487522 else :
488523 assert callable (onnx_function_and_wrangler )
489524 onnx_function = onnx_function_and_wrangler
@@ -503,8 +538,8 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
503538 continue
504539 input_onnx = [_convert_tensor_to_numpy (x ) for x in inputs ]
505540 kwargs_onnx = _convert_kwargs_for_onnx (cpu_sample .kwargs )
506- if kwarg_wrangler :
507- kwargs_onnx = kwarg_wrangler ( kwargs_onnx )
541+ if input_wrangler :
542+ input_onnx , kwargs_onnx = input_wrangler ( input_onnx , kwargs_onnx )
508543 torch_output = op (* inputs , ** cpu_sample .kwargs )
509544 function_output = onnx_function (* input_onnx , ** kwargs_onnx )
510545
@@ -524,7 +559,9 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
524559 # Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
525560 torch .testing .assert_close (
526561 torch .tensor (function_output ),
527- torch .tensor (torch_output ),
562+ torch_output
563+ if isinstance (torch_output , torch .Tensor )
564+ else torch .tensor (torch_output ),
528565 rtol = rtol ,
529566 atol = atol ,
530567 )
0 commit comments