Skip to content

Commit 15e915a

Browse files
authored
[mlir][dataflow] Propagate errors from visitOperation (#105448)
Base `DataFlowAnalysis::visit` returns `LogicalResult`, but wrappers's Sparse/Dense/Forward/Backward `visitOperation` doesn't. Sometimes it's needed to abort solver early if some unrecoverable condition detected inside analysis. Update `visitOperation` to return `LogicalResult` and propagate it to `solver.initializeAndRun()`. Only `visitOperation` is updated for now, it's possible to update other hooks like `visitNonControlFlowArguments`, bit it's not needed immediately and let's keep this PR small. Hijacked `UnderlyingValueAnalysis` test analysis to test it.
1 parent 14c7e4a commit 15e915a

16 files changed

+220
-149
lines changed

flang/lib/Optimizer/Transforms/StackArrays.cpp

+16-13
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,9 @@ class AllocationAnalysis
149149
public:
150150
using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
151151

152-
void visitOperation(mlir::Operation *op, const LatticePoint &before,
153-
LatticePoint *after) override;
152+
mlir::LogicalResult visitOperation(mlir::Operation *op,
153+
const LatticePoint &before,
154+
LatticePoint *after) override;
154155

155156
/// At an entry point, the last modifications of all memory resources are
156157
/// yet to be determined
@@ -159,7 +160,7 @@ class AllocationAnalysis
159160
protected:
160161
/// Visit control flow operations and decide whether to call visitOperation
161162
/// to apply the transfer function
162-
void processOperation(mlir::Operation *op) override;
163+
mlir::LogicalResult processOperation(mlir::Operation *op) override;
163164
};
164165

165166
/// Drives analysis to find candidate fir.allocmem operations which could be
@@ -329,9 +330,8 @@ std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
329330
return it->second;
330331
}
331332

332-
void AllocationAnalysis::visitOperation(mlir::Operation *op,
333-
const LatticePoint &before,
334-
LatticePoint *after) {
333+
mlir::LogicalResult AllocationAnalysis::visitOperation(
334+
mlir::Operation *op, const LatticePoint &before, LatticePoint *after) {
335335
LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
336336
<< "\n");
337337
LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
@@ -346,14 +346,14 @@ void AllocationAnalysis::visitOperation(mlir::Operation *op,
346346
if (attr && attr.getValue()) {
347347
LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
348348
// skip allocation marked not to be moved
349-
return;
349+
return mlir::success();
350350
}
351351

352352
auto retTy = allocmem.getAllocatedType();
353353
if (!mlir::isa<fir::SequenceType>(retTy)) {
354354
LLVM_DEBUG(llvm::dbgs()
355355
<< "--Allocation is not for an array: skipping\n");
356-
return;
356+
return mlir::success();
357357
}
358358

359359
mlir::Value result = op->getResult(0);
@@ -387,6 +387,7 @@ void AllocationAnalysis::visitOperation(mlir::Operation *op,
387387

388388
LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
389389
propagateIfChanged(after, changed);
390+
return mlir::success();
390391
}
391392

392393
void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
@@ -395,18 +396,20 @@ void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
395396

396397
/// Mostly a copy of AbstractDenseLattice::processOperation - the difference
397398
/// being that call operations are passed through to the transfer function
398-
void AllocationAnalysis::processOperation(mlir::Operation *op) {
399+
mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) {
399400
// If the containing block is not executable, bail out.
400401
if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive())
401-
return;
402+
return mlir::success();
402403

403404
// Get the dense lattice to update
404405
mlir::dataflow::AbstractDenseLattice *after = getLattice(op);
405406

406407
// If this op implements region control-flow, then control-flow dictates its
407408
// transfer function.
408-
if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op))
409-
return visitRegionBranchOperation(op, branch, after);
409+
if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
410+
visitRegionBranchOperation(op, branch, after);
411+
return mlir::success();
412+
}
410413

411414
// pass call operations through to the transfer function
412415

@@ -418,7 +421,7 @@ void AllocationAnalysis::processOperation(mlir::Operation *op) {
418421
before = getLatticeFor(op, op->getBlock());
419422

420423
/// Invoke the operation transfer function
421-
visitOperationImpl(op, *before, after);
424+
return visitOperationImpl(op, *before, after);
422425
}
423426

