Skip to content

Commit fed4cf7

Browse files
committed
Review comments
1 parent 59d5dea commit fed4cf7

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

cuda_core/examples/jit_lto_fractal.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
# 1. How to use the JIT LTO feature provided by the Linker class to link multiple objects together
1010
# 2. That linking allows for libraries to modify workflows dynamically at runtime
1111
#
12-
# This demo mimics a relationship between a library and a user. The user's sole responsability is to
13-
# provide device code that generates art. Where as the library is responsible for all steps involved in
14-
# setting up the device, launch configurations and arguments as well as linking the provided device code.
12+
# This demo mimics a relationship between a library and a user. The user's sole responsibility is to
13+
# provide device code that generates some art. Whereas the library is responsible for all steps involved in
14+
# setting up the device, launch configurations and arguments, as well as linking the provided device code.
1515
#
1616
# Two algorithms are implemented:
1717
# 1. A Mandelbrot set
@@ -37,9 +37,9 @@
3737
# The provided device code must contain a function with the signature `void generate_art(float* Data)`
3838
class MockLibrary:
3939
def __init__(self):
40-
# For this mock library, the main workflow is intentially kept simple by limiting itself to only calling the
40+
# For this mock library, the main workflow is intentionally kept simple by limiting itself to only calling the
4141
# externally defined generate_art function. More involved libraries have the option of applying pre and post
42-
# processing steps before calling user-defined device code. Conversely, these responsabilities can be reversed
42+
# processing steps before calling user-defined device code. Conversely, these responsibilities can be reversed
4343
# such that the library owns the bulk of the workflow while allowing users to provide customized pre/post
4444
# processing steps.
4545
code_main = r"""
@@ -61,33 +61,32 @@ def __init__(self):
6161

6262
# Most of the launch configurations can be preemptively done before the user provides their device code
6363
# Therefore lets compile our main workflow device code now, and link the remaining pieces at a later time
64-
self.arch = "".join(f"{i}" for i in Device().compute_capability)
65-
self.program_options = ProgramOptions(std="c++11", arch=f"sm_{self.arch}", relocatable_device_code=True)
64+
self.program_options = ProgramOptions(relocatable_device_code=True)
6665
self.main_object_code = Program(code_main, "c++", options=self.program_options).compile("ptx")
6766

6867
# Setup device state
6968
self.dev = Device()
7069
self.dev.set_current()
7170
self.stream = self.dev.create_stream()
7271

73-
# Setup buffer to store our results
72+
# Setup a buffer to store the RGBA results for the width and height specified
7473
self.width = 1024
7574
self.height = 512
7675
self.buffer = cp.empty(self.width * self.height * 4, dtype=cp.float32)
7776

7877
# Setup the launch configuration such that each thread will be generating one pixel, and subdivide
7978
# the problem into 16x16 chunks.
80-
self.grid = (self.width / 16, self.height / 16, 1)
79+
self.grid = (self.width / 16, self.height / 16, 1.0)
8180
self.block = (16, 16, 1)
8281
self.config = LaunchConfig(grid=self.grid, block=self.block, stream=self.stream)
8382

8483
def link(self, user_code, target_type):
8584
if target_type == "ltoir":
86-
program_options = ProgramOptions(std="c++11", arch=f"sm_{self.arch}", link_time_optimization=True)
87-
linker_options = LinkerOptions(arch=f"sm_{self.arch}", link_time_optimization=True)
85+
program_options = ProgramOptions(link_time_optimization=True)
86+
linker_options = LinkerOptions(link_time_optimization=True)
8887
elif target_type == "ptx":
8988
program_options = self.program_options
90-
linker_options = LinkerOptions(arch=f"sm_{self.arch}")
89+
linker_options = LinkerOptions()
9190
else:
9291
raise AssertionError
9392

@@ -119,7 +118,7 @@ def run(self, kernel):
119118
# http://en.wikipedia.org/wiki/Mandelbrot_set
120119
#
121120
# Note that this kernel is meant to be a simple, straight-forward
122-
# implementation, and so may not represent optimized GPU code.
121+
# implementation. No attempt is made to optimize this GPU code.
123122
code_mandelbrot = r"""
124123
__device__
125124
void generate_art(float* Data) {
@@ -168,18 +167,19 @@ def run(self, kernel):
168167
B = (float)ColorB / 25.0f;
169168
A = 1.0f;
170169
171-
Data[DataY*Width*4+DataX*4+0] = R;
172-
Data[DataY*Width*4+DataX*4+1] = G;
173-
Data[DataY*Width*4+DataX*4+2] = B;
174-
Data[DataY*Width*4+DataX*4+3] = A;
170+
unsigned i = DataY*Width*4+DataX*4;
171+
Data[i+0] = R;
172+
Data[i+1] = G;
173+
Data[i+2] = B;
174+
Data[i+3] = A;
175175
}
176176
"""
177177

178178
# Simple implementation of Julia set from Wikipedia
179179
# http://en.wikipedia.org/wiki/Julia_set
180180
#
181181
# Note that this kernel is meant to be a simple, straight-forward
182-
# implementation, and so may not represent optimized GPU code.
182+
# implementation. No attempt is made to optimize this GPU code.
183183
code_julia = r"""
184184
__device__
185185
void generate_art(float* Data) {
@@ -224,15 +224,16 @@ def run(self, kernel):
224224
B = (float)ColorB / 25.0f;
225225
A = 1.0f;
226226
227-
Data[DataY*Width*4+DataX*4+0] = R;
228-
Data[DataY*Width*4+DataX*4+1] = G;
229-
Data[DataY*Width*4+DataX*4+2] = B;
230-
Data[DataY*Width*4+DataX*4+3] = A;
227+
unsigned i = DataY*Width*4+DataX*4;
228+
Data[i+0] = R;
229+
Data[i+1] = G;
230+
Data[i+2] = B;
231+
Data[i+3] = A;
231232
}
232233
"""
233234

234235

235-
if __name__ == "__main__":
236+
def main():
236237
# Parse command line arguments
237238
# Two different kernels are implemented with unique algorithms, and the user can choose which one should be used
238239
# Both kernels fulfill the signature required by the MockLibrary: `void generate_art(float* Data)`
@@ -296,3 +297,7 @@ def run(self, kernel):
296297
plt.show()
297298

298299
print("done!")
300+
301+
302+
if __name__ == "__main__":
303+
main()

0 commit comments

Comments
 (0)