Add hash_roundrobin routing mode to mitigate modulo-aliasing imbalance#367
Add hash_roundrobin routing mode to mitigate modulo-aliasing imbalance#367shijieliu merged 5 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds
Confidence Score: 4/5Safe to merge after fixing the stale local_rank reference in the module-level rank_print; core routing and checkpoint logic is sound. One P1 defect in example.py: the module-level rank_print closure references local_rank which is no longer in module scope, making builtins.print broken between import and init_runtime(). All other changes — CUDA kernel, checkpoint metadata, dist_type validation, and tests — are correct and well-tested. corelib/dynamicemb/example/example.py (stale local_rank reference in module-level rank_print) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[bucketize_kjt_before_all2all] --> B{dist_type_str per feature}
B -->|continuous| C[dist_type = 0]
B -->|roundrobin| D[dist_type = 1]
B -->|hash_roundrobin| E[dist_type = 2]
C --> F[torch.int32 tensor]
D --> F
E --> F
F --> G[CUDA: block_bucketize_sparse_features]
G --> H{dist_type per index}
H -->|0| I[p = idx / blk_size or idx % my_size]
H -->|1| J[p = idx % my_size]
H -->|2| K[p = hash_key(idx) % my_size]
I --> L[new_lengths / new_indices]
J --> L
K --> L
subgraph Checkpoint Safety
M[DynamicEmbDump] -->|writes dist_type to meta.json| N[checkpoint]
N -->|load| O{dist_type match?}
O -->|yes| P[load succeeds]
O -->|no| Q[ValueError raised]
O -->|key absent - legacy| R[default to roundrobin]
end
|
|
hi @ShaobinChen-AH thanks for your contribution!
|
jiashuy
left a comment
There was a problem hiding this comment.
We have dist_type in DynamicEmbParameterSharding, which is not exposed to users.
So if you want to use hash_roundrobin, we have two choice:
- expose dist_type in DynamicEmbTableOptions, and make
roundrobinas default value - use
hash_roundrobinin defalut here, and adjust our tests who viewdist_typeasroundrobin
dist_type is now exposed via DynamicEmbTableOptions, with roundrobin kept as the default for compatibility. The planner now reads opts.dist_type instead of hardcoding the routing mode, so hash_roundrobin is an explicit opt-in path. |
I updated dump/load to persist dist_type in checkpoint metadata and validate it at load time, so mismatched input-distribution settings now fail loudly instead of silently loading. I also added end-to-end validation through the user-facing path (DynamicEmbTableOptions(dist_type="hash_roundrobin") -> planner/sharding/input-dist -> dump/load smoke), in addition to the existing kernel/parity benchmark. |
5e9efdc to
009c5e3
Compare
|
thanks! @ShaobinChen-AH could you also help update the example to demonstrate how to use different input_dist type? |
|
/build |
I updated the DynamicEmb MovieLens example to demonstrate different input distribution policies via a new --dist_type CLI option (continuous / roundrobin / hash_roundrobin, default roundrobin). ALL PASSED. |
|
/build |
|
/build |
|
❌ Pipeline #49661872 -- failed
Result: 11/14 jobs passed |
Description
Checklist
Summary
This PR adds
hash_roundrobinas a new DynamicEmb RW routing mode and makes it the default for DynamicEmb row-wise planning.The goal is narrow: fix load imbalance caused by pathological raw-key patterns that can break plain modulo-based
roundrobin. The new mode hashes the raw key first, then assigns the owner rank from the hashed key.This PR does not claim to solve general hot-key or Zipf-skew load balancing.
Changes
hash_roundrobinsupport in the DynamicEmb input-distribution pathhash_roundrobintodist_type = 2for the CUDA extension pathtest/unit_test.shflowdist_typevalues and clarify the intended scope ofhash_roundrobinWhy
Issue #350 points out that plain
roundrobincan become imbalanced when raw keys follow special patterns. Hashing the raw key before RW rank assignment makes the routing much less sensitive to those patterns while preserving the existing overall bucketization flow.Validation
Validated on a clean rebuild in the target Ubuntu Docker environment.
python3 -m pytest -svv test/unit_tests/test_hash_roundrobin_kuairand.py16 passedCUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes 1 --nproc_per_node 2 ./test/unit_tests/test_sequence_embedding_fw.py --print_sharding_plan --optimizer_type "adam" --use_index_dedup TrueNVIDIA RTX A6000Notes