Skip to content

Commit 55b93a5

Browse files
htyumanman-ren
authored andcommitted
Automatic Warp Specialization Optimization (triton-lang#5622)
Warp specialization enhances kernel performance by utilizing an asynchronous execution model, where different parts of the kernel are handled by separate hardware units. The data communication between these units, via shared memory on the H100, operates with high efficiency. With this in mind, we’ve developed an automatic warp specialization optimization that partitions a user kernel into asynchronous tasks (which map to warp groups on NVIDIA GPU), which naturally execute concurrently, leveraging the hardware’s multitasking warp scheduler. To enable warp specialization, user just needs to specify certain autotune flags, i.e., `num_consumer_groups` and `num_buffers_warp_spec`. For example, a warp-specialized GEMM implementation might look like below. You can find a complete example in 09-persistent-matmul.py. ```python @triton.autotune( configs=[ triton.Config( { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, }, num_stages=2, num_warps=4, num_consumer_groups=2, num_buffers_warp_spec=3, ), ], key=["M", "N", "K"], ) @triton.jit def matmul_persistent_ws_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) pid_m = pid // num_pid_m pid_n = pid % num_pid_n offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): a = tl.load(a_ptrs) b = tl.load(b_ptrs) acc += tl.dot(a, b) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk c = acc.to(tl.float16) c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] tl.store(c_ptrs, c) ``` Merged up to 9e980f7112eb7d0f8c91e9e0f22742bdd09c6f80 in ws-3.3.x
1 parent 5fdff50 commit 55b93a5

File tree

88 files changed

+13490
-347
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+13490
-347
lines changed

README.md

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,160 @@ the [triton-dev-containers repository](https://github.com/redhat-et/triton-dev-c
308308

309309
For detailed instructions on how to use the dev containers please see
310310
the [dev container user guide](https://github.com/redhat-et/triton-dev-containers/blob/main/.devcontainer/devcontainer.md)
311+
312+
# Warp Specialization Support
313+
314+
315+
Warp specialization enhances kernel performance by utilizing an asynchronous execution model, where different parts of the kernel are handled by separate hardware units. The data communication between these units, via shared memory on the H100, operates with high efficiency. With this in mind, we’ve developed an automatic warp specialization optimization that partitions a user kernel into asynchronous tasks (which map to warp groups on NVIDIA GPU), which naturally execute concurrently, leveraging the hardware’s multitasking warp scheduler. The following sections provide a breakdown of the compiler features developed to enable warp specialization.
316+
317+
318+
## Asynchronous Tasks
319+
320+
Warp specialization is built on top of the concept of partitioning the user’s program into asynchronous tasks (referred to as "async tasks" or “tasks” in the following sections). Each async task will be executed by a standalone warp group on the supported hardware, to achieve instruction level parallelism. While optimally and automatically partitioning asynchronous tasks remains a challenge for compilers, our approach to automatic task partitioning has proven effective for kernels similar to typical examples like GEMM and Flash Attention.
321+
322+
To enable warp specialization, user just needs to specify certain autotune flags, i.e., `num_consumer_groups` and `num_buffers_warp_spec`. For example, a warp-specialized GEMM implementation might look like below. You can find a complete example in 09-persistent-matmul.py.
323+
324+
```python
325+
@triton.autotune(
326+
configs=[
327+
triton.Config(
328+
{
329+
"BLOCK_SIZE_M": 128,
330+
"BLOCK_SIZE_N": 256,
331+
"BLOCK_SIZE_K": 64,
332+
"GROUP_SIZE_M": 8,
333+
},
334+
num_stages=2,
335+
num_warps=4,
336+
num_consumer_groups=2,
337+
num_buffers_warp_spec=3,
338+
),
339+
],
340+
key=["M", "N", "K"],
341+
)
342+
@triton.jit
343+
def matmul_persistent_ws_kernel(
344+
a_ptr, b_ptr, c_ptr, M, N, K,
345+
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
346+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
347+
):
348+
pid = tl.program_id(axis=0)
349+
num_pid_m = tl.cdiv(M, BLOCK_M)
350+
num_pid_n = tl.cdiv(N, BLOCK_N)
351+
pid_m = pid // num_pid_m
352+
pid_n = pid % num_pid_n
353+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
354+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
355+
offs_k = tl.arange(0, BLOCK_K)
356+
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
357+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
358+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
359+
for k in range(0, tl.cdiv(K, BLOCK_K)):
360+
a = tl.load(a_ptrs)
361+
b = tl.load(b_ptrs)
362+
acc += tl.dot(a, b)
363+
a_ptrs += BLOCK_K * stride_ak
364+
b_ptrs += BLOCK_K * stride_bk
365+
c = acc.to(tl.float16)
366+
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
367+
tl.store(c_ptrs, c)
368+
```
369+
370+
The compiler automatically determines how to utilize one producer warp group and two consumer warp groups to execute the kernel. It begins by assigning task IDs to certain anchor operations, which influence the task assignments for the remaining operations. Once the anchor tasks are annotated, the compiler assigns the non-anchor operations to tasks as follows:
371+
372+
- Control dependencies exclusive to an anchor operation are included in the same task as the anchor operation.
373+
- Data dependencies exclusive to an anchor operation are included in the same task as the anchor operation, unless they are another anchor operation.
374+
- Control or data dependencies shared between tasks are included in all those tasks.
375+
376+
For the GEMM example above, the compiler computes a task scheme and annotates it in the IR using MLIR attributes. To illustrate this more clearly, let's use source code annotations. After task propagation:
377+
378+
379+
```python
380+
@triton.jit
381+
def matmul_persistent_ws_kernel(
382+
a_ptr, b_ptr, c_ptr, M, N, K,
383+
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
384+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
385+
):
386+
pid = tl.program_id(axis=0) # async_task 0, 1
387+
num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1
388+
num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1
389+
pid_m = pid // num_pid_m # async_task 0, 1
390+
pid_n = pid % num_pid_n # async_task 0, 1
391+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # async_task 0, 1
392+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # async_task 0, 1
393+
offs_k = tl.arange(0, BLOCK_K) # async_task 0
394+
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
395+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0
396+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # async_task 1
397+
for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1
398+
a = tl.load(a_ptrs) # async_task 0
399+
b = tl.load(b_ptrs) # async_task 0
400+
acc += tl.dot(a, b) # async_task 1
401+
a_ptrs += BLOCK_K * stride_ak # async_task 0
402+
b_ptrs += BLOCK_K * stride_bk # async_task 0
403+
c = acc.to(tl.float16) # async_task 1
404+
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] # async_task 1
405+
tl.store(c_ptrs, c) # async_task 1
406+
```
407+
408+
409+
## Data Partitioning
410+
411+
To further improve performance, the compiler will split the same workload across two async tasks This way, when one task is blocked on a heavy computation (e.g., the dot operation), the other group can execute other operations in parallel. The compiler determines how to divide the work between the two tasks to maximize performance. On the H100 GPU, the compiler will, by default, attempt to split the input tensor A along the M dimension so that each consumer computes half of the output tensor independently. This approach is known as cooperative partitioning. If this split is not advantageous—for instance, if it results in a smaller-than-native `wgmma` instruction—the compiler will instead attempt to split along the N dimension.
412+
413+
The transformed code for the above GEMM kernel with a configured tile size [128, 256, 64] will look like below (using source annotations instead of IR for illustration).
414+
415+
416+
```python
417+
@triton.jit
418+
def matmul_persistent_ws_kernel(
419+
a_ptr, b_ptr, c_ptr, M, N, K,
420+
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
421+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
422+
):
423+
pid = tl.program_id(axis=0) # async_task 0, 1, 2
424+
num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1, 2
425+
num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1, 2
426+
pid_m = pid // num_pid_m # async_task 0, 1, 2
427+
pid_n = pid % num_pid_n # async_task 0, 1, 2
428+
offs_m_1 = pid_m * BLOCK_M + tl.arange(0, BLOCK_M // 2) # async_task 0, 1, 2
429+
offs_m_2 = pid_m * BLOCK_M + tl.arange(BLOCK_M // 2, BLOCK_M) # async_task 0, 1, 2
430+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_N) # async_task 0, 1, 2
431+
offs_k = tl.arange(0, BLOCK_K) # async_task 0
432+
a_ptrs_1 = a_ptr + (offs_m_1[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
433+
a_ptrs_2 = a_ptr + (offs_m_2[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
434+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0
435+
acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 1
436+
acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 2
437+
for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1, 2
438+
a_1 = tl.load(a_ptrs_1) # async_task 0
439+
a_2 = tl.load(a_ptrs_2) # async_task 0
440+
b = tl.load(b_ptrs) # async_task 0
441+
acc_1 += tl.dot(a_1, b) # async_task 1
442+
acc_2 += tl.dot(a_2, b) # async_task 2
443+
a_ptrs_1 += BLOCK_K * stride_ak # async_task 0
444+
a_ptrs_2 += BLOCK_K * stride_ak # async_task 0
445+
b_ptrs += BLOCK_K * stride_bk # async_task 0
446+
c_1 = acc_1.to(tl.float16) # async_task 1
447+
c_2 = acc_2.to(tl.float16) # async_task 2
448+
c_ptrs_1 = c_ptr_1 + stride_cm * offs_m_1[:, None] + stride_cn * offs_n[None, :] # async_task 1
449+
c_ptrs_2 = c_ptr_2 + stride_cm * offs_m_2[:, None] + stride_cn * offs_n[None, :] # async_task 2
450+
tl.store(c_ptrs_1, c_1) # async_task 1
451+
tl.store(c_ptrs_2, c_2) # async_task 2
452+
```
453+
454+
455+
## Code Partitioning
456+
457+
We assume all operations are already marked with a list of taskIds. We first find all communications required between warp groups. Each communication starts from a load operation with a single taskId, and ends at a direct user of the load which belongs to a different taskId. For `ForOps` containing a communication channel, we add additional arguments: `phase` and `bufferIndex`.
458+
459+
We introduce a tuning configuration: `num_buffers_warp_spec`. For each communication channel, if it is within a `forOp`, we use an array of buffers in SMEM to save the results, and size of the array is determined by `num_buffers_warp_spec`. We also use an array of barriers for each communication channel that is inside a `ForOp`. At this pass, four new operations are introduced to correctly synchronize between the producer and the consumer: `ProducerAcquireOp`, `ProducerCommitOp`, `ConsumerWaitOp`, and `ConsumerReleaseOp`. Each of the four new ops take a token, a buffer Index. `ProducerAcquire` and `ConsumerWait` take an additional phase operand.
460+
461+
462+
For `ForOps` with multiple task Ids, we clone one copy for each taskId, each copy contains the operations with the specific taskId. In the end, we create multiple `IfOps`, one for each possible taskId. We go through the body of the function, clone the op for each attached task Id and put the cloned op in the right `IfOp`.
463+
464+
To adjust register usage, we introduce two new ops: `RegAllocOp` and `RegDeallocOp`, both taking an integer operand. For each warp group, we decide to insert either `RegAllocOp` or `RegDeallocOp`. The current heuristic is simple: if the task Id is 0, we add `RegDeallocOp`, otherwise we use `RegAllocOp`. The amount of register adjustment can be tuned via `reg_dec_producer` and `reg_inc_consumer`.
465+
466+
This pass also lowers `loadOp`s to `AsyncTMACopyGlobalToLocalOp` or `AsyncCopyGlobalToLocalOp`, so the communication can be expressed via SMEM. For TMA, the producer will become
467+
`ProducerAcquire` -> `barrier_expect` -> `AsyncTMACopyGlobalToLocalOp`, and the consumer will contain `wait_barrier` -> ops -> `ConsumerRelease`. For non-TMA loads, the producer will become `ProducerAcquire` -> `AsyncCopyGlobalToLocalOp` -> `ProducerCommitOp`, and the consumer will contain `ConsumerWaitOp` -> ops -> `ConsumerRelease`.

include/triton/Analysis/Allocation.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,16 @@ class Allocation {
197197
size_t size;
198198
size_t alignment;
199199
size_t offset;
200+
SetVector<int> regionIds;
201+
int sharingGroup; // -1 means not shared
200202

201203
bool operator==(const BufferT &other) const { return id == other.id; }
202204
bool operator<(const BufferT &other) const { return id < other.id; }
203205

204206
BufferT(BufferKind kind, BufferId id, Operation *owner, size_t size,
205-
size_t alignment = 4, size_t offset = 0)
207+
size_t alignment = 4, size_t offset = 0, int sharingGroup = -1)
206208
: kind(kind), id(id), owner(owner), size(size), alignment(alignment),
207-
offset(offset) {}
209+
offset(offset), sharingGroup(sharingGroup) {}
208210

209211
size_t setOffsetAligned(size_t newOffset) {
210212
return offset = llvm::alignTo(newOffset, alignment);

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
102102
const TargetInfoBase &targetInfo,
103103
PatternBenefit benefit);
104104

105+
void populateRegReallocOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
106+
RewritePatternSet &patterns,
107+
PatternBenefit benefit);
108+
105109
} // namespace triton
106110
} // namespace mlir
107111

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1717
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1818
#include "triton/Dialect/TritonGPU/IR/Types.h"
19+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1920
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2021
#include "triton/Tools/LinearLayout.h"
2122
#include "triton/Tools/StrUtil.h"
@@ -331,6 +332,81 @@ class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder {
331332
namespace mlir {
332333
namespace triton {
333334

335+
static inline void insertBarrier(OpBuilder &builder, Operation *op) {
336+
auto barrierOp = builder.create<mlir::gpu::BarrierOp>(op->getLoc());
337+
auto asyncTaskIds = getAsyncTaskIds(op);
338+
assert(asyncTaskIds.size() <= 1);
339+
if (asyncTaskIds.size() == 1) {
340+
int asyncTaskId = asyncTaskIds[0];
341+
int barId = asyncTaskId + nameBarrierIdBegin;
342+
assert(barId < nameBarrierIdEnd);
343+
auto mod = op->getParentOfType<ModuleOp>();
344+
int numWarps = mlir::triton::gpu::lookupNumWarps(op);
345+
int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
346+
int numThreads = numWarps * warpSize;
347+
barrierOp->setAttr("bar_id", builder.getI64IntegerAttr(barId));
348+
barrierOp->setAttr("num_threads", builder.getI64IntegerAttr(numThreads));
349+
}
350+
}
351+
352+
// Delinearize supposing order is [0, 1, .. , n]
353+
template <typename T>
354+
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
355+
llvm::ArrayRef<T> shape) {
356+
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
357+
size_t rank = shape.size();
358+
T accMul = product(shape.drop_back());
359+
T linearRemain = linearIndex;
360+
llvm::SmallVector<T> multiDimIndex(rank);
361+
for (int i = rank - 1; i >= 0; --i) {
362+
multiDimIndex[i] = linearRemain / accMul;
363+
linearRemain = linearRemain % accMul;
364+
if (i != 0) {
365+
accMul = accMul / shape[i - 1];
366+
}
367+
}
368+
return multiDimIndex;
369+
}
370+
371+
template <typename T>
372+
llvm::SmallVector<T> getMultiDimIndex(T linearIndex, llvm::ArrayRef<T> shape,
373+
llvm::ArrayRef<unsigned> order) {
374+
size_t rank = shape.size();
375+
assert(rank == order.size());
376+
auto reordered = applyPermutation(shape, order);
377+
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
378+
llvm::SmallVector<T> multiDim(rank);
379+
for (unsigned i = 0; i < rank; ++i) {
380+
multiDim[order[i]] = reorderedMultiDim[i];
381+
}
382+
return multiDim;
383+
}
384+
385+
// Linearize supposing order is [0, 1, .. , n]
386+
template <typename T>
387+
T getLinearIndexImpl(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape) {
388+
assert(multiDimIndex.size() == shape.size());
389+
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
390+
size_t rank = shape.size();
391+
T accMul = product(shape.drop_back());
392+
T linearIndex = 0;
393+
for (int i = rank - 1; i >= 0; --i) {
394+
linearIndex += multiDimIndex[i] * accMul;
395+
if (i != 0) {
396+
accMul = accMul / shape[i - 1];
397+
}
398+
}
399+
return linearIndex;
400+
}
401+
402+
template <typename T>
403+
T getLinearIndex(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape,
404+
llvm::ArrayRef<unsigned> order) {
405+
assert(shape.size() == order.size());
406+
return getLinearIndexImpl<T>(applyPermutation(multiDimIndex, order),
407+
applyPermutation(shape, order));
408+
}
409+
334410
namespace gpu {
335411
Type getFunctionType(Type resultType, ValueRange operands);
336412

include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@ class TTG_TypeDef<string name, string _mnemonic, list<Trait> traits = []>
1010
let mnemonic = _mnemonic;
1111
}
1212

13+
def TTG_TokenType : TTG_TypeDef<"Token", "token"> {
14+
let parameters = (ins "int32_t":$type);
15+
16+
let builders = [
17+
TypeBuilder<(ins "unsigned":$type), [{
18+
return $_get($_ctxt, type);
19+
}]>
20+
];
21+
22+
let hasCustomAssemblyFormat = 1;
23+
24+
let skipDefaultBuilders = 1;
25+
}
26+
1327
def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", "async.token", []> {
1428
let summary = "async token type";
1529
let description = [{

0 commit comments

Comments
 (0)