|
21 | 21 | {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %}
|
22 | 22 |
|
23 | 23 | #include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
|
24 |
| -#include "fbgemm_gpu/utils/tensor_accessor.h" |
25 | 24 | #include "fbgemm_gpu/sparse_ops.h"
|
26 | 25 | #include "fbgemm_gpu/split_embeddings_utils.cuh"
|
| 26 | +#include "fbgemm_gpu/utils/barrier_isolation.cuh" |
27 | 27 | #include "fbgemm_gpu/utils/ops_utils.h"
|
28 |
| - |
| 28 | +#include "fbgemm_gpu/utils/tensor_accessor.h" |
29 | 29 | {%- if is_rocm %}
|
30 | 30 | #include "fbgemm_gpu/rocm/cdna_guard.h"
|
31 | 31 | {%- endif %}
|
@@ -790,32 +790,34 @@ Tensor {{ embedding_cuda_op }}(
|
790 | 790 | // {{ locs_or_addrs_tensor }} run ids and sorted_linear_indices run ids.
|
791 | 791 | auto dev_or_uvm_unique_indices = at::zeros_like(weights_placements);
|
792 | 792 |
|
| 793 | + DEBUG_KERNEL_BARRIER_ISOLATE([&] { |
793 | 794 | #ifdef FBGEMM_GPU_MEMCHECK
|
794 | 795 | const auto func_name = "split_embedding_backward_count_unique_indices_kernel";
|
795 | 796 | #endif
|
796 |
| - split_embedding_backward_count_unique_indices_kernel< |
797 |
| - {{ "int64_t" if nobag else "int32_t" }}, |
798 |
| - {{ "int64_t" if nobag else "uint32_t" }}, |
799 |
| - {{ "true" if nobag else "false" }} |
800 |
| - ><<< |
801 |
| - div_round_up(total_unique_indices, kMaxThreads), |
802 |
| - kMaxThreads, |
803 |
| - 0, |
804 |
| - at::cuda::getCurrentCUDAStream() |
805 |
| - >>>( |
806 |
| - MAKE_PTA_WITH_NAME( |
807 |
| - func_name, sorted_linear_indices_num_runs, int32_t, 1, 32), |
808 |
| - MAKE_PTA_WITH_NAME( |
809 |
| - func_name, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32), |
810 |
| - MAKE_PTA_WITH_NAME( |
811 |
| - func_name, infos_sorted, {{ "int64_t" if nobag else "int32_t" }}, 1, 32), |
812 |
| - MAKE_PTA_WITH_NAME( |
813 |
| - func_name, weights_placements, int32_t, 1, 32), |
814 |
| - MAKE_PTA_WITH_NAME( |
815 |
| - func_name, dev_or_uvm_unique_indices, int32_t, 1, 32), |
816 |
| - info_B_num_bits |
817 |
| - ); |
818 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 797 | + split_embedding_backward_count_unique_indices_kernel< |
| 798 | + {{ "int64_t" if nobag else "int32_t" }}, |
| 799 | + {{ "int64_t" if nobag else "uint32_t" }}, |
| 800 | + {{ "true" if nobag else "false" }} |
| 801 | + ><<< |
| 802 | + div_round_up(total_unique_indices, kMaxThreads), |
| 803 | + kMaxThreads, |
| 804 | + 0, |
| 805 | + at::cuda::getCurrentCUDAStream() |
| 806 | + >>>( |
| 807 | + MAKE_PTA_WITH_NAME( |
| 808 | + func_name, sorted_linear_indices_num_runs, int32_t, 1, 32), |
| 809 | + MAKE_PTA_WITH_NAME( |
| 810 | + func_name, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32), |
| 811 | + MAKE_PTA_WITH_NAME( |
| 812 | + func_name, infos_sorted, {{ "int64_t" if nobag else "int32_t" }}, 1, 32), |
| 813 | + MAKE_PTA_WITH_NAME( |
| 814 | + func_name, weights_placements, int32_t, 1, 32), |
| 815 | + MAKE_PTA_WITH_NAME( |
| 816 | + func_name, dev_or_uvm_unique_indices, int32_t, 1, 32), |
| 817 | + info_B_num_bits |
| 818 | + ); |
| 819 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 820 | + }); // DEBUG_KERNEL_BARRIER_ISOLATE |
819 | 821 |
|
820 | 822 | table_unique_indices_offsets =
|
821 | 823 | fbgemm_gpu::asynchronous_complete_cumsum_gpu(dev_or_uvm_unique_indices).to(at::kInt);
|
@@ -940,31 +942,32 @@ Tensor {{ embedding_cuda_op }}(
|
940 | 942 | grad_output_mean = at::empty_like(grad_output_reshaped);
|
941 | 943 | {%- if not dense or not vbe %}
|
942 | 944 |
|
| 945 | + DEBUG_KERNEL_BARRIER_ISOLATE([&] { |
943 | 946 | #ifdef FBGEMM_GPU_MEMCHECK
|
944 |
| - const auto func_name1 = "grad_mean{{ vdesc }}_kernel"; |
| 947 | + const auto func_name1 = "grad_mean{{ vdesc }}_kernel"; |
945 | 948 | #endif
|
| 949 | + grad_mean{{ vdesc }}_kernel<<< |
| 950 | + div_round_up(total_B, kMaxThreads / kWarpSize), |
| 951 | + dim3(kWarpSize, kMaxThreads / kWarpSize), |
| 952 | + 0, |
| 953 | + at::cuda::getCurrentCUDAStream()>>> |
| 954 | + ( |
| 955 | + MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64), |
| 956 | + MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64), |
| 957 | + MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32), |
| 958 | + MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32), |
| 959 | + {%- if vbe %} |
| 960 | + MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32), |
| 961 | + MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32), |
| 962 | + info_B_num_bits, |
| 963 | + info_B_mask |
| 964 | + {%- else %} |
| 965 | + FixedDivisor(total_B / T) |
| 966 | + {%- endif %} |
| 967 | + ); |
946 | 968 |
|
947 |
| - grad_mean{{ vdesc }}_kernel<<< |
948 |
| - div_round_up(total_B, kMaxThreads / kWarpSize), |
949 |
| - dim3(kWarpSize, kMaxThreads / kWarpSize), |
950 |
| - 0, |
951 |
| - at::cuda::getCurrentCUDAStream()>>> |
952 |
| - ( |
953 |
| - MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64), |
954 |
| - MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64), |
955 |
| - MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32), |
956 |
| - MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32), |
957 |
| - {%- if vbe %} |
958 |
| - MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32), |
959 |
| - MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32), |
960 |
| - info_B_num_bits, |
961 |
| - info_B_mask |
962 |
| - {%- else %} |
963 |
| - FixedDivisor(total_B / T) |
964 |
| - {%- endif %} |
965 |
| - ); |
966 |
| -
|
967 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 969 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 970 | + }); // DEBUG_KERNEL_BARRIER_ISOLATE |
968 | 971 | {%- endif %} // if not dense or not vbe
|
969 | 972 |
|
970 | 973 | grad_output_accessor = MAKE_PTA_WITH_NAME("{{ embedding_cuda_op }}.2", grad_output_mean, grad_t, 2, 64);
|
@@ -1005,27 +1008,29 @@ Tensor {{ embedding_cuda_op }}(
|
1005 | 1008 | use_deterministic_algorithms ? 0 : (indices.numel() / max_segment_length_per_cta),
|
1006 | 1009 | indices.options().dtype(at::kInt));
|
1007 | 1010 |
|
| 1011 | + DEBUG_KERNEL_BARRIER_ISOLATE([&] { |
1008 | 1012 | #ifdef FBGEMM_GPU_MEMCHECK
|
1009 |
| - const auto func_name2 = "split_embedding_backward_codegen_find_long_segments"; |
| 1013 | + const auto func_name2 = "split_embedding_backward_codegen_find_long_segments"; |
1010 | 1014 | #endif
|
1011 | 1015 |
|
1012 |
| - split_embedding_backward_codegen_find_long_segments<<< |
1013 |
| - div_round_up(total_unique_indices, kMaxThreads), |
1014 |
| - kMaxThreads, |
1015 |
| - 0, |
1016 |
| - at::cuda::getCurrentCUDAStream() |
1017 |
| - >>>( |
1018 |
| - MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_num_runs, int32_t, 1, 32), |
1019 |
| - MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_run_lengths, int32_t, 1, 32), |
1020 |
| - MAKE_PTA_WITH_NAME(func_name2, long_run_ids, int32_t, 1, 32), |
1021 |
| - MAKE_PTA_WITH_NAME(func_name2, num_long_run_ids, int32_t, 1, 32), |
1022 |
| - MAKE_PTA_WITH_NAME(func_name2, long_run_id_to_really_long_run_ids, int32_t, 1, 32), |
1023 |
| - MAKE_PTA_WITH_NAME(func_name2, num_really_long_run_ids, int32_t, 1, 32), |
1024 |
| - MAKE_PTA_WITH_NAME(func_name2, grad_accum_counter, int32_t, 1, 32), |
1025 |
| - max_segment_length_per_warp, |
1026 |
| - max_segment_length_per_cta, |
1027 |
| - use_deterministic_algorithms); |
1028 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 1016 | + split_embedding_backward_codegen_find_long_segments<<< |
| 1017 | + div_round_up(total_unique_indices, kMaxThreads), |
| 1018 | + kMaxThreads, |
| 1019 | + 0, |
| 1020 | + at::cuda::getCurrentCUDAStream() |
| 1021 | + >>>( |
| 1022 | + MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_num_runs, int32_t, 1, 32), |
| 1023 | + MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_run_lengths, int32_t, 1, 32), |
| 1024 | + MAKE_PTA_WITH_NAME(func_name2, long_run_ids, int32_t, 1, 32), |
| 1025 | + MAKE_PTA_WITH_NAME(func_name2, num_long_run_ids, int32_t, 1, 32), |
| 1026 | + MAKE_PTA_WITH_NAME(func_name2, long_run_id_to_really_long_run_ids, int32_t, 1, 32), |
| 1027 | + MAKE_PTA_WITH_NAME(func_name2, num_really_long_run_ids, int32_t, 1, 32), |
| 1028 | + MAKE_PTA_WITH_NAME(func_name2, grad_accum_counter, int32_t, 1, 32), |
| 1029 | + max_segment_length_per_warp, |
| 1030 | + max_segment_length_per_cta, |
| 1031 | + use_deterministic_algorithms); |
| 1032 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 1033 | + }); // DEBUG_KERNEL_BARRIER_ISOLATE |
1029 | 1034 |
|
1030 | 1035 | // A temp buffer to accumulate gradients with atomics.
|
1031 | 1036 | auto temp_grad_accum = at::zeros(
|
@@ -1079,8 +1084,9 @@ Tensor {{ embedding_cuda_op }}(
|
1079 | 1084 | div_round_up(total_unique_indices, kMaxThreads),
|
1080 | 1085 | get_max_thread_blocks_());
|
1081 | 1086 |
|
| 1087 | + DEBUG_KERNEL_BARRIER_ISOLATE([&] { |
1082 | 1088 | #ifdef FBGEMM_GPU_MEMCHECK
|
1083 |
| - const auto func_name3 = "{{ cta_kernel }}"; |
| 1089 | + const auto func_name3 = "{{ cta_kernel }}"; |
1084 | 1090 | #endif
|
1085 | 1091 | backward_cta_per_row_kernel
|
1086 | 1092 | <<<cta_per_row_grid_size,
|
@@ -1161,6 +1167,8 @@ Tensor {{ embedding_cuda_op }}(
|
1161 | 1167 | );
|
1162 | 1168 |
|
1163 | 1169 | C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1170 | + }); // DEBUG_KERNEL_BARRIER_ISOLATE |
| 1171 | +
|
1164 | 1172 | {%- set warp_kernel =
|
1165 | 1173 | "batch_index_select_dim0_codegen_backward_kernel_warp_per_row"
|
1166 | 1174 | if is_index_select else
|
@@ -1241,7 +1249,7 @@ Tensor {{ embedding_cuda_op }}(
|
1241 | 1249 | {%- endif %}
|
1242 | 1250 | #endif
|
1243 | 1251 |
|
1244 |
| -
|
| 1252 | + DEBUG_KERNEL_BARRIER_ISOLATE([&] { |
1245 | 1253 | #ifdef FBGEMM_GPU_MEMCHECK
|
1246 | 1254 | const auto func_name4 = "{{ warp_kernel }}";
|
1247 | 1255 | #endif
|
@@ -1316,6 +1324,8 @@ Tensor {{ embedding_cuda_op }}(
|
1316 | 1324 | {%- endif %}
|
1317 | 1325 | );
|
1318 | 1326 | C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1327 | +
|
| 1328 | + }); // DEBUG_KERNEL_BARRIER_ISOLATE |
1319 | 1329 | }); // DISPATCH_PLACEHOLDER_TYPES
|
1320 | 1330 | return;
|
1321 | 1331 |
|
|
0 commit comments