Skip to content

Commit 8ee38f3

Browse files
[mlir][bufferization] Follow up for #68074 (#68488)
Address additional comments in #68074. This should have been part of #68074.
1 parent 635eb5f commit 8ee38f3

File tree

6 files changed

+98
-22
lines changed

6 files changed

+98
-22
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

+17-8
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,14 @@ def Bufferization_MaterializeInDestinationOp
227227
let summary = "copy a tensor";
228228

229229
let description = [{
230-
This op indicates that the data of the `source` tensor should materialize
231-
in `dest`, which can be a tensor or a memref. In case of a tensor, `source`
232-
should materialize in the future buffer of `dest` and a the updated
233-
destination tensor is returned. In case of a memref, `source` should
234-
materialize in `dest`, which is already a buffer. The op has no results in
235-
that case.
230+
This op indicates that the data of the `source` tensor is guaranteed to
231+
materialize in `dest`, which can be a tensor or a memref. In case of a
232+
tensor, `source` materializes in the future buffer of `dest` and a the
233+
updated destination tensor is returned. If this is not possible, e.g.,
234+
because the destination tensor is read-only or because its original
235+
contents are still read later, the input IR fails to bufferize. In case of a
236+
memref, `source` materializes in `dest`, which is already a buffer. The op
237+
has no results in that case.
236238

237239
`source`, `dest` and `result` (if present) must have the same shape and
238240
element type. If the op has a result, the types of `result` and `dest` must
@@ -252,7 +254,8 @@ def Bufferization_MaterializeInDestinationOp
252254
indicates that this op is the only way for the tensor IR to access `dest`
253255
(or an alias thereof). E.g., there must be no other `to_tensor` ops with
254256
`dest` or with an alias of `dest`. Such IR is not supported by
255-
One-Shot Bufferize.
257+
One-Shot Bufferize. Ops that have incorrect usage of `restrict` may
258+
bufferize incorrectly.
256259

257260
Note: `restrict` and `writable` could be removed from this op because they
258261
must always be set for memref destinations. This op has these attributes to
@@ -262,7 +265,9 @@ def Bufferization_MaterializeInDestinationOp
262265
Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
263266
same purpose, but since tensor dialect ops only indicate *what* should be
264267
computed but not *where*, it could fold away, causing the computation to
265-
materialize in a different buffer.
268+
materialize in a different buffer. It is also possible that the
269+
`tensor.insert_slice` destination bufferizes out-of-place, which would also
270+
cause the computation to materialize in a buffer different buffer.
266271
}];
267272

268273
let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
@@ -282,6 +287,9 @@ def Bufferization_MaterializeInDestinationOp
282287
bool bufferizesToElementwiseAccess(const AnalysisState &state,
283288
ArrayRef<OpOperand *> opOperands);
284289

290+
bool mustBufferizeInPlace(OpOperand &opOperand,
291+
const AnalysisState &state);
292+
285293
AliasingValueList getAliasingValues(
286294
OpOperand &opOperand, const AnalysisState &state);
287295

@@ -408,6 +416,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
408416
Note: Only `to_tensor` ops with the `restrict` unit attribute are supported
409417
by One-Shot Bufferize. Other IR is rejected. (To support `to_tensor`
410418
without `restrict`, One-Shot Bufferize would have to analyze memref IR.)
419+
Ops that have incorrect usage of `restrict` may bufferize incorrectly.
411420

412421
Example:
413422

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,14 @@ bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
549549
return false;
550550
}
551551

552+
bool MaterializeInDestinationOp::mustBufferizeInPlace(
553+
OpOperand &opOperand, const AnalysisState &state) {
554+
// The source is only read and not written, so it always bufferizes in-place
555+
// by default. The destination is written and is forced to bufferize in-place
556+
// (if it is a tensor).
557+
return true;
558+
}
559+
552560
AliasingValueList
553561
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
554562
const AnalysisState &state) {

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1313
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
1414
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15+
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
1516
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1617
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/IR/Dominance.h"
@@ -184,12 +185,25 @@ struct EmptyTensorElimination
184185

185186
void EmptyTensorElimination::runOnOperation() {
186187
Operation *op = getOperation();
188+
auto moduleOp = dyn_cast<ModuleOp>(op);
187189
OneShotBufferizationOptions options;
188190
options.allowReturnAllocsFromLoops = true;
191+
if (moduleOp)
192+
options.bufferizeFunctionBoundaries = true;
189193
OneShotAnalysisState state(op, options);
190-
if (failed(analyzeOp(op, state))) {
191-
signalPassFailure();
192-
return;
194+
if (moduleOp) {
195+
// Module analysis takes into account function boundaries.
196+
if (failed(analyzeModuleOp(moduleOp, state))) {
197+
signalPassFailure();
198+
return;
199+
}
200+
} else {
201+
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,
202+
// func.return.
203+
if (failed(analyzeOp(op, state))) {
204+
signalPassFailure();
205+
return;
206+
}
193207
}
194208

195209
IRRewriter rewriter(op->getContext());

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040

4141
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
4242

43-
#include <random>
4443
#include <optional>
44+
#include <random>
4545

4646
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
4747
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -1182,8 +1182,8 @@ checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
11821182
// not handled in the analysis.
11831183
if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
11841184
if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
1185-
op->emitError("to_tensor ops without `restrict` are not supported by "
1186-
"One-Shot Analysis");
1185+
op->emitOpError("to_tensor ops without `restrict` are not supported by "
1186+
"One-Shot Analysis");
11871187
return WalkResult::interrupt();
11881188
}
11891189
}
@@ -1195,8 +1195,19 @@ checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
11951195
/*checkConsistencyOnly=*/true)) {
11961196
// This error can happen if certain "mustBufferizeInPlace" interface
11971197
// methods are implemented incorrectly, such that the IR already has
1198-
// a RaW conflict before making any bufferization decisions.
1199-
op->emitError("input IR has RaW conflict");
1198+
// a RaW conflict before making any bufferization decisions. It can
1199+
// also happen if the bufferization.materialize_in_destination is used
1200+
// in such a way that a RaW conflict is not avoidable.
1201+
op->emitOpError("not bufferizable under the given constraints: "
1202+
"cannot avoid RaW conflict");
1203+
return WalkResult::interrupt();
1204+
}
1205+
1206+
if (state.isInPlace(opOperand) &&
1207+
wouldCreateWriteToNonWritableBuffer(
1208+
opOperand, state, /*checkConsistencyOnly=*/true)) {
1209+
op->emitOpError("not bufferizable under the given constraints: would "
1210+
"write to read-only buffer");
12001211
return WalkResult::interrupt();
12011212
}
12021213
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir

