|
29 | 29 | kernel_name_to_category,
|
30 | 30 | parse_bw_and_kernel_name,
|
31 | 31 | profiler_output_to_gpu_time_for_key,
|
32 |
| - profiler_output_to_time_by_kernel_name, |
| 32 | + profiler_output_to_filtered_time_by_kernel_name, |
33 | 33 | )
|
34 | 34 |
|
35 | 35 | # don't truncate long kernel names
|
@@ -312,85 +312,89 @@ def float8_forw_backward_wrapper(x):
|
312 | 312 | # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
|
313 | 313 | # to populate triton kernel bandwidth further down in the script
|
314 | 314 | f = io.StringIO()
|
315 |
| - with redirect_stdout(f): |
316 |
| - # warm up |
317 |
| - for _ in range(1): |
| 315 | + try: |
| 316 | + with redirect_stdout(f): |
| 317 | + # warm up |
| 318 | + for _ in range(1): |
| 319 | + if dtype_filter != "float8": |
| 320 | + ref_forw_backward(input_tensor) |
| 321 | + if dtype_filter != "bfloat16": |
| 322 | + float8_forw_backward_wrapper(input_tensor) |
| 323 | + |
| 324 | + profile_iters = 5 |
| 325 | + ref_times, float8_times = None, None |
| 326 | + data = [] |
| 327 | + |
| 328 | + num_leaf_tensors = 1 + len(list(m_ref.parameters())) |
| 329 | + |
318 | 330 | if dtype_filter != "float8":
|
319 |
| - ref_forw_backward(input_tensor) |
320 |
| - if dtype_filter != "bfloat16": |
321 |
| - float8_forw_backward_wrapper(input_tensor) |
322 |
| - |
323 |
| - profile_iters = 5 |
324 |
| - ref_times, float8_times = None, None |
325 |
| - data = [] |
326 |
| - |
327 |
| - if dtype_filter != "float8": |
328 |
| - # Profile Reference Model |
329 |
| - print("profiling ref") |
330 |
| - ref_suffix = f"_{model_type}_ref_compile_{compile}.json" |
331 |
| - ref_path = profile_path_prefix + ref_suffix |
332 |
| - profile_config = ProfileConfig( |
333 |
| - ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True |
334 |
| - ) |
335 |
| - p = profile_function(profile_config, ref_forw_backward, input_tensor) |
336 |
| - print(f"saved {ref_path}") |
337 |
| - ref_times = profiler_output_to_time_by_kernel_name(p) |
338 |
| - total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters |
339 |
| - for k, v in ref_times.items(): |
340 |
| - v_ms = v / 1e3 / profile_iters |
341 |
| - data.append( |
342 |
| - [ |
343 |
| - "0_ref", |
344 |
| - k, |
345 |
| - kernel_name_to_category(k), |
346 |
| - v_ms, |
347 |
| - v_ms / total_time_ms, |
348 |
| - None, |
349 |
| - ] |
| 331 | + # Profile Reference Model |
| 332 | + print("profiling ref") |
| 333 | + ref_suffix = f"_{model_type}_ref_compile_{compile}.json" |
| 334 | + ref_path = profile_path_prefix + ref_suffix |
| 335 | + profile_config = ProfileConfig( |
| 336 | + ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True |
350 | 337 | )
|
| 338 | + p = profile_function(profile_config, ref_forw_backward, input_tensor) |
| 339 | + print(f"saved {ref_path}") |
| 340 | + ref_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors) |
| 341 | + total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters |
| 342 | + for k, v in ref_times.items(): |
| 343 | + v_ms = v / 1e3 / profile_iters |
| 344 | + data.append( |
| 345 | + [ |
| 346 | + "0_ref", |
| 347 | + k, |
| 348 | + kernel_name_to_category(k), |
| 349 | + v_ms, |
| 350 | + v_ms / total_time_ms, |
| 351 | + None, |
| 352 | + ] |
| 353 | + ) |
351 | 354 |
|
352 |
| - if dtype_filter != "bfloat16": |
353 |
| - # Profile Float8 Model |
354 |
| - print("profiling float8") |
355 |
| - float8_suffix = ( |
356 |
| - f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json" |
357 |
| - ) |
358 |
| - float8_path = profile_path_prefix + float8_suffix |
359 |
| - profile_config = ProfileConfig( |
360 |
| - float8_path, |
361 |
| - float8_suffix, |
362 |
| - iters=profile_iters, |
363 |
| - warmup_iters=2, |
364 |
| - sync=True, |
365 |
| - ) |
366 |
| - p = profile_function( |
367 |
| - profile_config, float8_forw_backward_wrapper, input_tensor |
368 |
| - ) |
369 |
| - print(f"saved {float8_path}") |
370 |
| - float8_times = profiler_output_to_time_by_kernel_name(p) |
371 |
| - total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters |
372 |
| - for k, v in float8_times.items(): |
373 |
| - v_ms = v / 1e3 / profile_iters |
374 |
| - data.append( |
375 |
| - [ |
376 |
| - "1_float8", |
377 |
| - k, |
378 |
| - kernel_name_to_category(k), |
379 |
| - v / 1e3 / profile_iters, |
380 |
| - v_ms / total_time_ms, |
381 |
| - None, |
382 |
| - ] |
| 355 | + if dtype_filter != "bfloat16": |
| 356 | + # Profile Float8 Model |
| 357 | + print("profiling float8") |
| 358 | + float8_suffix = ( |
| 359 | + f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json" |
383 | 360 | )
|
| 361 | + float8_path = profile_path_prefix + float8_suffix |
| 362 | + profile_config = ProfileConfig( |
| 363 | + float8_path, |
| 364 | + float8_suffix, |
| 365 | + iters=profile_iters, |
| 366 | + warmup_iters=2, |
| 367 | + sync=True, |
| 368 | + ) |
| 369 | + p = profile_function( |
| 370 | + profile_config, float8_forw_backward_wrapper, input_tensor |
| 371 | + ) |
| 372 | + print(f"saved {float8_path}") |
| 373 | + float8_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors) |
| 374 | + total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters |
| 375 | + for k, v in float8_times.items(): |
| 376 | + v_ms = v / 1e3 / profile_iters |
| 377 | + data.append( |
| 378 | + [ |
| 379 | + "1_float8", |
| 380 | + k, |
| 381 | + kernel_name_to_category(k), |
| 382 | + v / 1e3 / profile_iters, |
| 383 | + v_ms / total_time_ms, |
| 384 | + None, |
| 385 | + ] |
| 386 | + ) |
| 387 | + |
| 388 | + # get the time spent per user annotation |
| 389 | + sync_time_us = profiler_output_to_gpu_time_for_key( |
| 390 | + p, "scale_amax_and_scales" |
| 391 | + ) |
| 392 | + sync_time_ms = sync_time_us / profile_iters / 1e3 |
| 393 | + print(f"Sync time ms: {sync_time_ms}") |
384 | 394 |
|
385 |
| - # get the time spent per user annotation |
386 |
| - sync_time_us = profiler_output_to_gpu_time_for_key( |
387 |
| - p, "scale_amax_and_scales" |
388 |
| - ) |
389 |
| - sync_time_ms = sync_time_us / profile_iters / 1e3 |
390 |
| - print(f"Sync time ms: {sync_time_ms}") |
391 |
| - |
392 |
| - # print the redirected stdout back to regular stdout |
393 |
| - print(f.getvalue()) |
| 395 | + finally: |
| 396 | + # print the redirected stdout back to regular stdout |
| 397 | + print(f.getvalue()) |
394 | 398 |
|
395 | 399 | # populate the triton kernel bandwidth
|
396 | 400 | for line in f.getvalue().split("\n"):
|
|
0 commit comments