From 11878d6f25af22746bc0b8843910fb979c4b3fc5 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 9 Oct 2023 23:31:22 +0200 Subject: [PATCH] [mlir][sparse] Extract `StorageSpecifierToLLVMPass` from bufferization pipeline `StorageSpecifierToLLVMPass` does not have to be part of the bufferization mini pipeline. It can run after the bufferization pipeline. This is desirable because it keeps the bufferization pipeline smaller. --- .../SparseTensor/Pipelines/SparseTensorPipelines.cpp | 2 ++ .../Transforms/SparsificationAndBufferizationPass.cpp | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index 54069064839ea..7569413546c0a 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -42,6 +42,8 @@ void mlir::sparse_tensor::buildSparseCompiler( /*enableSIMDIndex32=*/options.force32BitVectorIndices)); if (options.testBufferizationAnalysisOnly) return; + + pm.addPass(createStorageSpecifierToLLVMPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass( mlir::bufferization::createFinalizingBufferizePass()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 6fca8f82e3566..480e18e257277 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -39,7 +39,7 @@ namespace sparse_tensor { /// Return `true` if one of the given types is a sparse tensor type. static bool containsSparseTensor(TypeRange types) { for (Type t : types) - if (getSparseTensorEncoding(t)) + if (isa(t) && getSparseTensorEncoding(t)) return true; return false; } @@ -97,7 +97,8 @@ class SparsificationAndBufferizationPass return false; }); - if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions))) + if (failed(bufferization::bufferizeModuleOp(cast(getOperation()), + updatedOptions))) return failure(); bufferization::removeBufferizationAttributesInModule(getOperation()); @@ -154,7 +155,6 @@ class SparsificationAndBufferizationPass pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs, enableBufferInitialization)); pm.addPass(createSparseBufferRewritePass(enableBufferInitialization)); - pm.addPass(createStorageSpecifierToLLVMPass()); } if (failed(runPipeline(pm, getOperation()))) return signalPassFailure();