+39-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
// RUN: mlir-opt %s -one-shot-bufferize="allow-unknown-ops" -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -one-shot-bufferize="allow-unknown-ops" -verify-diagnostics -split-input-file | FileCheck %s
22

33
// Run fuzzer with different seeds.
4-
// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
5-
// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
6-
// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
4+
// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-fuzzer-seed=23" -verify-diagnostics -split-input-file -o /dev/null
5+
// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-fuzzer-seed=59" -verify-diagnostics -split-input-file -o /dev/null
6+
// RUN: mlir-opt %s -one-shot-bufferize="test-analysis-only analysis-fuzzer-seed=91" -verify-diagnostics -split-input-file -o /dev/null
77

88
// Run with top-down analysis.
9-
// RUN: mlir-opt %s -one-shot-bufferize="allow-unknown-ops analysis-heuristic=top-down" -split-input-file | FileCheck %s --check-prefix=CHECK-TOP-DOWN-ANALYSIS
9+
// RUN: mlir-opt %s -one-shot-bufferize="allow-unknown-ops analysis-heuristic=top-down" -verify-diagnostics -split-input-file | FileCheck %s --check-prefix=CHECK-TOP-DOWN-ANALYSIS
1010

1111
// Test without analysis: Insert a copy on every buffer write.
1212
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-unknown-ops copy-before-write" -split-input-file | FileCheck %s --check-prefix=CHECK-COPY-BEFORE-WRITE
@@ -235,3 +235,37 @@ func.func @materialize_in_destination_buffer(%t: tensor<5xf32>, %m: memref<5xf32
235235
return
236236
}
237237

238+
// -----
239+
240+
func.func @materialize_in_func_bbarg(%t: tensor<?xf32>, %dest: tensor<?xf32>)
241+
-> tensor<?xf32> {
242+
// This op is not bufferizable because function block arguments are
243+
// read-only in regular One-Shot Bufferize. (Run One-Shot Module
244+
// Bufferization instead.)
245+
// expected-error @below{{not bufferizable under the given constraints: would write to read-only buffer}}
246+
%0 = bufferization.materialize_in_destination %t in %dest
247+
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
248+
return %0 : tensor<?xf32>
249+
}
250+
251+
// -----
252+
253+
func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5xf32>, f32) {
254+
%dest = bufferization.alloc_tensor() : tensor<5xf32>
255+
// Note: The location of the RaW conflict may not be accurate (such as in this
256+
// example). This is because the analysis operates on "alias sets" and not
257+
// single SSA values. The location may point to any SSA value in the alias set
258+
// that participates in the conflict.
259+
// expected-error @below{{not bufferizable under the given constraints: cannot avoid RaW conflict}}
260+
%dest_filled = linalg.fill ins(%f : f32) outs(%dest : tensor<5xf32>) -> tensor<5xf32>
261+
%src = bufferization.alloc_tensor() : tensor<5xf32>
262+
%src_filled = linalg.fill ins(%f2 : f32) outs(%src : tensor<5xf32>) -> tensor<5xf32>
263+
264+
%0 = bufferization.materialize_in_destination %src_filled in %dest_filled
265+
: (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
266+
// Read from %dest_filled, which makes it impossible to bufferize the
267+
// materialize_in_destination op in-place.
268+
%r = tensor.extract %dest_filled[%idx] : tensor<5xf32>
269+
270+
return %0, %r : tensor<5xf32>, f32
271+
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
153153
func.func @copy_of_unranked_tensor(%t: tensor<*xf32>) -> tensor<*xf32> {
154154
// Unranked tensor OpOperands always bufferize in-place. With this limitation,
155155
// there is no way to bufferize this IR correctly.
156-
// expected-error @+1 {{input IR has RaW conflict}}
156+
// expected-error @+1 {{not bufferizable under the given constraints: cannot avoid RaW conflict}}
157157
func.call @maybe_writing_func(%t) : (tensor<*xf32>) -> ()
158158
return %t : tensor<*xf32>
159159
}

0 commit comments

Comments
 (0)