Skip to content

Commit 03001e5

Browse files
authored
Merge pull request #475 from vzhurba01/324-linker-example
Add a `Linker` example
2 parents 3d413ed + 3097c66 commit 03001e5

File tree

1 file changed

+302
-0
lines changed

1 file changed

+302
-0
lines changed

cuda_core/examples/jit_lto_fractal.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4+
5+
# ################################################################################
6+
#
7+
# This demo aims to illustrate a couple takeaways:
8+
#
9+
# 1. How to use the JIT LTO feature provided by the Linker class to link multiple objects together
10+
# 2. That linking allows for libraries to modify workflows dynamically at runtime
11+
#
12+
# This demo mimics a relationship between a library and a user. The user's sole responsibility is to
13+
# provide device code that generates some art. Whereas the library is responsible for all steps involved in
14+
# setting up the device, launch configurations and arguments, as well as linking the provided device code.
15+
#
16+
# Two algorithms are implemented:
17+
# 1. A Mandelbrot set
18+
# 2. A Julia set
19+
#
20+
# The user can choose which algorithm to use at runtime and generate the resulting image.
21+
#
22+
# ################################################################################
23+
24+
import argparse
25+
import sys
26+
27+
import cupy as cp
28+
29+
from cuda.core.experimental import Device, LaunchConfig, Linker, LinkerOptions, Program, ProgramOptions, launch
30+
31+
32+
# ################################################################################
33+
#
34+
# This Mocklibrary is responsible for all steps involved launching the device code.
35+
#
36+
# The user is responsible for providing the device code that will be linked into the library's workflow.
37+
# The provided device code must contain a function with the signature `void generate_art(float* Data)`
38+
class MockLibrary:
39+
def __init__(self):
40+
# For this mock library, the main workflow is intentionally kept simple by limiting itself to only calling the
41+
# externally defined generate_art function. More involved libraries have the option of applying pre and post
42+
# processing steps before calling user-defined device code. Conversely, these responsibilities can be reversed
43+
# such that the library owns the bulk of the workflow while allowing users to provide customized pre/post
44+
# processing steps.
45+
code_main = r"""
46+
extern __device__ void generate_art(float* Data);
47+
48+
extern "C"
49+
__global__
50+
void main_workflow(float* Data) {
51+
// Preprocessing steps can be called here
52+
// ...
53+
54+
// Call the user-defined device code
55+
generate_art(Data);
56+
57+
// Postprocessing steps can be called here
58+
// ...
59+
}
60+
"""
61+
62+
# Most of the launch configurations can be preemptively done before the user provides their device code
63+
# Therefore lets compile our main workflow device code now, and link the remaining pieces at a later time
64+
self.program_options = ProgramOptions(relocatable_device_code=True)
65+
self.main_object_code = Program(code_main, "c++", options=self.program_options).compile("ptx")
66+
67+
# Setup device state
68+
self.dev = Device()
69+
self.dev.set_current()
70+
self.stream = self.dev.create_stream()
71+
72+
# Setup a buffer to store the RGBA results for the width and height specified
73+
self.width = 1024
74+
self.height = 512
75+
self.buffer = cp.empty(self.width * self.height * 4, dtype=cp.float32)
76+
77+
# Setup the launch configuration such that each thread will be generating one pixel, and subdivide
78+
# the problem into 16x16 chunks.
79+
self.grid = (self.width / 16, self.height / 16, 1.0)
80+
self.block = (16, 16, 1)
81+
self.config = LaunchConfig(grid=self.grid, block=self.block, stream=self.stream)
82+
83+
def link(self, user_code, target_type):
84+
if target_type == "ltoir":
85+
program_options = ProgramOptions(link_time_optimization=True)
86+
linker_options = LinkerOptions(link_time_optimization=True)
87+
elif target_type == "ptx":
88+
program_options = self.program_options
89+
linker_options = LinkerOptions()
90+
else:
91+
raise AssertionError(f"Invalid {target_type=}")
92+
93+
# First, user-defined code is compiled into a PTX object code
94+
user_object_code = Program(user_code, "c++", options=program_options).compile(target_type)
95+
96+
# Then a Linker is created to link the main object code with the user-defined code
97+
linker = Linker(self.main_object_code, user_object_code, options=linker_options)
98+
99+
# We emit the linked code as cubin
100+
linked_code = linker.link("cubin")
101+
102+
# Now we're ready to retrieve the main device function and execute our library's workflow
103+
return linked_code.get_kernel("main_workflow")
104+
105+
def run(self, kernel):
106+
launch(kernel, self.config, self.buffer.data.ptr)
107+
self.stream.sync()
108+
109+
# Return the result as a NumPy array (on host).
110+
return cp.asnumpy(self.buffer).reshape(self.height, self.width, 4)
111+
112+
113+
# Now lets proceed with code from the user's perspective!
114+
#
115+
# ################################################################################
116+
117+
# Simple implementation of Mandelbrot set from Wikipedia
118+
# http://en.wikipedia.org/wiki/Mandelbrot_set
119+
#
120+
# Note that this kernel is meant to be a simple, straight-forward
121+
# implementation. No attempt is made to optimize this GPU code.
122+
code_mandelbrot = r"""
123+
__device__
124+
void generate_art(float* Data) {
125+
// Which pixel am I?
126+
unsigned DataX = blockIdx.x * blockDim.x + threadIdx.x;
127+
unsigned DataY = blockIdx.y * blockDim.y + threadIdx.y;
128+
unsigned Width = gridDim.x * blockDim.x;
129+
unsigned Height = gridDim.y * blockDim.y;
130+
131+
float R, G, B, A;
132+
133+
// Scale coordinates to (-2.5, 1) and (-1, 1)
134+
135+
float NormX = (float)DataX / (float)Width;
136+
NormX *= 3.5f;
137+
NormX -= 2.5f;
138+
139+
float NormY = (float)DataY / (float)Height;
140+
NormY *= 2.0f;
141+
NormY -= 1.0f;
142+
143+
float X0 = NormX;
144+
float Y0 = NormY;
145+
146+
float X = 0.0f;
147+
float Y = 0.0f;
148+
149+
unsigned Iter = 0;
150+
unsigned MaxIter = 1000;
151+
152+
// Iterate
153+
while(X*X + Y*Y < 4.0f && Iter < MaxIter) {
154+
float XTemp = X*X - Y*Y + X0;
155+
Y = 2.0f*X*Y + Y0;
156+
157+
X = XTemp;
158+
159+
Iter++;
160+
}
161+
162+
unsigned ColorG = Iter % 50;
163+
unsigned ColorB = Iter % 25;
164+
165+
R = 0.0f;
166+
G = (float)ColorG / 50.0f;
167+
B = (float)ColorB / 25.0f;
168+
A = 1.0f;
169+
170+
unsigned i = DataY*Width*4+DataX*4;
171+
Data[i+0] = R;
172+
Data[i+1] = G;
173+
Data[i+2] = B;
174+
Data[i+3] = A;
175+
}
176+
"""
177+
178+
# Simple implementation of Julia set from Wikipedia
179+
# http://en.wikipedia.org/wiki/Julia_set
180+
#
181+
# Note that this kernel is meant to be a simple, straight-forward
182+
# implementation. No attempt is made to optimize this GPU code.
183+
code_julia = r"""
184+
__device__
185+
void generate_art(float* Data) {
186+
// Which pixel am I?
187+
unsigned DataX = blockIdx.x * blockDim.x + threadIdx.x;
188+
unsigned DataY = blockIdx.y * blockDim.y + threadIdx.y;
189+
unsigned Width = gridDim.x * blockDim.x;
190+
unsigned Height = gridDim.y * blockDim.y;
191+
192+
float R, G, B, A;
193+
194+
// Scale coordinates to (-2, 2) for both x and y
195+
// Scale coordinates to (-2.5, 1) and (-1, 1)
196+
float X = (float)DataX / (float)Width;
197+
X *= 4.0f;
198+
X -= 2.0f;
199+
200+
float Y = (float)DataY / (float)Height;
201+
Y *= 2.0f;
202+
Y -= 1.0f;
203+
204+
// Julia set uses a fixed constant C
205+
float Cx = -0.8f; // Try different values for different patterns
206+
float Cy = 0.156f; // Try different values for different patterns
207+
208+
unsigned Iter = 0;
209+
unsigned MaxIter = 1000;
210+
211+
// Iterate
212+
while(X*X + Y*Y < 4.0f && Iter < MaxIter) {
213+
float XTemp = X*X - Y*Y + Cx;
214+
Y = 2.0f*X*Y + Cy;
215+
X = XTemp;
216+
Iter++;
217+
}
218+
219+
unsigned ColorG = Iter % 50;
220+
unsigned ColorB = Iter % 25;
221+
222+
R = 0.0f;
223+
G = (float)ColorG / 50.0f;
224+
B = (float)ColorB / 25.0f;
225+
A = 1.0f;
226+
227+
unsigned i = DataY*Width*4+DataX*4;
228+
Data[i+0] = R;
229+
Data[i+1] = G;
230+
Data[i+2] = B;
231+
Data[i+3] = A;
232+
}
233+
"""
234+
235+
236+
def main():
237+
# Parse command line arguments
238+
# Two different kernels are implemented with unique algorithms, and the user can choose which one should be used
239+
# Both kernels fulfill the signature required by the MockLibrary: `void generate_art(float* Data)`
240+
parser = argparse.ArgumentParser()
241+
parser.add_argument(
242+
"--target",
243+
"-t",
244+
type=str,
245+
default="all",
246+
choices=["mandelbrot", "julia", "all"],
247+
help="Type of visualization to generate",
248+
)
249+
parser.add_argument(
250+
"--format",
251+
"-f",
252+
type=str,
253+
default="ltoir",
254+
choices=["ptx", "ltoir"],
255+
help="Type of intermediate format for the device functions to be linked",
256+
)
257+
parser.add_argument(
258+
"--display",
259+
"-d",
260+
action="store_true",
261+
help="Display the generated images",
262+
)
263+
args = parser.parse_args()
264+
265+
if args.display:
266+
try:
267+
import matplotlib.pyplot as plt
268+
except ImportError:
269+
print("this example requires matplotlib installed in order to display the image", file=sys.stderr)
270+
sys.exit(0)
271+
272+
result_to_display = []
273+
lib = MockLibrary()
274+
275+
# Process mandelbrot option
276+
if args.target in ("mandelbrot", "all"):
277+
# The library will compile and link their main kernel with the provided Mandelbrot kernel
278+
kernel = lib.link(code_mandelbrot, args.format)
279+
result = lib.run(kernel)
280+
result_to_display.append((result, "Mandelbrot"))
281+
282+
# Process julia option
283+
if args.target in ("julia", "all"):
284+
# Likewise, the same library can be configured to instead use the provided Julia kernel
285+
kernel = lib.link(code_julia, args.format)
286+
result = lib.run(kernel)
287+
result_to_display.append((result, "Julia"))
288+
289+
# Display the generated images if requested
290+
if args.display:
291+
fig = plt.figure()
292+
for i, (image, title) in enumerate(result_to_display):
293+
axs = fig.add_subplot(len(result_to_display), 1, i + 1)
294+
axs.imshow(image)
295+
axs.set_title(title)
296+
axs.axis("off")
297+
plt.show()
298+
299+
300+
if __name__ == "__main__":
301+
main()
302+
print("done!")

0 commit comments

Comments
 (0)