Skip to content

Commit 2ca9856

Browse files
author
Michael Levesque-Dion
committed
Make inference test for chlo.broadcast_select less brittle
In llvm/llvm-project#74438, the folder for `shape.shape_of` is changed to a canonicalizer. This means that constant shapes no longer get folded automatically (`--canonicalize` must be used). This will cause a test failure when we do the next LLVM integrate, because the `broadcast_select_reify` test expects the `shape.shape_of` operation to be folded into `shape.const_shape`. The test also expects the constant shape value to be pushed to the rightmost arg of the `shape.broadcast` operation, which will not be the case if canonicalization is not applied. Additional context: - The old folder for `shape.shape_of` returned its input shape as a tensor attribute, so it would [automatically get materialized](https://mlir.llvm.org/docs/Canonicalization/#generating-constants-from-attributes) [to a `shape.const_shape` op](https://github.com/llvm/llvm-project/blob/98d8dce6e9e21a995f6a06fa4485fa529931be37/mlir/lib/Dialect/Shape/IR/Shape.cpp#L154-L156). - The new canonicalizer does the materialization explicitly. - [BroadcastOp is Commutative](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td#L57) and [ConstShape is ConstantLike](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td#L105), so if `shape.shape_of` is folded to `shape.const_shape`, the resulting value becomes the rightmost argument to `shape.broadcast`. Indeed, according to the docs [constant arguments of commutative ops are shifted to the right](https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules:~:text=Move%20constant%20operands%20to%20commutative%20operators%20to%20the%20right%20side), and this is implemented [here](https://github.com/llvm/llvm-project/blob/7ddd3d776402f9cc7d5f13b5940ba38a285223c2/mlir/lib/IR/Operation.cpp#L802).
1 parent b046f11 commit 2ca9856

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

stablehlo/tests/infer_chlo.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,14 @@ func.func @broadcast_select_branch_mismatch(%arg0: tensor<2xi1>, %arg1: tensor<2
120120
// -----
121121
// CHECK-LABEL: @broadcast_select_reify
122122
func.func @broadcast_select_reify(%arg0: tensor<2xi1>, %arg1: tensor<?xi32>, %arg2: tensor<?xi32>) -> tensor<1xindex> {
123-
// CHECK: %0 = shape.const_shape [2] : tensor<1xindex>
124-
// CHECK-NEXT: %1 = shape.shape_of %arg1 : tensor<?xi32> -> tensor<1xindex>
125-
// CHECK-NEXT: %2 = shape.shape_of %arg2 : tensor<?xi32> -> tensor<1xindex>
126-
// CHECK-NEXT: %3 = shape.broadcast %1, %2, %0 : tensor<1xindex>, tensor<1xindex>, tensor<1xindex> -> tensor<1xindex>
123+
// CHECK-DAG: %[[ARG0_S:.+]] = shape
124+
// CHECK-DAG: %[[ARG1_S:.+]] = shape
125+
// CHECK-DAG: %[[ARG2_S:.+]] = shape
126+
// CHECK-NEXT: %[[BCAST_S:.+]] = shape.broadcast
127+
// CHECK-DAG: %[[ARG0_S]]
128+
// CHECK-DAG: %[[ARG1_S]]
129+
// CHECK-DAG: %[[ARG2_S]]
130+
// CHECK-NEXT: return %[[BCAST_S]] : tensor<1xindex>
127131
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
128132
%1 = "hlo_test_infer.reify_return_type_shapes"(%0) : (tensor<?xi32>) -> tensor<1xindex>
129133
return %1: tensor<1xindex>

0 commit comments

Comments
 (0)