Commit 06a0a5c
Split arg{max,min} into arg{max,min}_dim and improve function signatures | feat(atenlib) (#677)
To remove trace_only and allow the dispatcher to select the correct
function, we explicitly split argmax with argmax_dim and fixed the
function return type.
---
One potential problem as below code:
```python
@torch_op("aten::argmax", overload=True)
def aten_argmax_dim(self: TReal, dim: int, keepdim: bool = False) -> TInt:
self_is_scaler = op.Size(op.Shape(self)) == 0
if self_is_scaler:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.ArgMax(self, axis=dim, keepdims=keepdim)
if self_is_scaler:
result = op.Squeeze(result)
return result
```
The above code works well. But when it was rewritten to below (reduce
one if/else), the ShapeInference will fail:
```python
@torch_op("aten::argmax", overload=True)
def aten_argmax_dim(self: TReal, dim: int, keepdim: bool = False) -> TInt:
self_is_scaler = op.Size(op.Shape(self)) == 0
if self_is_scaler:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
result = op.ArgMax(self, axis=dim, keepdims=keepdim)
result = op.Squeeze(result)
else:
result = op.ArgMax(self, axis=dim, keepdims=keepdim)
return result
```
---------
Co-authored-by: Justin Chu <justinchu@microsoft.com>1 parent 9c69053 commit 06a0a5c
2 files changed
Lines changed: 59 additions & 21 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
17 | | - | |
| 17 | + | |
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| |||
520 | 520 | | |
521 | 521 | | |
522 | 522 | | |
523 | | - | |
524 | | - | |
525 | | - | |
526 | | - | |
| 523 | + | |
| 524 | + | |
527 | 525 | | |
528 | 526 | | |
529 | | - | |
530 | | - | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
531 | 532 | | |
532 | | - | |
| 533 | + | |
533 | 534 | | |
534 | 535 | | |
535 | | - | |
536 | | - | |
| 536 | + | |
| 537 | + | |
537 | 538 | | |
538 | 539 | | |
539 | 540 | | |
| |||
547 | 548 | | |
548 | 549 | | |
549 | 550 | | |
550 | | - | |
551 | | - | |
552 | | - | |
553 | | - | |
| 551 | + | |
| 552 | + | |
554 | 553 | | |
555 | 554 | | |
556 | | - | |
557 | | - | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
558 | 560 | | |
559 | | - | |
| 561 | + | |
560 | 562 | | |
561 | 563 | | |
562 | | - | |
563 | | - | |
| 564 | + | |
| 565 | + | |
564 | 566 | | |
565 | 567 | | |
566 | 568 | | |
| |||
Lines changed: 38 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1209 | 1209 | | |
1210 | 1210 | | |
1211 | 1211 | | |
1212 | | - | |
| 1212 | + | |
| 1213 | + | |
| 1214 | + | |
| 1215 | + | |
| 1216 | + | |
| 1217 | + | |
| 1218 | + | |
| 1219 | + | |
| 1220 | + | |
| 1221 | + | |
| 1222 | + | |
| 1223 | + | |
| 1224 | + | |
| 1225 | + | |
| 1226 | + | |
| 1227 | + | |
| 1228 | + | |
| 1229 | + | |
1213 | 1230 | | |
1214 | 1231 | | |
1215 | 1232 | | |
| |||
1218 | 1235 | | |
1219 | 1236 | | |
1220 | 1237 | | |
1221 | | - | |
| 1238 | + | |
| 1239 | + | |
| 1240 | + | |
| 1241 | + | |
| 1242 | + | |
| 1243 | + | |
| 1244 | + | |
| 1245 | + | |
| 1246 | + | |
| 1247 | + | |
| 1248 | + | |
| 1249 | + | |
| 1250 | + | |
| 1251 | + | |
| 1252 | + | |
| 1253 | + | |
| 1254 | + | |
| 1255 | + | |
1222 | 1256 | | |
1223 | 1257 | | |
1224 | 1258 | | |
| |||
1631 | 1665 | | |
1632 | 1666 | | |
1633 | 1667 | | |
| 1668 | + | |
| 1669 | + | |
1634 | 1670 | | |
1635 | 1671 | | |
1636 | 1672 | | |
| |||
0 commit comments