424427
llvm::LogicalResult

mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ class SparseConstantPropagation
101101
public:
102102
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
103103

104-
void visitOperation(Operation *op,
105-
ArrayRef<const Lattice<ConstantValue> *> operands,
106-
ArrayRef<Lattice<ConstantValue> *> results) override;
104+
LogicalResult
105+
visitOperation(Operation *op,
106+
ArrayRef<const Lattice<ConstantValue> *> operands,
107+
ArrayRef<Lattice<ConstantValue> *> results) override;
107108

108109
void setToEntryState(Lattice<ConstantValue> *lattice) override;
109110
};

mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h

+22-20
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
8787
protected:
8888
/// Propagate the dense lattice before the execution of an operation to the
8989
/// lattice after its execution.
90-
virtual void visitOperationImpl(Operation *op,
91-
const AbstractDenseLattice &before,
92-
AbstractDenseLattice *after) = 0;
90+
virtual LogicalResult visitOperationImpl(Operation *op,
91+
const AbstractDenseLattice &before,
92+
AbstractDenseLattice *after) = 0;
9393

9494
/// Get the dense lattice after the execution of the given program point.
9595
virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
@@ -114,7 +114,7 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
114114
/// operation, then the state after the execution of the operation is set by
115115
/// control-flow or the callgraph. Otherwise, this function invokes the
116116
/// operation transfer function.
117-
virtual void processOperation(Operation *op);
117+
virtual LogicalResult processOperation(Operation *op);
118118

119119
/// Propagate the dense lattice forward along the control flow edge from
120120
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
@@ -191,8 +191,8 @@ class DenseForwardDataFlowAnalysis
191191
/// Visit an operation with the dense lattice before its execution. This
192192
/// function is expected to set the dense lattice after its execution and
193193
/// trigger change propagation in case of change.
194-
virtual void visitOperation(Operation *op, const LatticeT &before,
195-
LatticeT *after) = 0;
194+
virtual LogicalResult visitOperation(Operation *op, const LatticeT &before,
195+
LatticeT *after) = 0;
196196

197197
/// Hook for customizing the behavior of lattice propagation along the call
198198
/// control flow edges. Two types of (forward) propagation are possible here:
@@ -263,10 +263,11 @@ class DenseForwardDataFlowAnalysis
263263

