-
Notifications
You must be signed in to change notification settings - Fork 172
Add cluster
to LaunchConfig
to support thread block clusters on Hopper
#261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
5d253f1
add support for cluster
leofang 4abe520
add a code sample; apply a WAR to a potential bug
leofang 7f58ae4
Merge branch 'main' into cluster
leofang 895c9fa
more robust treatments
leofang a003b98
add release note entries
leofang 20692de
fix invalid context during test teardown
leofang 5ac409b
improve comments in the code sample
leofang 37c2843
Merge branch 'main' into cluster
leofang 6c35033
Merge branch 'main' into cluster
leofang 7d117f2
Merge branch 'main' into cluster
leofang 2c3a619
Merge branch 'main' into cluster
leofang b8004e9
switch from chip chen to compute capability in comments
ksimpson-work 4b95ba4
Merge remote-tracking branch 'leofang/cluster' into HEAD
ksimpson-work File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. | ||
# | ||
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE | ||
|
||
import os | ||
import sys | ||
|
||
from cuda.core.experimental import Device, LaunchConfig, Program, launch | ||
|
||
# prepare include | ||
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME")) | ||
if cuda_path is None: | ||
print("this demo requires a valid CUDA_PATH environment variable set", file=sys.stderr) | ||
sys.exit(0) | ||
cuda_include_path = os.path.join(cuda_path, "include") | ||
|
||
# print cluster info using a kernel | ||
code = r""" | ||
#include <cooperative_groups.h> | ||
|
||
namespace cg = cooperative_groups; | ||
|
||
extern "C" | ||
__global__ void check_cluster_info() { | ||
auto g = cg::this_grid(); | ||
auto b = cg::this_thread_block(); | ||
if (g.cluster_rank() == 0 && g.block_rank() == 0 && g.thread_rank() == 0) { | ||
printf("grid dim: (%u, %u, %u)\n", g.dim_blocks().x, g.dim_blocks().y, g.dim_blocks().z); | ||
printf("cluster dim: (%u, %u, %u)\n", g.dim_clusters().x, g.dim_clusters().y, g.dim_clusters().z); | ||
printf("block dim: (%u, %u, %u)\n", b.dim_threads().x, b.dim_threads().y, b.dim_threads().z); | ||
} | ||
} | ||
""" | ||
|
||
dev = Device() | ||
arch = dev.compute_capability | ||
if arch < (9, 0): | ||
print( | ||
"this demo requires compute capability >= 9.0 (since thread block cluster is a hardware feature)", | ||
file=sys.stderr, | ||
) | ||
sys.exit(0) | ||
arch = "".join(f"{i}" for i in arch) | ||
|
||
# prepare program & compile kernel | ||
dev.set_current() | ||
prog = Program(code, code_type="c++") | ||
mod = prog.compile( | ||
target_type="cubin", | ||
# TODO: update this after NVIDIA/cuda-python#237 is merged | ||
options=(f"-arch=sm_{arch}", "-std=c++17", f"-I{cuda_include_path}"), | ||
) | ||
ker = mod.get_kernel("check_cluster_info") | ||
|
||
# prepare launch config | ||
grid = 4 | ||
cluster = 2 | ||
block = 32 | ||
config = LaunchConfig(grid=grid, cluster=cluster, block=block, stream=dev.default_stream) | ||
|
||
# launch kernel on the default stream | ||
launch(ker, config) | ||
dev.sync() | ||
|
||
print("done!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.