You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Warp specialization enhances kernel performance by utilizing an
asynchronous execution model, where different parts of the kernel are
handled by separate hardware units. The data communication between these
units, via shared memory on the H100, operates with high efficiency.
With this in mind, we’ve developed an automatic warp specialization
optimization that partitions a user kernel into asynchronous tasks
(which map to warp groups on NVIDIA GPU), which naturally execute
concurrently, leveraging the hardware’s multitasking warp scheduler.
To enable warp specialization, user just needs to specify certain
autotune flags, i.e., `num_consumer_groups` and `num_buffers_warp_spec`.
For example, a warp-specialized GEMM implementation might look like
below. You can find a complete example in 09-persistent-matmul.py.
```python
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
num_consumer_groups=2,
num_buffers_warp_spec=3,
),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_persistent_ws_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_m
pid_n = pid % num_pid_n
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
tl.store(c_ptrs, c)
```
Merged up to 9e980f7112eb7d0f8c91e9e0f22742bdd09c6f80 in ws-3.3.x
Copy file name to clipboardExpand all lines: README.md
+157Lines changed: 157 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -308,3 +308,160 @@ the [triton-dev-containers repository](https://github.com/redhat-et/triton-dev-c
308
308
309
309
For detailed instructions on how to use the dev containers please see
310
310
the [dev container user guide](https://github.com/redhat-et/triton-dev-containers/blob/main/.devcontainer/devcontainer.md)
311
+
312
+
# Warp Specialization Support
313
+
314
+
315
+
Warp specialization enhances kernel performance by utilizing an asynchronous execution model, where different parts of the kernel are handled by separate hardware units. The data communication between these units, via shared memory on the H100, operates with high efficiency. With this in mind, we’ve developed an automatic warp specialization optimization that partitions a user kernel into asynchronous tasks (which map to warp groups on NVIDIA GPU), which naturally execute concurrently, leveraging the hardware’s multitasking warp scheduler. The following sections provide a breakdown of the compiler features developed to enable warp specialization.
316
+
317
+
318
+
## Asynchronous Tasks
319
+
320
+
Warp specialization is built on top of the concept of partitioning the user’s program into asynchronous tasks (referred to as "async tasks" or “tasks” in the following sections). Each async task will be executed by a standalone warp group on the supported hardware, to achieve instruction level parallelism. While optimally and automatically partitioning asynchronous tasks remains a challenge for compilers, our approach to automatic task partitioning has proven effective for kernels similar to typical examples like GEMM and Flash Attention.
321
+
322
+
To enable warp specialization, user just needs to specify certain autotune flags, i.e., `num_consumer_groups` and `num_buffers_warp_spec`. For example, a warp-specialized GEMM implementation might look like below. You can find a complete example in 09-persistent-matmul.py.
The compiler automatically determines how to utilize one producer warp group and two consumer warp groups to execute the kernel. It begins by assigning task IDs to certain anchor operations, which influence the task assignments for the remaining operations. Once the anchor tasks are annotated, the compiler assigns the non-anchor operations to tasks as follows:
371
+
372
+
- Control dependencies exclusive to an anchor operation are included in the same task as the anchor operation.
373
+
- Data dependencies exclusive to an anchor operation are included in the same task as the anchor operation, unless they are another anchor operation.
374
+
- Control or data dependencies shared between tasks are included in all those tasks.
375
+
376
+
For the GEMM example above, the compiler computes a task scheme and annotates it in the IR using MLIR attributes. To illustrate this more clearly, let's use source code annotations. After task propagation:
To further improve performance, the compiler will split the same workload across two async tasks This way, when one task is blocked on a heavy computation (e.g., the dot operation), the other group can execute other operations in parallel. The compiler determines how to divide the work between the two tasks to maximize performance. On the H100 GPU, the compiler will, by default, attempt to split the input tensor A along the M dimension so that each consumer computes half of the output tensor independently. This approach is known as cooperative partitioning. If this split is not advantageous—for instance, if it results in a smaller-than-native `wgmma` instruction—the compiler will instead attempt to split along the N dimension.
412
+
413
+
The transformed code for the above GEMM kernel with a configured tile size [128, 256, 64] will look like below (using source annotations instead of IR for illustration).
We assume all operations are already marked with a list of taskIds. We first find all communications required between warp groups. Each communication starts from a load operation with a single taskId, and ends at a direct user of the load which belongs to a different taskId. For `ForOps` containing a communication channel, we add additional arguments: `phase` and `bufferIndex`.
458
+
459
+
We introduce a tuning configuration: `num_buffers_warp_spec`. For each communication channel, if it is within a `forOp`, we use an array of buffers in SMEM to save the results, and size of the array is determined by `num_buffers_warp_spec`. We also use an array of barriers for each communication channel that is inside a `ForOp`. At this pass, four new operations are introduced to correctly synchronize between the producer and the consumer: `ProducerAcquireOp`, `ProducerCommitOp`, `ConsumerWaitOp`, and `ConsumerReleaseOp`. Each of the four new ops take a token, a buffer Index. `ProducerAcquire` and `ConsumerWait` take an additional phase operand.
460
+
461
+
462
+
For `ForOps` with multiple task Ids, we clone one copy for each taskId, each copy contains the operations with the specific taskId. In the end, we create multiple `IfOps`, one for each possible taskId. We go through the body of the function, clone the op for each attached task Id and put the cloned op in the right `IfOp`.
463
+
464
+
To adjust register usage, we introduce two new ops: `RegAllocOp` and `RegDeallocOp`, both taking an integer operand. For each warp group, we decide to insert either `RegAllocOp` or `RegDeallocOp`. The current heuristic is simple: if the task Id is 0, we add `RegDeallocOp`, otherwise we use `RegAllocOp`. The amount of register adjustment can be tuned via `reg_dec_producer` and `reg_inc_consumer`.
465
+
466
+
This pass also lowers `loadOp`s to `AsyncTMACopyGlobalToLocalOp` or `AsyncCopyGlobalToLocalOp`, so the communication can be expressed via SMEM. For TMA, the producer will become
467
+
`ProducerAcquire` -> `barrier_expect` -> `AsyncTMACopyGlobalToLocalOp`, and the consumer will contain `wait_barrier` -> ops -> `ConsumerRelease`. For non-TMA loads, the producer will become `ProducerAcquire` -> `AsyncCopyGlobalToLocalOp` -> `ProducerCommitOp`, and the consumer will contain `ConsumerWaitOp` -> ops -> `ConsumerRelease`.
0 commit comments