Skip to content

Commit 34b24f7

Browse files
authored
float8 profiling script: filter out microbenchmarking overhead (#629)
Summary: Our microbenchmarks have a lot of overhead. This PR attempts to get a cleaner measurement of only the kernels in the fwd+bwd, and subtracts the kernels unrelated to fwd+bwd code. This makes the kernel summary tables more reflective of GPU bound real use cases. Test Plan: profiling ln -> linear: ``` python benchmarks/float8/profile_linear_float8.py --dtype_filter both ~/local/tmp --model_type ln_linear ``` new output, note that only kernels relevant to ln and linear are displayed ``` Summary of GPU time by CPU kernel experiment kernel category time_ms pct_gpu_time bw_gpbs 1 0_ref aten::mm 0_gemm 10.153 0.945 None 2 0_ref triton_red_fused_native_layer_norm_native_layer_norm_backward_0 2_other 0.350 0.033 None 0 0_ref triton_red_fused_native_layer_norm_0 2_other 0.241 0.022 None 12 1_float8 aten::_scaled_mm 0_gemm 5.182 0.736 None 16 1_float8 triton_red_fused__scaled_mm__to_copy_clamp_clone_mul_native_layer_norm_native_layer_norm_backwar... 1_f8_overhead 0.813 0.115 None 15 1_float8 triton_poi_fused__scaled_mm__to_copy_clamp_clone_mul_reciprocal_view_2 1_f8_overhead 0.302 0.043 None 5 1_float8 triton_red_fused_abs_max_native_layer_norm_0 1_f8_overhead 0.212 0.030 None 10 1_float8 triton_poi_fused__scaled_mm__to_copy_clamp_mul_native_layer_norm_view_5 1_f8_overhead 0.177 0.025 None 11 1_float8 triton_poi_fused__scaled_mm__to_copy_clamp_clone_mul_native_layer_norm_view_6 1_f8_overhead 0.150 0.021 None 13 1_float8 triton_red_fused_abs_max_0 1_f8_overhead 0.126 0.018 None 7 1_float8 triton_red_fused_abs_max_2 1_f8_overhead 0.060 0.008 None 3 1_float8 triton_per_fused_copy_max_roll_0 1_f8_overhead 0.005 0.001 None 6 1_float8 triton_red_fused__to_copy_abs_clamp_max_mul_native_layer_norm_reciprocal_1 1_f8_overhead 0.004 0.001 None 4 1_float8 triton_per_fused_copy_max_roll_1 1_f8_overhead 0.003 0.000 None 14 1_float8 triton_per_fused__scaled_mm__to_copy_abs_clamp_clone_max_mul_reciprocal_view_1 1_f8_overhead 0.003 0.000 None 8 1_float8 triton_per_fused_abs_fill_max_3 1_f8_overhead 0.003 0.000 None 9 1_float8 triton_poi_fused_reciprocal_4 2_other 0.002 0.000 None Float8 amax/scale sync approx ratio of total time: 0.006 Summary of time (ms) by kernel category experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 10.153 5.182 0.510 1.959 1_f8_overhead 0.000 1.858 inf 0.000 2_other 0.591 0.002 0.004 264.393 All 10.743 7.042 0.655 1.526 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 934dead commit 34b24f7

File tree

2 files changed

+139
-79
lines changed

2 files changed

+139
-79
lines changed

benchmarks/float8/profile_linear_float8.py

Lines changed: 79 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
kernel_name_to_category,
3030
parse_bw_and_kernel_name,
3131
profiler_output_to_gpu_time_for_key,
32-
profiler_output_to_time_by_kernel_name,
32+
profiler_output_to_filtered_time_by_kernel_name,
3333
)
3434

3535
# don't truncate long kernel names
@@ -312,85 +312,89 @@ def float8_forw_backward_wrapper(x):
312312
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
313313
# to populate triton kernel bandwidth further down in the script
314314
f = io.StringIO()
315-
with redirect_stdout(f):
316-
# warm up
317-
for _ in range(1):
315+
try:
316+
with redirect_stdout(f):
317+
# warm up
318+
for _ in range(1):
319+
if dtype_filter != "float8":
320+
ref_forw_backward(input_tensor)
321+
if dtype_filter != "bfloat16":
322+
float8_forw_backward_wrapper(input_tensor)
323+
324+
profile_iters = 5
325+
ref_times, float8_times = None, None
326+
data = []
327+
328+
num_leaf_tensors = 1 + len(list(m_ref.parameters()))
329+
318330
if dtype_filter != "float8":
319-
ref_forw_backward(input_tensor)
320-
if dtype_filter != "bfloat16":
321-
float8_forw_backward_wrapper(input_tensor)
322-
323-
profile_iters = 5
324-
ref_times, float8_times = None, None
325-
data = []
326-
327-
if dtype_filter != "float8":
328-
# Profile Reference Model
329-
print("profiling ref")
330-
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
331-
ref_path = profile_path_prefix + ref_suffix
332-
profile_config = ProfileConfig(
333-
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
334-
)
335-
p = profile_function(profile_config, ref_forw_backward, input_tensor)
336-
print(f"saved {ref_path}")
337-
ref_times = profiler_output_to_time_by_kernel_name(p)
338-
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
339-
for k, v in ref_times.items():
340-
v_ms = v / 1e3 / profile_iters
341-
data.append(
342-
[
343-
"0_ref",
344-
k,
345-
kernel_name_to_category(k),
346-
v_ms,
347-
v_ms / total_time_ms,
348-
None,
349-
]
331+
# Profile Reference Model
332+
print("profiling ref")
333+
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
334+
ref_path = profile_path_prefix + ref_suffix
335+
profile_config = ProfileConfig(
336+
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
350337
)
338+
p = profile_function(profile_config, ref_forw_backward, input_tensor)
339+
print(f"saved {ref_path}")
340+
ref_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
341+
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
342+
for k, v in ref_times.items():
343+
v_ms = v / 1e3 / profile_iters
344+
data.append(
345+
[
346+
"0_ref",
347+
k,
348+
kernel_name_to_category(k),
349+
v_ms,
350+
v_ms / total_time_ms,
351+
None,
352+
]
353+
)
351354

352-
if dtype_filter != "bfloat16":
353-
# Profile Float8 Model
354-
print("profiling float8")
355-
float8_suffix = (
356-
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
357-
)
358-
float8_path = profile_path_prefix + float8_suffix
359-
profile_config = ProfileConfig(
360-
float8_path,
361-
float8_suffix,
362-
iters=profile_iters,
363-
warmup_iters=2,
364-
sync=True,
365-
)
366-
p = profile_function(
367-
profile_config, float8_forw_backward_wrapper, input_tensor
368-
)
369-
print(f"saved {float8_path}")
370-
float8_times = profiler_output_to_time_by_kernel_name(p)
371-
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
372-
for k, v in float8_times.items():
373-
v_ms = v / 1e3 / profile_iters
374-
data.append(
375-
[
376-
"1_float8",
377-
k,
378-
kernel_name_to_category(k),
379-
v / 1e3 / profile_iters,
380-
v_ms / total_time_ms,
381-
None,
382-
]
355+
if dtype_filter != "bfloat16":
356+
# Profile Float8 Model
357+
print("profiling float8")
358+
float8_suffix = (
359+
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
383360
)
361+
float8_path = profile_path_prefix + float8_suffix
362+
profile_config = ProfileConfig(
363+
float8_path,
364+
float8_suffix,
365+
iters=profile_iters,
366+
warmup_iters=2,
367+
sync=True,
368+
)
369+
p = profile_function(
370+
profile_config, float8_forw_backward_wrapper, input_tensor
371+
)
372+
print(f"saved {float8_path}")
373+
float8_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
374+
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
375+
for k, v in float8_times.items():
376+
v_ms = v / 1e3 / profile_iters
377+
data.append(
378+
[
379+
"1_float8",
380+
k,
381+
kernel_name_to_category(k),
382+
v / 1e3 / profile_iters,
383+
v_ms / total_time_ms,
384+
None,
385+
]
386+
)
387+
388+
# get the time spent per user annotation
389+
sync_time_us = profiler_output_to_gpu_time_for_key(
390+
p, "scale_amax_and_scales"
391+
)
392+
sync_time_ms = sync_time_us / profile_iters / 1e3
393+
print(f"Sync time ms: {sync_time_ms}")
384394

385-
# get the time spent per user annotation
386-
sync_time_us = profiler_output_to_gpu_time_for_key(
387-
p, "scale_amax_and_scales"
388-
)
389-
sync_time_ms = sync_time_us / profile_iters / 1e3
390-
print(f"Sync time ms: {sync_time_ms}")
391-
392-
# print the redirected stdout back to regular stdout
393-
print(f.getvalue())
395+
finally:
396+
# print the redirected stdout back to regular stdout
397+
print(f.getvalue())
394398

395399
# populate the triton kernel bandwidth
396400
for line in f.getvalue().split("\n"):

benchmarks/float8/utils.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,44 @@
99
from typing import Optional
1010

1111

12-
def profiler_output_to_time_by_kernel_name(prof):
12+
def profiler_output_to_filtered_time_by_kernel_name(
13+
prof,
14+
num_iter: int,
15+
num_leaf_tensors: int,
16+
):
1317
"""
14-
Input: a profiler with captured events.
15-
Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name
18+
Input:
19+
* `prof`: a profiler with captured events
20+
* `num_iter`: number of iterations used to capture `prof`
21+
* `num_leaf_tensors`: number of leaf tensors to accumulate gradients to
22+
Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name,
23+
with the microbenchmark overhead filtered out
24+
25+
Currently assumes that `prof` captured events from a microbenchmark which was
26+
set up as follows:
27+
28+
#
29+
# Forward pass
30+
#
31+
32+
# Expected GPU kernel overhead: none
33+
y = func(...)
34+
35+
# Convenient way to set up the backward pass without caring about shapes
36+
y_sum = y.sum()
37+
38+
# Expected GPU kernel overhead:
39+
# * the call to `sum`
40+
41+
#
42+
# Backward pass
43+
#
44+
y_sum.backward()
45+
46+
# Expected GPU kernel overhead:
47+
# * the call to `aten.fill_` to put a tensor with a single 1.0 value as the input to the backward
48+
# * the call to `aten.copy_` to fill the first `grad_output` tensor with 1.0
49+
# * the call to `aten.add_` to accumulate grads, once per leaf tensor
1650
1751
Note that if there are user_annotations in the captured events, `torch.profiler`
1852
will include their time in the total GPU time displayed at the bottom of
@@ -23,13 +57,35 @@ def profiler_output_to_time_by_kernel_name(prof):
2357
thresh = 1e-10
2458
kernel_name_to_gpu_time_us = collections.defaultdict(float)
2559
for e in key_averages:
60+
2661
# manually filter top-level CPU events with attributed CUDA time
27-
# example CPU event row:
62+
# example CPU event row from printing `key_averages`:
2863
# aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1
2964
# and it maps to this CUDA event:
3065
# sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize256x64... 0.00% 0.000us 0.00% 0.000us 0.000us 1.022ms 31.82% 1.022ms 1.022ms 1
3166
if not (e.self_cpu_time_total > thresh and e.self_device_time_total > thresh):
3267
continue
68+
69+
# manually filter expected microbenchmarking overhead, in order of execution
70+
if e.key == 'aten::sum':
71+
# forward pass sum
72+
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
73+
continue
74+
elif e.key == 'aten::fill_':
75+
# filling the forward pass sum with 1.0
76+
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
77+
continue
78+
elif e.key == 'aten::copy_':
79+
# copying 1.0 from grad_out of `sum` to grad_out of next op
80+
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
81+
continue
82+
elif e.key == 'aten::add_':
83+
# accumulating gradients into leaf tensors
84+
assert e.count == (num_iter * num_leaf_tensors), f'unexpected number of iter for {e.key}'
85+
continue
86+
elif e.key == 'cudaDeviceSynchronize':
87+
continue
88+
3389
kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total
3490
return kernel_name_to_gpu_time_us
3591

0 commit comments

Comments
 (0)