Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,12 @@ class ConvertAtenTransposeIntOp
Value outVector = tensor::EmptyOp::create(
rewriter, loc, getAsOpFoldResult(outputDims), elementType);

// Note: The empty tensor type may not match `outType` due to folding
// performed by `getAsOpFoldResult` of `tensor::DimOp`.
// Cast to `outType` if needed to ensure type consistency.
if (outVector.getType() != outType)
outVector = tensor::CastOp::create(rewriter, loc, outType, outVector);

SmallVector<int64_t> permutation(inputRank);
std::iota(permutation.begin(), permutation.end(), 0);
permutation[dim0] = dim1;
Expand Down
25 changes: 25 additions & 0 deletions test/Conversion/TorchToLinalg/datamovement.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,28 @@ func.func @torch.aten.reflection_pad2d(%arg0: !torch.vtensor<[1,1,4,4],f32>) ->
%1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,1,4,4],f32>, !torch.list<int> -> !torch.vtensor<[1,1,8,9],f32>
return %1 : !torch.vtensor<[1,1,8,9],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.transpose.int$dynamic_dims(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,56,56,96],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,56,56,96],f32> -> tensor<1x56x56x96xf32>
// CHECK: %[[VAL_9:.*]] = tensor.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2], [3, 4], [5]] output_shape [1, 8, 7, 8, 7, 96] : tensor<1x56x56x96xf32> into tensor<1x8x7x8x7x96xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x8x8x7x7x96xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[VAL_9]] {{.*}} outs(%[[EMPTY]] {{.*}} permutation = [0, 1, 3, 2, 4, 5]
// CHECK: %[[RESULT_CAST:.*]] = tensor.cast %[[TRANSPOSE]] : tensor<1x8x8x7x7x96xf32> to tensor<?x?x?x?x?x?xf32>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[RESULT_CAST]] : tensor<?x?x?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?,?,?],f32>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?,?,?,?,?],f32>
// CHECK: }
func.func @torch.aten.transpose.int$dynamic_dims(%arg0: !torch.vtensor<[1,56,56,96],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> {
%int1 = torch.constant.int 1
%int8 = torch.constant.int 8
%int7 = torch.constant.int 7
%int96 = torch.constant.int 96
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int1, %int8, %int7, %int8, %int7, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?,?,?],f32>
%2 = torch.aten.transpose.int %1, %int2, %int3 : !torch.vtensor<[?,?,?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?,?,?],f32>
return %2 : !torch.vtensor<[?,?,?,?,?,?],f32>
}
Loading