Skip to content

Commit b25dec3

Browse files
sryapfacebook-github-bot
authored andcommitted
Use bounds_check_indices v2 on ROCm (#3916)
Summary: X-link: facebookresearch/FBGEMM#1005 Pull Request resolved: #3916 This diff forces using bounds_check_indices v2 on ROCm because ROCm has a constraint that the gridDim * blockDim has to be smaller than 2^32. The v1 kernel can be launched with gridDim * blockDim > 2^32 while the v2 kernel limits the gridDim size to 64 * # of SMs. Thus, its gridDim * blockDim is guaranteed to be smaller than 2^32 Reviewed By: q10, jianyuh, joebos Differential Revision: D72334377 fbshipit-source-id: 9c955b691e4462721d500b5b643e037d71e13e0c
1 parent 15a3e5c commit b25dec3

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,21 @@ void bounds_check_indices_cuda(
6060
const int64_t info_B_num_bits,
6161
const int64_t info_B_mask,
6262
const int8_t bounds_check_version) {
63+
#if USE_ROCM
64+
// Force using bounds_check_indices v2 on ROCm because ROCm has a constraint
65+
// that the gridDim * blockDim has to be smaller than 2^32. The v1 kernel can
66+
// be launched with gridDim * blockDim > 2^32 while the v2 kernel limits the
67+
// gridDim size to 64 * # of SMs. Thus, its gridDim * blockDim is guaranteed
68+
// to be smaller than 2^32
69+
const auto bounds_check_indices_fn = _bounds_check_indices_cuda_v2;
70+
#else
6371
TORCH_CHECK(bounds_check_version == 1 || bounds_check_version == 2);
64-
const static bool use_v2 =
65-
fbgemm_gpu::config::is_feature_enabled(
66-
fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2) ||
67-
bounds_check_version == 2;
68-
const auto bounds_check_indices_fn =
69-
use_v2 ? _bounds_check_indices_cuda_v2 : _bounds_check_indices_cuda_v1;
72+
const static bool use_v2 = fbgemm_gpu::config::is_feature_enabled(
73+
fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2);
74+
const auto bounds_check_indices_fn = (use_v2 || bounds_check_version == 2)
75+
? _bounds_check_indices_cuda_v2
76+
: _bounds_check_indices_cuda_v1;
77+
#endif
7078
bounds_check_indices_fn(
7179
rows_per_table,
7280
indices,

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ def bounds_check_indices_abstract(
829829
b_t_map: Optional[torch.Tensor] = None,
830830
info_B_num_bits: int = -1,
831831
info_B_mask: int = -1,
832+
bounds_check_version: int = 1,
832833
) -> None:
833834
"""
834835
This meta function is used to fake the bounds checking

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,16 @@ def __init__( # noqa C901
716716
# See:
717717
# https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
718718
cache_precision = SparseType.FP32
719+
self.log("Override cache_precision=SparseType.FP32 on ROCm")
720+
721+
# NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a
722+
# constraint that the gridDim * blockDim has to be smaller than
723+
# 2^32. The v1 kernel can be launched with gridDim * blockDim >
724+
# 2^32 while the v2 kernel limits the gridDim size to 64 * # of
725+
# SMs. Thus, its gridDim * blockDim is guaranteed to be smaller
726+
# than 2^32
727+
self.bounds_check_version = 2
728+
self.log("Override bounds_check_version=2 on ROCm")
719729
else:
720730
# NOTE: The changes from D65865527 are retained here until we can
721731
# test that the the hack also works for non-ROCm environments.

0 commit comments

Comments
 (0)