Skip to content

Commit edc8b60

Browse files
[mlir][linalg] ValueBoundsOpInterface: Add LinalgOps
Also add a few more complex test cases. Differential Revision: https://reviews.llvm.org/D145806
1 parent 10dbf23 commit edc8b60

File tree

7 files changed

+160
-0
lines changed

7 files changed

+160
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace linalg {
16+
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace linalg
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
4242
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
4343
#include "mlir/Dialect/Linalg/IR/Linalg.h"
44+
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
4445
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
4546
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
4647
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
@@ -141,6 +142,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
141142
registry);
142143
linalg::registerBufferizableOpInterfaceExternalModels(registry);
143144
linalg::registerTilingInterfaceExternalModels(registry);
145+
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
144146
memref::registerBufferizableOpInterfaceExternalModels(registry);
145147
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
146148
memref::registerValueBoundsOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Linalg/IR/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
22
LinalgInterfaces.cpp
33
LinalgOps.cpp
44
LinalgDialect.cpp
5+
ValueBoundsOpInterfaceImpl.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
@@ -30,5 +31,6 @@ add_mlir_dialect_library(MLIRLinalgDialect
3031
MLIRMemRefDialect
3132
MLIRTensorDialect
3233
MLIRTilingInterface
34+
MLIRValueBoundsOpInterface
3335
MLIRViewLikeInterface
3436
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
12+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
13+
14+
using namespace mlir;
15+
16+
namespace mlir {
17+
namespace linalg {
18+
namespace {
19+
20+
/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
21+
/// the `ValueBoundsOpInterface` with each of them.
22+
template <typename... Ops> struct LinalgValueBoundsOpInterfaceHelper {
23+
static void registerOpInterface(MLIRContext *ctx) {
24+
(Ops::template attachInterface<DstValueBoundsOpInterfaceExternalModel<Ops>>(
25+
*ctx),
26+
...);
27+
}
28+
};
29+
30+
} // namespace
31+
} // namespace linalg
32+
} // namespace mlir
33+
34+
void mlir::linalg::registerValueBoundsOpInterfaceExternalModels(
35+
DialectRegistry &registry) {
36+
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
37+
// Register all Linalg structured ops.
38+
LinalgValueBoundsOpInterfaceHelper<
39+
#define GET_OP_LIST
40+
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
41+
>::registerOpInterface(ctx);
42+
});
43+
}

mlir/test/Dialect/Affine/value-bounds-reification.mlir

+79
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,82 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index
2020

2121
return %4, %5, %6 : index, index, index
2222
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: func @reify_slice_bound(
27+
// CHECK: %[[c5:.*]] = arith.constant 5 : index
28+
// CHECK: "test.some_use"(%[[c5]])
29+
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
30+
%c0 = arith.constant 0 : index
31+
%c4 = arith.constant 4 : index
32+
scf.for %iv = %c0 to %ub step %c4 {
33+
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
34+
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
35+
%filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
36+
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
37+
"test.some_use"(%bound) : (index) -> ()
38+
}
39+
return
40+
}
41+
42+
// -----
43+
44+
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 - s1 + 1)>
45+
// CHECK-LABEL: func @scf_for(
46+
// CHECK-SAME: %[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index
47+
// CHECK: %[[bound:.*]] = affine.apply #[[$map]]()[%[[ub]], %[[lb]]]
48+
// CHECK: "test.some_use"(%[[bound]])
49+
func.func @scf_for(%lb: index, %ub: index, %step: index) {
50+
scf.for %iv = %lb to %ub step %step {
51+
%0 = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%iv)[%ub]
52+
%bound = "test.reify_bound"(%0) {type = "UB"} : (index) -> (index)
53+
"test.some_use"(%bound) : (index) -> ()
54+
}
55+
return
56+
}
57+
58+
// -----
59+
60+
// CHECK-LABEL: func @reify_slice_bound2(
61+
func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
62+
%ub2: index, %t1: tensor<1x?xi8>,
63+
%t2: tensor<?x?xi8>, %t3: tensor<1x?xi32>) {
64+
%c0 = arith.constant 0 : index
65+
%c1 = arith.constant 1 : index
66+
%c32 = arith.constant 32 : index
67+
scf.for %iv0 = %lb0 to %ub0 step %step0 {
68+
// CHECK: %[[c129:.*]] = arith.constant 129 : index
69+
// CHECK: "test.some_use"(%[[c129]])
70+
%ub1 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 128)>(%iv0)[%ub0]
71+
%ub1_ub = "test.reify_bound"(%ub1) {type = "UB"} : (index) -> (index)
72+
"test.some_use"(%ub1_ub) : (index) -> ()
73+
74+
// CHECK: %[[c129:.*]] = arith.constant 129 : index
75+
// CHECK: "test.some_use"(%[[c129]])
76+
%lb1 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 32)>()[%ub1]
77+
%lb1_ub = "test.reify_bound"(%lb1) {type = "UB"} : (index) -> (index)
78+
"test.some_use"(%lb1_ub) : (index) -> ()
79+
80+
scf.for %iv1 = %lb1 to %ub1 step %c32 {
81+
// CHECK: %[[c32:.*]] = arith.constant 32 : index
82+
// CHECK: "test.some_use"(%[[c32]])
83+
%sz = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%iv1)[%ub1]
84+
%sz_ub = "test.reify_bound"(%sz) {type = "UB"} : (index) -> (index)
85+
"test.some_use"(%sz_ub) : (index) -> ()
86+
87+
scf.for %iv2 = %c0 to %ub2 step %c1 {
88+
%slice1 = tensor.extract_slice %t1[0, %iv2] [1, 1] [1, 1] : tensor<1x?xi8> to tensor<1x1xi8>
89+
%slice2 = tensor.extract_slice %t2[%iv2, 0] [1, %sz] [1, 1] : tensor<?x?xi8> to tensor<1x?xi8>
90+
%slice3 = tensor.extract_slice %t3[0, 0] [1, %sz] [1, 1] : tensor<1x?xi32> to tensor<1x?xi32>
91+
%matmul = linalg.matmul ins(%slice1, %slice2 : tensor<1x1xi8>, tensor<1x?xi8>) outs(%slice3 : tensor<1x?xi32>) -> tensor<1x?xi32>
92+
93+
// CHECK: %[[c32:.*]] = arith.constant 32 : index
94+
// CHECK: "test.some_use"(%[[c32]])
95+
%matmul_ub = "test.reify_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
96+
"test.some_use"(%matmul_ub) : (index) -> ()
97+
}
98+
}
99+
}
100+
return
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
2+
// RUN: -split-input-file | FileCheck %s
3+
4+
// CHECK-LABEL: func @linalg_fill(
5+
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
6+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
7+
// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]]
8+
// CHECK: return %[[dim]]
9+
func.func @linalg_fill(%t: tensor<?xf32>, %f: f32) -> index {
10+
%0 = linalg.fill ins(%f : f32) outs(%t : tensor<?xf32>) -> tensor<?xf32>
11+
%1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
12+
return %1 : index
13+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -8654,6 +8654,7 @@ cc_library(
86548654
":Support",
86558655
":TensorDialect",
86568656
":TilingInterface",
8657+
":ValueBoundsOpInterface",
86578658
":ViewLikeInterface",
86588659
"//llvm:Support",
86598660
],

0 commit comments

Comments
 (0)