Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,55 @@ class ConvertAtenUnflattenIntOp
op, "Must be able to either infer expansion dims, or retrieve them "
"from list construct");

auto expandTy = getTypeConverter()->convertType(outputTensorType);
// Check if the output tensor type has all static shapes while the input
// tensor type doesn't Note: unflatten changes the shape, so we need to
// account for dimension mapping:
// - Input dims [0:dimInt) map to output dims [0:dimInt)
// - Input dim [dimInt] is the flattened dimension
// - Output dims [dimInt:dimInt+numSizes) are the unflattened dimensions
// - Input dims [dimInt+1:] map to output dims [dimInt+numSizes:]
auto inputSizes = inputTensorType.getSizes();
bool inputAllStatic = llvm::none_of(
inputSizes, [](int64_t size) { return size == Torch::kUnknownSize; });
bool outputAllStatic = llvm::none_of(
outputSizes, [](int64_t size) { return size == Torch::kUnknownSize; });
Comment on lines +712 to +715
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use type.areAllSizesKnown() here


auto expandTy = cast<RankedTensorType>(
getTypeConverter()->convertType(outputTensorType));
auto inputTy = cast<RankedTensorType>(
getTypeConverter()->convertType(inputTensorType));
self = adaptor.getSelf();

inputSizes = inputTy.getShape();
outputSizes = expandTy.getShape();

if (!inputAllStatic && outputAllStatic) {
// The output tensor type is all static, but the input tensor type is not.
// Construct input with static shapes.
SmallVector<int64_t> refinedInputSizes;
// Copy dims before the flatten dimension from output
for (int64_t i = 0; i < dimInt; ++i) {
refinedInputSizes.push_back(outputSizes[i]);
}

int64_t unflattenedDimSize = 1;
for (int64_t i = dimInt; i < dimInt + numSizes; ++i) {
unflattenedDimSize *= outputSizes[i];
}

// Keep the flattened dimension from input (may be dynamic)
refinedInputSizes.push_back(unflattenedDimSize);
// Copy dims after the flatten dimension from output
for (int64_t i = dimInt + numSizes; i < outputRank; ++i) {
refinedInputSizes.push_back(outputSizes[i]);
}

auto staticInputType =
RankedTensorType::get(refinedInputSizes, inputTy.getElementType());
self = tensor::CastOp::create(rewriter, loc, staticInputType,
adaptor.getSelf());
}

Value expand;
// When there are less than two dynamic reassociation dims, this will lower
// to tensor.expand_shape. Otherwise, this lowers to tensor.reshape.
Expand All @@ -718,14 +766,13 @@ class ConvertAtenUnflattenIntOp
for (int i = dimInt + numSizes; i < outputRank; ++i)
reassociations[i - numSizes + 1].push_back(i);
}
expand = tensor::ExpandShapeOp::create(rewriter, loc, expandTy,
adaptor.getSelf(), reassociations)
expand = tensor::ExpandShapeOp::create(rewriter, loc, expandTy, self,
reassociations)
.getResult();
} else {
reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
reassocSizes);
SmallVector<Value> inputShape =
getTensorSizes(rewriter, loc, adaptor.getSelf());
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, self);
inputShape = castIndexVectorToInt64Vector(rewriter, loc, inputShape);
SmallVector<Value> outputShape(inputShape.begin(),
inputShape.begin() + dimInt);
Expand All @@ -740,9 +787,9 @@ class ConvertAtenUnflattenIntOp
ArrayRef<int64_t>{outputRank}, rewriter.getIntegerType(64));
Value shapeValue =
tensor::FromElementsOp::create(rewriter, loc, shapeType, outputShape);
expand = tensor::ReshapeOp::create(rewriter, loc, expandTy,
adaptor.getSelf(), shapeValue)
.getResult();
expand =
tensor::ReshapeOp::create(rewriter, loc, expandTy, self, shapeValue)
.getResult();
}
rewriter.replaceOp(op, expand);
return success();
Expand Down Expand Up @@ -1740,6 +1787,12 @@ class ConvertAtenTransposeIntOp
Value outVector = tensor::EmptyOp::create(
rewriter, loc, getAsOpFoldResult(outputDims), elementType);

// Note: The empty tensor type may not match `outType` due to folding
// performed by `getAsOpFoldResult` of `tensor::DimOp`.
// Cast to `outType` if needed to ensure type consistency.
if (outVector.getType() != outType)
outVector = tensor::CastOp::create(rewriter, loc, outType, outVector);

