Skip to content

Commit 437c621

Browse files
[mlir][memref] Remove redundant memref.tensor_store op (#71010)
`bufferization.materialize_in_destination` should be used instead. Both ops bufferize to a memcpy. This change also conceptually cleans up the memref dialect a bit: the memref dialect no longer contains ops that operate on tensor values.
1 parent 6529c9a commit 437c621

File tree

15 files changed

+53
-191
lines changed

15 files changed

+53
-191
lines changed

mlir/docs/LangRef.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ func.func @mul(%A: tensor<100x?xf32>, %B: tensor<?x50xf32>) -> (tensor<100x50xf3
7777
7878
// Allocate addressable "buffers" and copy tensors %A and %B into them.
7979
%A_m = memref.alloc(%n) : memref<100x?xf32>
80-
memref.tensor_store %A to %A_m : memref<100x?xf32>
80+
bufferization.materialize_in_destination %A in writable %A_m
81+
: (tensor<100x?xf32>, memref<100x?xf32>) -> ()
8182
8283
%B_m = memref.alloc(%n) : memref<?x50xf32>
83-
memref.tensor_store %B to %B_m : memref<?x50xf32>
84+
bufferization.materialize_in_destination %B in writable %B_m
85+
: (tensor<?x50xf32>, memref<?x50xf32>) -> ()
8486
8587
// Call function @multiply passing memrefs as arguments,
8688
// and getting returned the result of the multiplication.

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

+13-11
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
103103

104104
```
105105
%alloc = memref.alloc() : memref<10xf32>
106-
memref.tensor_store %dest, %alloc : memref<10xf32>
106+
bufferization.materialize_in_destination %dest in %alloc
107107
memref.store %f, %alloc[%pos] : memref<10xf32>
108108
%0 = bufferization.to_tensor %alloc restrict writable : memref<10xf32>
109109
```
@@ -118,15 +118,16 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
118118
An optional memory space attribute can be specified for the materialized
119119
buffer allocation.
120120

121-
If a memory copy is needed, a "memref.tensor_store" is used when possible.
122-
This is an op with tensor semantics that will bufferize to a memory copy
123-
later. Which concrete op will be used for the memory copy is up to the
124-
bufferization framework. Alternatively, a custom memcpy op can be specified
125-
via `memcpy_op`. Currently supported are "memref.copy" and "linalg.copy".
126-
In that case, the source of each memcpy must not have a custom memory space.
127-
Furthermore, because the future buffer layout unknown for a given tensor,
128-
a fully dynamic layout is assumed for best compatibility. Users should use
129-
"memref.tensor_store" when possible.
121+
If a memory copy is needed, a "bufferization.materialize_in_destination" is
122+
used when possible. This is an op with tensor semantics that will bufferize
123+
to a memory copy later. Which concrete op will be used for the memory copy
124+
is up to the bufferization framework. Alternatively, a custom memcpy op can
125+
be specified via `memcpy_op`. Currently supported are "memref.copy" and
126+
"linalg.copy". In that case, the source of each memcpy must not have a
127+
custom memory space. Furthermore, because the future buffer layout unknown
128+
for a given tensor, a fully dynamic layout is assumed for best
129+
compatibility. Users should use "bufferization.materialize_in_destination"
130+
when possible.
130131

131132
"memref.alloc" is used for new buffer allocations. The buffer is deallocated
132133
at the end of the block if the "emit_dealloc" attribute is present. If this
@@ -148,7 +149,8 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
148149

149150
let arguments = (ins TransformHandleTypeInterface:$target,
150151
OptionalAttr<AnyAttr>:$memory_space,
151-
DefaultValuedAttr<StrAttr, "\"memref.tensor_store\"">:
152+
DefaultValuedAttr<StrAttr,
153+
"\"bufferization.materialize_in_destination\"">:
152154
$memcpy_op,
153155
DefaultValuedAttr<StrAttr, "\"memref.alloc\"">:
154156
$alloc_op,

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

+10-5
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ struct BufferizeToAllocationOptions {
5252
enum class AllocOp { MemrefAlloc = 0, MemrefAlloca = 1 };
5353
AllocOp allocOp = AllocOp::MemrefAlloc;
5454

55-
enum class MemcpyOp { MemrefTensorStore = 0, MemrefCopy = 1, LinalgCopy = 2 };
56-
MemcpyOp memcpyOp = MemcpyOp::MemrefTensorStore;
55+
enum class MemcpyOp {
56+
MaterializeInDestination = 0,
57+
MemrefCopy = 1,
58+
LinalgCopy = 2
59+
};
60+
MemcpyOp memcpyOp = MemcpyOp::MaterializeInDestination;
5761

5862
/// If set to "true", only the destination tensor operands are bufferized to
5963
/// a new allocation (and wrapped in "bufferization.to_tensor"), but not the
@@ -68,7 +72,8 @@ struct BufferizeToAllocationOptions {
6872
};
6973

7074
/// Materialize a buffer allocation for the given tensor.pad op and lower the
71-
/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.:
75+
/// op to linalg.fill/linalg.generic + bufferization.materialize_in_destination.
76+
/// E.g.:
7277
///
7378
/// %0 = tensor.pad low[%l] high[%h] %t ...
7479
///
@@ -77,7 +82,7 @@ struct BufferizeToAllocationOptions {
7782
/// %alloc = memref.alloc
7883
/// linalg.fill ... outs(%alloc)
7984
/// %subview = memref.subview %alloc [%l] [...] [1]
80-
/// memref.tensor_store %t, %subview
85+
/// bufferization.materialize_in_destination %t in %subview
8186
/// %0 = bufferization.to_tensor %alloc restrict writable
8287
///
8388
/// In addition to rewriting the IR as shown above, this function returns the
@@ -98,7 +103,7 @@ Value bufferizeToAllocation(RewriterBase &rewriter,
98103
/// is lowered to:
99104
///
100105
/// %alloc = memref.alloc
101-
/// memref.tensor_store %t, %subview
106+
/// bufferization.materialize_in_destination %t in %subview
102107
/// vector.mask {
103108
/// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32>
104109
/// } : vector<16xi1>

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

-31
Original file line numberDiff line numberDiff line change
@@ -2095,37 +2095,6 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20952095
let hasVerifier = 1;
20962096
}
20972097

2098-
//===----------------------------------------------------------------------===//
2099-
// TensorStoreOp
2100-
//===----------------------------------------------------------------------===//
2101-
2102-
def TensorStoreOp : MemRef_Op<"tensor_store",
2103-
[SameOperandsShape, SameOperandsElementType,
2104-
TypesMatchWith<"type of 'value' matches tensor equivalent of 'memref'",
2105-
"memref", "tensor",
2106-
"getTensorTypeFromMemRefType($_self)">]> {
2107-
let summary = "tensor store operation";
2108-
let description = [{
2109-
Stores the contents of a tensor into a memref. The first operand is a value
2110-
of tensor type, the second operand is a value of memref type. The shapes and
2111-
element types of these must match, and are specified by the memref type.
2112-
2113-
Example:
2114-
2115-
```mlir
2116-
%9 = dim %8, 1 : tensor<4x?xf32>
2117-
%10 = memref.alloc(%9) : memref<4x?xf32, #layout, memspace0>
2118-
memref.tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
2119-
```
2120-
}];
2121-
2122-
let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
2123-
"the reference to store to",
2124-
[MemWriteAt<0, FullEffect>]>:$memref);
2125-
2126-
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
2127-
}
2128-
21292098
//===----------------------------------------------------------------------===//
21302099
// TransposeOp
21312100
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h

-21
This file was deleted.

mlir/include/mlir/InitAllDialects.h

-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
5454
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
5555
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
56-
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
5756
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
5857
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
5958
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -157,7 +156,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
157156
linalg::registerTilingInterfaceExternalModels(registry);
158157
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
159158
memref::registerAllocationOpInterfaceExternalModels(registry);
160-
memref::registerBufferizableOpInterfaceExternalModels(registry);
161159
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
162160
memref::registerValueBoundsOpInterfaceExternalModels(registry);
163161
memref::registerMemorySlotExternalModels(registry);

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,11 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
585585
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
586586
buffer = getDest();
587587
}
588-
rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), buffer);
588+
auto srcBuffer = getBuffer(rewriter, getSource(), options);
589+
if (failed(srcBuffer))
590+
return failure();
591+
if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
592+
return failure();
589593
replaceOpWithBufferizedValues(rewriter, getOperation(),
590594
tensorDest ? ValueRange(buffer) : ValueRange());
591595
return success();
@@ -682,8 +686,9 @@ LogicalResult MaterializeInDestinationOp::verify() {
682686
void MaterializeInDestinationOp::build(OpBuilder &builder,
683687
OperationState &state, Value source,
684688
Value dest) {
685-
assert(isa<TensorType>(dest.getType()) && "expected tensor type");
686-
build(builder, state, /*result=*/dest.getType(), source, dest);
689+
auto destTensorType = dyn_cast<TensorType>(dest.getType());
690+
build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
691+
source, dest);
687692
}
688693

689694
bool MaterializeInDestinationOp::isWritable(Value value,

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
241241
rewriter.setListener(&newOpsListener);
242242

243243
linalg::BufferizeToAllocationOptions options;
244-
if (getMemcpyOp() == "memref.tensor_store") {
245-
options.memcpyOp =
246-
linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefTensorStore;
244+
if (getMemcpyOp() == "bufferization.materialize_in_destination") {
245+
options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
246+
MaterializeInDestination;
247247
} else if (getMemcpyOp() == "memref.copy") {
248248
options.memcpyOp =
249249
linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy;
@@ -296,7 +296,7 @@ void transform::BufferizeToAllocationOp::getEffects(
296296
}
297297

298298
LogicalResult transform::BufferizeToAllocationOp::verify() {
299-
if (getMemcpyOp() != "memref.tensor_store" &&
299+
if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
300300
getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
301301
return emitOpError() << "unsupported memcpy op";
302302
if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,14 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
6363
assert(memrefDest.getType().isa<MemRefType>() && "expected ranked memref");
6464

6565
switch (options.memcpyOp) {
66-
case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefTensorStore:
66+
case linalg::BufferizeToAllocationOptions::MemcpyOp::
67+
MaterializeInDestination: {
6768
// Note: This is the preferred way of memcpy'ing because no layout map
6869
// and/or memory space must be specified for the source.
69-
b.create<memref::TensorStoreOp>(loc, tensorSource, memrefDest);
70-
break;
70+
auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>(
71+
loc, tensorSource, memrefDest);
72+
materializeOp.setWritable(true);
73+
} break;
7174
case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: {
7275
// TODO: Support custom memory space on source.
7376
// We do not know the layout map of the source yet, so use a fully dynamic
@@ -238,7 +241,7 @@ Value linalg::bufferizeToAllocation(
238241
rewriter.setInsertionPointAfter(fillOp);
239242
}
240243

241-
// Create memref.tensor_store.
244+
// Create memcpy.
242245
SmallVector<OpFoldResult> sizes =
243246
getMixedSizes(rewriter, loc, padOp.getSource());
244247
SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),

mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp

-63
This file was deleted.

mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_dialect_library(MLIRMemRefTransforms
22
AllocationOpInterfaceImpl.cpp
3-
BufferizableOpInterfaceImpl.cpp
43
ComposeSubView.cpp
54
ExpandOps.cpp
65
ExpandRealloc.cpp

mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

+7-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// CHECK: linalg.fill ins(%[[c50]] : index) outs(%[[alloc]] : memref<?x?xindex>)
1616
// CHECK: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]]
1717
// CHECK: %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1]
18-
// CHECK: memref.tensor_store %[[t]], %[[subview]]
18+
// CHECK: bufferization.materialize_in_destination %[[t]] in writable %[[subview]]
1919
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable : memref<?x?xindex>
2020
// CHECK: memref.dealloc %[[alloc]]
2121
// CHECK: return %[[r]]
@@ -40,17 +40,17 @@ module attributes {transform.with_named_sequence} {
4040
transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
4141

4242
// Ensure that one linalg.copy was generated.
43-
%tensor_store = transform.select "memref.tensor_store" in %new : (!transform.any_op) -> !transform.any_op
43+
%mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
4444
// expected-remark @below{{1}}
45-
transform.test_print_number_of_associated_payload_ir_ops %tensor_store : !transform.any_op
45+
transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
4646
transform.yield
4747
}
4848
}
4949

5050
// -----
5151

5252
// CHECK-LABEL: func @tensor_pad_constant_with_custom_copy(
53-
// CHECK-NOT: memref.tensor_store
53+
// CHECK-NOT: bufferization.materialize_in_destination
5454
// CHECK-NOT: memref.copy
5555
// CHECK: memref.alloca
5656
// CHECK: linalg.copy
@@ -194,7 +194,7 @@ module attributes {transform.with_named_sequence} {
194194
// CHECK-LABEL: func @vector_mask(
195195
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>,
196196
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref<?xf32, 4>
197-
// CHECK: memref.tensor_store %[[t]], %[[alloc]]
197+
// CHECK: bufferization.materialize_in_destination %[[t]] in writable %[[alloc]]
198198
// CHECK: vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %[[alloc]]
199199
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable
200200
// CHECK: memref.dealloc %[[alloc]]
@@ -217,7 +217,7 @@ module attributes {transform.with_named_sequence} {
217217
// CHECK-LABEL: func @tensor_insert_destination(
218218
// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>
219219
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref<?x10xindex, 4>
220-
// CHECK: memref.tensor_store %[[t]], %[[alloc]]
220+
// CHECK: bufferization.materialize_in_destination %[[t]] in writable %[[alloc]]
221221
// CHECK: %[[t2:.*]] = bufferization.to_tensor %[[alloc]] restrict writable
222222
// CHECK: %[[inserted:.*]] = tensor.insert %{{.*}} into %[[t2]]
223223
// CHECK: memref.dealloc %[[alloc]]
@@ -240,7 +240,7 @@ module attributes {transform.with_named_sequence} {
240240
// CHECK-LABEL: func @scf_for_destination(
241241
// CHECK-SAME: %[[t:.*]]: tensor<?x10xindex>
242242
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref<?x10xindex, 4>
243-
// CHECK: memref.tensor_store %[[t]], %[[alloc]]
243+
// CHECK: bufferization.materialize_in_destination %[[t]] in writable %[[alloc]]
244244
// CHECK: %[[t2:.*]] = bufferization.to_tensor %[[alloc]] restrict writable
245245
// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%{{.*}} = %[[t2]])
246246
// CHECK: memref.dealloc %[[alloc]]

mlir/test/Dialect/MemRef/bufferize.mlir

-11
This file was deleted.

0 commit comments

Comments
 (0)