Skip to content

Commit 1dd2d12

Browse files
giuserosbertmaher
authored andcommitted
[AMD] Disable block merging to avoid block argument explosion (triton-lang#4176)
This PR disable block merging when running `convert-builtin-func-to-llvm`. The reason behind this is that for now block merging can double the arguments of the blocks. This means that after a while we can start witnessing a block argument "explosion" which hangs the compiler. I am working on this ticket: llvm/llvm-project#63230 to make block merging better, but in the meantime, we should stop merging blocks to avoid compiler hangs. I added the minimal test to reproduce the explosion. The test for now is checking that we don't try to merge blocks.
1 parent 53e4aa9 commit 1dd2d12

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: triton-opt --convert-builtin-func-to-llvm %s | FileCheck %s
2+
3+
// Trying to merge those blocks will cause a lot of duplication in the block arguments, which will cause
4+
// an exponential growth of the argument length. Make sure we don't try to merge those blocks.
5+
module {
6+
llvm.func @rand() -> i1
7+
llvm.func @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(!llvm.ptr<1>, i32, i1) attributes {libname = "", libpath = ""}
8+
9+
llvm.func @top(%arg0: i64, %1 : !llvm.ptr<1>, %2 : !llvm.ptr<1>, %3 : !llvm.ptr<1>, %4 : !llvm.ptr<1>) {
10+
%0 = llvm.mlir.constant(0 : i64) : i64
11+
%10 = llvm.icmp "eq" %arg0, %0 : i64
12+
%true = llvm.mlir.constant(1 : i1) : i1
13+
%c = llvm.mlir.constant(1 : i32) : i32
14+
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
15+
llvm.cond_br %10, ^bb1, ^bb14
16+
^bb1: // pred: ^bb0
17+
%11 = llvm.call @rand() : () -> i1
18+
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
19+
llvm.cond_br %11, ^bb2, ^bb3
20+
^bb2: // pred: ^bb1
21+
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%1, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
22+
llvm.br ^bb4
23+
^bb3: // pred: ^bb1
24+
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%2, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
25+
llvm.br ^bb4
26+
^bb4: // 2 preds: ^bb2, ^bb3
27+
%14 = llvm.call @rand() : () -> i1
28+
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
29+
llvm.cond_br %14, ^bb5, ^bb6
30+
^bb5: // pred: ^bb4
31+
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%3, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
32+
llvm.br ^bb13
33+
^bb6: // pred: ^bb4
34+
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%4, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
35+
llvm.br ^bb13
36+
^bb13: // 2 preds: ^bb11, ^bb12
37+
llvm.br ^bb27
38+
^bb14: // pred: ^bb0
39+
%23 = llvm.call @rand() : () -> i1
40+
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
41+
llvm.cond_br %23, ^bb15, ^bb16
42+
^bb15: // pred: ^bb14
43+
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%4, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
44+
llvm.br ^bb17
45+
^bb16: // pred: ^bb14
46+
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%3, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
47+
llvm.br ^bb17
48+
^bb17: // 2 preds: ^bb15, ^bb16
49+
%26 = llvm.call @rand() : () -> i1
50+
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
51+
llvm.cond_br %26, ^bb18, ^bb19
52+
^bb18: // pred: ^bb17
53+
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%2, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
54+
llvm.br ^bb26
55+
^bb19: // pred: ^bb17
56+
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%1, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
57+
llvm.br ^bb26
58+
^bb26: // 2 preds: ^bb24, ^bb25
59+
llvm.br ^bb27
60+
^bb27: // 2 preds: ^bb13, ^bb26
61+
llvm.return
62+
}
63+
}

third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,17 @@ struct ConvertBuiltinFuncToLLVM
165165
MLIRContext *context = &getContext();
166166
ModuleOp mod = getOperation();
167167

168+
// Disable block merging because of:
169+
// https://github.com/llvm/llvm-project/issues/63230
170+
// TODO(giuseros): enable block merging once the above ticket is completed
171+
GreedyRewriteConfig config;
172+
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
173+
168174
RewritePatternSet patterns(context);
169175
patterns.add<CallOpConversion>(context);
170176

171-
if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) {
177+
if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns), config)
178+
.failed()) {
172179
signalPassFailure();
173180
}
174181
}

0 commit comments

Comments
 (0)