Skip to content

Commit ac7b5f5

Browse files
authored
[Torch] Enable dtype inference for operations with list-type operands (#4406)
The type inference task consists of two subtasks: shape inference and dtype inference. For operations such as `torch.aten.cat`, shape inference works correctly, but dtype inference fails. This commit fixes that issue. Enable loop unrolling within `DtypeCalculateOp` regions to support dtype inference for operations with list-type operands, such as `torch.aten.cat`. This commit extends `FullyUnrollPrimLoop` pattern to also unroll loops contained in `DtypeCalculateOp` regions, enabling the simplification pass to promote dtype information through operations like torch.aten.cat.
1 parent 68e74f1 commit ac7b5f5

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
3939
PatternRewriter &rewriter) const override {
4040
Location loc = op->getLoc();
4141
MLIRContext *context = op->getContext();
42-
// Only unroll loops if they are contained in a shape calculate region.
42+
// Only unroll loops if they are contained in a shape or dtype calculate
43+
// regions.
4344
Region *region = op->getParentRegion();
4445
Operation *parentOp = region->getParentOp();
45-
if (!parentOp || !isa<Torch::ShapeCalculateOp>(parentOp))
46+
if (!parentOp ||
47+
!isa<Torch::ShapeCalculateOp, Torch::DtypeCalculateOp>(parentOp))
4648
return rewriter.notifyMatchFailure(
47-
op, "Loop is not contained in a shape calculation region.");
49+
op, "Loop is not contained in a shape or dtype calculation regions.");
4850
if (!op.isForLike())
4951
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
5052
int64_t maxTripCount;

test/Dialect/Torch/simplify-dtype-calculations.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,37 @@ func.func @promote_dtypes$scalar_scalar_same_category(%arg0: !torch.int, %arg1:
215215

216216
// -----
217217

218+
// CHECK-LABEL: func.func @promote_dtypes$list_tensors
219+
// CHECK: {{.*}} = torch.aten.cat {{.*}} : !torch.list<vtensor>, !torch.int -> !torch.vtensor<*,f32>
220+
func.func @promote_dtypes$list_tensors(%arg0: !torch.vtensor<[1,8,320,384],f32>) -> !torch.vtensor {
221+
%int0 = torch.constant.int 0
222+
%int1 = torch.constant.int 1
223+
%true = torch.constant.bool true
224+
%int-3 = torch.constant.int -3
225+
%0 = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[1,8,320,384],f32>) -> !torch.list<vtensor>
226+
%1 = torch.dtype.calculate {
227+
%2 = torch.aten.cat %0, %int-3 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
228+
torch.dtype.calculate.yield %2 : !torch.vtensor
229+
} dtypes {
230+
%2 = torch.prim.ListConstruct : () -> !torch.list<tuple<int, int>>
231+
torch.prim.Loop %int1, %true, init() {
232+
^bb0(%arg1: !torch.int):
233+
%5 = torch.aten.__getitem__.t %0, %arg1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor
234+
%6 = torch.aten.dim %5 : !torch.vtensor -> !torch.int
235+
%7 = torch.prim.dtype %5 : !torch.vtensor -> !torch.int
236+
%8 = torch.prim.TupleConstruct %6, %7 : !torch.int, !torch.int -> !torch.tuple<int, int>
237+
%9 = torch.aten.append.t %2, %8 : !torch.list<tuple<int, int>>, !torch.tuple<int, int> -> !torch.list<tuple<int, int>>
238+
torch.prim.Loop.condition %true, iter()
239+
} : (!torch.int, !torch.bool) -> ()
240+
%3 = torch.aten.__getitem__.t %2, %int0 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>
241+
%4:2 = torch.prim.TupleUnpack %3 : !torch.tuple<int, int> -> !torch.int, !torch.int
242+
torch.dtype.calculate.yield.dtypes %4#1 : !torch.int
243+
} : !torch.vtensor
244+
return %1 : !torch.vtensor
245+
}
246+
247+
// -----
248+
218249
// CHECK-LABEL: func.func @refine_dtype$invalid_dtype_for_scalar(
219250
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.number
220251
func.func @refine_dtype$invalid_dtype_for_scalar(%arg0: !torch.int, %arg1: !torch.int) -> !torch.number {

0 commit comments

Comments
 (0)