SmallVector<int64_t> permutation(inputRank);
std::iota(permutation.begin(), permutation.end(), 0);
permutation[dim0] = dim1;
Expand Down
15 changes: 15 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,21 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
return rewriter.notifyMatchFailure(op, "dimension out of range");
}

// Check if dimensions are in the original order (no permutation needed).
bool isIdentityPermutation = true;
for (uint32_t i = 0; i < numDimensions; i++) {
if (dimensions[i] != static_cast<int64_t>(i)) {
isIdentityPermutation = false;
break;
}
}

// If no permutation is needed, return the input as result.
if (isIdentityPermutation) {
result = input;
return success();
}

SmallVector<Value> outputDims;
for (uint32_t i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, input, dimensions[i]));
Expand Down
62 changes: 62 additions & 0 deletions test/Conversion/TorchToLinalg/datamovement.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,65 @@ func.func @torch.aten.reflection_pad2d(%arg0: !torch.vtensor<[1,1,4,4],f32>) ->
%1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,1,4,4],f32>, !torch.list<int> -> !torch.vtensor<[1,1,8,9],f32>
return %1 : !torch.vtensor<[1,1,8,9],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.transpose.int$dynamic_dims(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,56,56,96],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,56,56,96],f32> -> tensor<1x56x56x96xf32>
// CHECK: %[[VAL_9:.*]] = tensor.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2], [3, 4], [5]] output_shape [1, 8, 7, 8, 7, 96] : tensor<1x56x56x96xf32> into tensor<1x8x7x8x7x96xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x8x8x7x7x96xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[VAL_9]] {{.*}} outs(%[[EMPTY]] {{.*}} permutation = [0, 1, 3, 2, 4, 5]
// CHECK: %[[RESULT_CAST:.*]] = tensor.cast %[[TRANSPOSE]] : tensor<1x8x8x7x7x96xf32> to tensor<?x?x?x?x?x?xf32>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[RESULT_CAST]] : tensor<?x?x?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?,?,?],f32>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?,?,?,?,?],f32>
// CHECK: }
func.func @torch.aten.transpose.int$dynamic_dims(%arg0: !torch.vtensor<[1,56,56,96],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> {
%int1 = torch.constant.int 1
%int8 = torch.constant.int 8
%int7 = torch.constant.int 7
%int96 = torch.constant.int 96
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int1, %int8, %int7, %int8, %int7, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?,?,?],f32>
%2 = torch.aten.transpose.int %1, %int2, %int3 : !torch.vtensor<[?,?,?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?,?,?],f32>
return %2 : !torch.vtensor<[?,?,?,?,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @aten.unflatten.int$dynamic_input_dim(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,?,96],f32>) -> !torch.vtensor<[1,56,56,96],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,?,96],f32> -> tensor<1x?x96xf32>
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<1x?x96xf32> to tensor<1x3136x96xf32>
// CHECK: %[[VAL_3:.*]] = tensor.expand_shape %[[VAL_2]] {{\[\[}}0], [1, 2], [3]] output_shape [1, 56, 56, 96] : tensor<1x3136x96xf32> into tensor<1x56x56x96xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x56x56x96xf32> -> !torch.vtensor<[1,56,56,96],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,56,56,96],f32>
// CHECK: }
func.func @aten.unflatten.int$dynamic_input_dim(%arg0: !torch.vtensor<[1,?,96],f32>) -> !torch.vtensor<[1,56,56,96],f32> {
%int1 = torch.constant.int 1
%int56 = torch.constant.int 56
%129 = torch.prim.ListConstruct %int56, %int56 : (!torch.int, !torch.int) -> !torch.list<int>
%130 = torch.aten.unflatten.int %arg0, %int1, %129 : !torch.vtensor<[1,?,96],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[1,56,56,96],f32>
return %130 : !torch.vtensor<[1,56,56,96],f32>
}

// -----

// CHECK-LABEL: func.func @aten.permute$identity_permutation(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64,32,16,8,4],f32>) -> !torch.vtensor<[64,32,16,8,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64,32,16,8,4],f32> -> tensor<64x32x16x8x4xf32>
// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<64x32x16x8x4xf32> -> !torch.vtensor<[64,32,16,8,4],f32>
// CHECK: return %[[VAL_2]] : !torch.vtensor<[64,32,16,8,4],f32>
// CHECK: }
func.func @aten.permute$identity_permutation(%arg0: !torch.vtensor<[64,32,16,8,4],f32>) -> !torch.vtensor<[64,32,16,8,4],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[64,32,16,8,4],f32>, !torch.list<int> -> !torch.vtensor<[64,32,16,8,4],f32>
return %1 : !torch.vtensor<[64,32,16,8,4],f32>
}
Loading