Skip to content

Commit 81cb119

Browse files
spcypptfacebook-github-bot
authored andcommitted
Compute info_B_num_bits from T to make it a constant (#3748)
Summary: X-link: facebookresearch/FBGEMM#829 ` b_t_map` contains information of batch (`b`) and feature (`t`). `info_B_num_bits` tells how many bits are used to cover batch information and is currently computed each iteration given the batch size. The `info_B_num_bits` calculation is problematic when `max_B_` is symbolic, causing issues with eagerAOT mode. If `max_B_` is symbolic, `info_B_num_bits` is not recomputed and uses the default value which can fail or if there is not enough bits for B. To resolve the issues, we can make `info_B_num_bits` constant. Current implementation adjusts `info_B_num_bits` based on the batch size, causing it to change every iteration. Fixing the values may cause the aforementioned issue of having insufficient bits for B. This diff implements `get_info_B_num_bits_from_T` to make `info_B_num_bits` constant. We first calculate how many bits required to cover `T` information, as number of features are known at TBE initialization and will remain the same throughout the run. The rest of the bits will be for `B` information. Since `info_T_num_bits` remains the same, `info_B_num_bits` remains the same. If there's not enough bits for B, it will fail. In V1 interface, since we hit the limit for the maximum number of arguments, we keep the interface the same. In V2 interface (next diff), we compute `info_B_num_bits` and `info_B_mask` once, store them as module parameters, and pass them to lookup and corresponding Autograd and backend functions. Reviewed By: sryap Differential Revision: D69387123
1 parent 69879df commit 81cb119

File tree

8 files changed

+87
-17
lines changed

8 files changed

+87
-17
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,8 @@ def generate() -> None:
413413
],
414414
"aux_int": [
415415
"iter", # 0
416+
"info_B_num_bits", # 1
417+
"info_B_mask", # 2
416418
],
417419
"aux_float": [
418420
"gwd_lower_bound", # 0

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -685,19 +685,8 @@ class {{ autograd_func }} :
685685
// Default values for Dynamo tracing
686686
// SymInt does not support bitshifts operator
687687
// Constanting info_B_num_bits, info_B_mask for Dynamo for now.
688-
int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS;
689-
uint32_t info_B_mask = (1u << info_B_num_bits) - 1;
690-
if (max_B_.is_symbolic()) {
691-
// int32_t info_B_num_bits = 22;
692-
// uint32_t info_B_mask = (1u << info_B_num_bits) - 1;
693-
694-
// TODO(ivankobzarev): Guarding Dynamo that T and B fits in constanted number of bits.
695-
// TORCH_CHECK(max_B_ < 1u << info_B_num_bits)
696-
// TORCH_CHECK(T < 1u << (DEFAULT_INFO_NUM_BITS - info_B_num_bits))
697-
} else {
698-
// TODO: don't guard here
699-
std::tie(info_B_num_bits, info_B_mask) = adjust_info_B_num_bits(max_B_.guard_int(__FILE__, __LINE__), T.guard_int(__FILE__, __LINE__));
700-
}
688+
const auto info_B_num_bits = static_cast<int32_t>(aux_int[IDX_INFO_B_NUM_BITS]);
689+
const auto info_B_mask = static_cast<uint32_t>(aux_int[IDX_INFO_B_MASK]);
701690

702691
{%- if vbe %}
703692
static auto generate_vbe_metadata_op =

fbgemm_gpu/codegen/training/python/lookup_args.template

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,36 @@ class OptimizerArgs(NamedTuple):
7777
regularization_mode: int
7878
use_rowwise_bias_correction: bool # Used for OptimType.ADAM
7979

80+
class CommonArgsPT2(NamedTuple):
81+
placeholder_autograd_tensor: torch.Tensor
82+
dev_weights: torch.Tensor
83+
host_weights: torch.Tensor
84+
uvm_weights: torch.Tensor
85+
lxu_cache_weights: torch.Tensor
86+
weights_placements: torch.Tensor
87+
weights_offsets: torch.Tensor
88+
D_offsets: torch.Tensor
89+
total_D: int
90+
max_D: int
91+
hash_size_cumsum: torch.Tensor
92+
total_hash_size_bits: int
93+
indices: torch.Tensor
94+
offsets: torch.Tensor
95+
pooling_mode: int
96+
indice_weights: Optional[torch.Tensor]
97+
feature_requires_grad: Optional[torch.Tensor]
98+
lxu_cache_locations: torch.Tensor
99+
uvm_cache_stats: Optional[torch.Tensor]
100+
output_dtype: int
101+
vbe_metadata: VBEMetadata
102+
is_experimental: bool
103+
use_uniq_cache_locations_bwd: bool
104+
use_homogeneous_placements: bool
105+
info_B_num_bits: int
106+
info_B_mask: int
107+
{%- if ssd %}
108+
ssd_tensors: Dict[str, torch.Tensor]
109+
{%- endif %}
80110

81111
class OptimizerArgsPT2(NamedTuple):
82112
"""

fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ std::tuple<int64_t, int64_t>
2222
get_infos_metadata(at::Tensor unused, int64_t B, int64_t T);
2323

2424
std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(int32_t B, int32_t T);
25+
std::tuple<int32_t, uint32_t> get_info_B_num_bits_from_T(int32_t T, int32_t B);
2526

2627
std::tuple<at::Tensor /*row_output_offsets*/, at::Tensor /*b_t_map*/>
2728
generate_vbe_metadata(

fbgemm_gpu/src/sparse_ops/sparse_batched_unary_embeddings.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ DLL_PUBLIC Tensor batched_unary_embeddings_backward_cuda(
186186

187187
int32_t info_B_num_bits;
188188
uint32_t info_B_mask;
189-
std::tie(info_B_num_bits, info_B_mask) = adjust_info_B_num_bits(B, T);
189+
std::tie(info_B_num_bits, info_B_mask) = get_info_B_num_bits_from_T(B, T);
190190

191191
// weight: [N, sum_E]
192192
// total_hash_size_bits = log2(sum_E)

fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ using namespace fbgemm_gpu;
1515

1616
DLL_PUBLIC std::tuple<int64_t, int64_t>
1717
get_infos_metadata(Tensor unused, int64_t B, int64_t T) {
18-
return adjust_info_B_num_bits(B, T);
18+
return get_info_B_num_bits_from_T(T, B);
1919
}

fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,53 @@
1313

1414
using Tensor = at::Tensor;
1515

16+
/// Find number of bits to accommodate this value
17+
///
18+
/// num_bits number of bits to needed to accommodate
19+
/// e.g., the function returns 3 if `n` is between 4 (100)
20+
/// and 7 (111) as 3 bits are required to represent the
21+
/// number.
22+
///
23+
/// @param n positive decimal number
24+
///
25+
DLL_PUBLIC int32_t get_num_bits(int32_t n) {
26+
TORCH_CHECK(n > 0, "Expect n to be positive but got ", n);
27+
return static_cast<int32_t>(std::floor(std::log2(n) + 1));
28+
}
29+
30+
/// Calculates number of bits to accommodate batch size (B) and table (T) from
31+
/// T. We first calculate how many bits needed for T and the rest is for B,
32+
/// since T does not change once TBE is initialized but B can change.
33+
///
34+
/// info_B_num_bits Number of bits needed for accommodate batch size
35+
/// info_B_mask Bit mask for information of B
36+
/// @param T Number of tables (features)
37+
/// @param B Batch size
38+
///
39+
DLL_PUBLIC std::tuple<int32_t, uint32_t> get_info_B_num_bits_from_T(
40+
int32_t T,
41+
int32_t B = 1) {
42+
TORCH_CHECK(B > 0, "B must be positive. Got B = ", B);
43+
TORCH_CHECK(T > 0, "T must be positive. Got T = ", T);
44+
const int32_t info_T_num_bits = get_num_bits(T);
45+
const int32_t info_B_num_bits = DEFAULT_INFO_NUM_BITS - info_T_num_bits;
46+
const uint32_t info_B_mask = (1u << info_B_num_bits) - 1;
47+
TORCH_CHECK(
48+
B <= info_B_mask,
49+
"Not enough infos bits to accommodate T and B. T = ",
50+
T,
51+
" takes ",
52+
info_T_num_bits,
53+
" and info_B_num_bits is ",
54+
info_B_num_bits,
55+
". Expect max_B = ",
56+
info_B_mask,
57+
"but got B ",
58+
B);
59+
60+
return {info_B_num_bits, info_B_mask};
61+
}
62+
1663
DLL_PUBLIC std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(
1764
int32_t B,
1865
int32_t T) {
@@ -79,7 +126,7 @@ generate_vbe_metadata_cpu(
79126

80127
std::tuple<int64_t, int64_t>
81128
get_infos_metadata_cpu(Tensor unused, int64_t B, int64_t T) {
82-
return adjust_info_B_num_bits(B, T);
129+
return get_info_B_num_bits_from_T(T, B);
83130
}
84131

85132
} // namespace

fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ def test_transpose(self, B: int, T: int, E: int) -> None:
190190
self.assertTrue(
191191
torch.equal(linear_indices_sorted.cpu(), linear_indices_sorted_ref)
192192
)
193-
self.assertTrue(torch.equal(infos_sorted.cpu(), infos_sorted_ref))
193+
infos_sorted = infos_sorted.cpu()
194+
self.assertTrue(torch.equal(infos_sorted, infos_sorted_ref.to(torch.int32)))
194195

195196
# fbgemm impl has padding so we need slice
196197
num = sorted_linear_indices_run_ref.numel()

0 commit comments

Comments
 (0)