264264
/// Type-erased wrappers that convert the abstract dense lattice to a derived
265265
/// lattice and invoke the virtual hooks operating on the derived lattice.
266-
void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
267-
AbstractDenseLattice *after) final {
268-
visitOperation(op, static_cast<const LatticeT &>(before),
269-
static_cast<LatticeT *>(after));
266+
LogicalResult visitOperationImpl(Operation *op,
267+
const AbstractDenseLattice &before,
268+
AbstractDenseLattice *after) final {
269+
return visitOperation(op, static_cast<const LatticeT &>(before),
270+
static_cast<LatticeT *>(after));
270271
}
271272
void visitCallControlFlowTransfer(CallOpInterface call,
272273
CallControlFlowAction action,
@@ -326,9 +327,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
326327
protected:
327328
/// Propagate the dense lattice after the execution of an operation to the
328329
/// lattice before its execution.
329-
virtual void visitOperationImpl(Operation *op,
330-
const AbstractDenseLattice &after,
331-
AbstractDenseLattice *before) = 0;
330+
virtual LogicalResult visitOperationImpl(Operation *op,
331+
const AbstractDenseLattice &after,
332+
AbstractDenseLattice *before) = 0;
332333

333334
/// Get the dense lattice before the execution of the program point. That is,
334335
/// before the execution of the given operation or after the execution of the
@@ -353,7 +354,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
353354
/// Visit an operation. Dispatches to specialized methods for call or region
354355
/// control-flow operations. Otherwise, this function invokes the operation
355356
/// transfer function.
356-
virtual void processOperation(Operation *op);
357+
virtual LogicalResult processOperation(Operation *op);
357358

358359
/// Propagate the dense lattice backwards along the control flow edge from
359360
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
@@ -442,8 +443,8 @@ class DenseBackwardDataFlowAnalysis
442443
/// Transfer function. Visits an operation with the dense lattice after its
443444
/// execution. This function is expected to set the dense lattice before its
444445
/// execution and trigger propagation in case of change.
445-
virtual void visitOperation(Operation *op, const LatticeT &after,
446-
LatticeT *before) = 0;
446+
virtual LogicalResult visitOperation(Operation *op, const LatticeT &after,
447+
LatticeT *before) = 0;
447448

448449
/// Hook for customizing the behavior of lattice propagation along the call
449450
/// control flow edges. Two types of (back) propagation are possible here:
@@ -513,10 +514,11 @@ class DenseBackwardDataFlowAnalysis
513514

514515
/// Type-erased wrappers that convert the abstract dense lattice to a derived
515516
/// lattice and invoke the virtual hooks operating on the derived lattice.
516-
void visitOperationImpl(Operation *op, const AbstractDenseLattice &after,
517-
AbstractDenseLattice *before) final {
518-
visitOperation(op, static_cast<const LatticeT &>(after),
519-
static_cast<LatticeT *>(before));
517+
LogicalResult visitOperationImpl(Operation *op,
518+
const AbstractDenseLattice &after,
519+
AbstractDenseLattice *before) final {
520+
return visitOperation(op, static_cast<const LatticeT &>(after),
521+
static_cast<LatticeT *>(before));
520522
}
521523
void visitCallControlFlowTransfer(CallOpInterface call,
522524
CallControlFlowAction action,

mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ class IntegerRangeAnalysis
5555

5656
/// Visit an operation. Invoke the transfer function on each operation that
5757
/// implements `InferIntRangeInterface`.
58-
void visitOperation(Operation *op,
59-
ArrayRef<const IntegerValueRangeLattice *> operands,
60-
ArrayRef<IntegerValueRangeLattice *> results) override;
58+
LogicalResult
59+
visitOperation(Operation *op,
60+
ArrayRef<const IntegerValueRangeLattice *> operands,
61+
ArrayRef<IntegerValueRangeLattice *> results) override;
6162

6263
/// Visit block arguments or operation results of an operation with region
6364
/// control-flow for which values are not defined by region control-flow. This

mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ class LivenessAnalysis : public SparseBackwardDataFlowAnalysis<Liveness> {
7979
public:
8080
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
8181

82-
void visitOperation(Operation *op, ArrayRef<Liveness *> operands,
83-
ArrayRef<const Liveness *> results) override;
82+
LogicalResult visitOperation(Operation *op, ArrayRef<Liveness *> operands,
83+
ArrayRef<const Liveness *> results) override;
8484

8585
void visitBranchOperand(OpOperand &operand) override;
8686

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

+14-12
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
197197

198198
/// The operation transfer function. Given the operand lattices, this
199199
/// function is expected to set the result lattices.
200-
virtual void
200+
virtual LogicalResult
201201
visitOperationImpl(Operation *op,
202202
ArrayRef<const AbstractSparseLattice *> operandLattices,
203203
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
@@ -238,7 +238,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
238238
/// Visit an operation. If this is a call operation or an operation with
239239
/// region control-flow, then its result lattices are set accordingly.
240240
/// Otherwise, the operation transfer function is invoked.
241-
void visitOperation(Operation *op);
241+
LogicalResult visitOperation(Operation *op);
242242

243243
/// Visit a block to compute the lattice values of its arguments. If this is
244244
/// an entry block, then the argument values are determined from the block's
@@ -277,8 +277,9 @@ class SparseForwardDataFlowAnalysis
277277

278278
/// Visit an operation with the lattices of its operands. This function is
279279
/// expected to set the lattices of the operation's results.
280-
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
281-
ArrayRef<StateT *> results) = 0;
280+
virtual LogicalResult visitOperation(Operation *op,
281+
ArrayRef<const StateT *> operands,
282+
ArrayRef<StateT *> results) = 0;
282283

283284
/// Visit a call operation to an externally defined function given the
284285
/// lattices of its arguments.
@@ -328,10 +329,10 @@ class SparseForwardDataFlowAnalysis
328329
private:
329330
/// Type-erased wrappers that convert the abstract lattice operands to derived
330331
/// lattices and invoke the virtual hooks operating on the derived lattices.
331-
void visitOperationImpl(
332+
LogicalResult visitOperationImpl(
332333
Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
333334
ArrayRef<AbstractSparseLattice *> resultLattices) override {
334-
visitOperation(
335+
return visitOperation(
335336
op,
336337
{reinterpret_cast<const StateT *const *>(operandLattices.begin()),
337338
operandLattices.size()},
@@ -387,7 +388,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
387388

388389
/// The operation transfer function. Given the result lattices, this
389390
/// function is expected to set the operand lattices.
390-
virtual void visitOperationImpl(
391+
virtual LogicalResult visitOperationImpl(
391392
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
392393
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
393394

@@ -424,7 +425,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
424425
/// Visit an operation. If this is a call operation or an operation with
425426
/// region control-flow, then its operand lattices are set accordingly.
426427
/// Otherwise, the operation transfer function is invoked.
427-
void visitOperation(Operation *op);
428+
LogicalResult visitOperation(Operation *op);
428429

429430
/// Visit a block.
430431
void visitBlock(Block *block);
@@ -474,8 +475,9 @@ class SparseBackwardDataFlowAnalysis
474475

475476
/// Visit an operation with the lattices of its results. This function is
476477
/// expected to set the lattices of the operation's operands.
477-
virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
478-
ArrayRef<const StateT *> results) = 0;
478+
virtual LogicalResult visitOperation(Operation *op,
479+
ArrayRef<StateT *> operands,
480+
ArrayRef<const StateT *> results) = 0;
479481

480482
/// Visit a call to an external function. This function is expected to set
481483
/// lattice values of the call operands. By default, calls `visitCallOperand`
@@ -510,10 +512,10 @@ class SparseBackwardDataFlowAnalysis
510512
private:
511513
/// Type-erased wrappers that convert the abstract lattice operands to derived
512514
/// lattices and invoke the virtual hooks operating on the derived lattices.
513-
void visitOperationImpl(
515+
LogicalResult visitOperationImpl(
514516
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
515517
ArrayRef<const AbstractSparseLattice *> resultLattices) override {
516-
visitOperation(
518+
return visitOperation(
517519
op,
518520
{reinterpret_cast<StateT *const *>(operandLattices.begin()),
519521
operandLattices.size()},

mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void ConstantValue::print(raw_ostream &os) const {
4343
// SparseConstantPropagation
4444
//===----------------------------------------------------------------------===//
4545

46-
void SparseConstantPropagation::visitOperation(
46+
LogicalResult SparseConstantPropagation::visitOperation(
4747
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
4848
ArrayRef<Lattice<ConstantValue> *> results) {
4949
LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
@@ -54,14 +54,14 @@ void SparseConstantPropagation::visitOperation(
5454
// folding.
5555
if (op->getNumRegions()) {
5656
setAllToEntryStates(results);
57-
return;
57+
return success();
5858
}
5959

6060
SmallVector<Attribute, 8> constantOperands;
6161
constantOperands.reserve(op->getNumOperands());
6262
for (auto *operandLattice : operands) {
6363
if (operandLattice->getValue().isUninitialized())
64-
return;
64+
return success();
6565
constantOperands.push_back(operandLattice->getValue().getConstantValue());
6666
}
6767

@@ -77,7 +77,7 @@ void SparseConstantPropagation::visitOperation(
7777
foldResults.reserve(op->getNumResults());
7878
if (failed(op->fold(constantOperands, foldResults))) {
7979
setAllToEntryStates(results);
80-
return;
80+
return success();
8181
}
8282

8383
// If the folding was in-place, mark the results as overdefined and reset
@@ -87,7 +87,7 @@ void SparseConstantPropagation::visitOperation(
8787
op->setOperands(originalOperands);
8888
op->setAttrs(originalAttrs);
8989
setAllToEntryStates(results);
90-
return;
90+
return success();
9191
}
9292

9393
// Merge the fold results into the lattice for this operation.
@@ -108,6 +108,7 @@ void SparseConstantPropagation::visitOperation(
108108
lattice, *getLatticeElement(foldResult.get<Value>()));
109109
}
110110
}
111+
return success();
111112
}
112113

113114
void SparseConstantPropagation::setToEntryState(

0 commit comments

Comments
 (0)