Skip to content

Commit 86cbee1

Browse files
JacoCheungclaude
andauthored
fix(dynamicemb): traverse nn.Module children in check_emb_collection_modules (#355)
* fix(dynamicemb): traverse nn.Module children in check_emb_collection_modules The function only checked direct attributes, missing EmbeddingCollection wrapped inside DistributedModelParallel or other nn.Module containers. Now recursively walks module.children() to find the embedding submodule. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * test(dynamicemb): verify get_dynamic_emb_module traverses DMP wrappers Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor(examples): drop dynamicemb_utils, use get_dynamic_emb_module check_emb_collection_modules now traverses nn.Module.children() and guards against cycles, so the examples-side duplicate find_dynamicemb_modules is redundant. Route both call sites (cache stats hook + FILL_DYNAMICEMB_TABLES fill path) through dynamicemb's public get_dynamic_emb_module and delete commons/utils/dynamicemb_utils.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(dynamicemb/test): rewrite check_counter_table_checkpoint for fused MultiTableKVCounter After the admission-counter fusion refactor (#343), `_admission_counter` became a single MultiTableKVCounter (or None) instead of a list of per-table KVCounters. The verification helper still iterated it as a list, so the dump/load counter check was silently no-op on any branch whose `get_dynamic_emb_module` could traverse into DMP wrappers and raised `TypeError: 'MultiTableKVCounter' object is not iterable` on any branch that could. PR #355's `children()` traversal surfaced the latter. Rewrite the check against the current API: - Iterate logical tables via `range(len(table._table_names))`. - Export one table_id's (keys, frequency) with `cnt.table_._batched_export_keys_scores([freq_name], device, table_id)`. - Look those keys up in the peer counter via `cnt_y.table_.lookup(keys, table_ids, score_arg)` and assert both `founds.all()` and that the returned score_out matches the exported frequency. - Handle the None (no-admission) case symmetrically. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3db734f commit 86cbee1

5 files changed

Lines changed: 113 additions & 97 deletions

File tree

corelib/dynamicemb/dynamicemb/dump_load.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import warnings
1919
from collections import deque
20-
from typing import Dict, List, Optional, Tuple
20+
from typing import Dict, List, Optional, Set, Tuple
2121

2222
import torch
2323
import torch.distributed as dist
@@ -50,31 +50,42 @@ def find_sharded_modules(
5050
return sharded_modules
5151

5252

53-
def check_emb_collection_modules(module: nn.Module, ret_list: List[nn.Module]):
53+
def check_emb_collection_modules(
54+
module: nn.Module,
55+
ret_list: List[nn.Module],
56+
visited: Optional[Set[int]] = None,
57+
):
58+
if visited is None:
59+
visited = set()
60+
61+
mid = id(module)
62+
if mid in visited:
63+
return
64+
visited.add(mid)
65+
5466
if isinstance(module, BatchedDynamicEmbeddingTablesV2):
5567
ret_list.append(module)
56-
return ret_list
68+
return
5769

5870
if isinstance(module, nn.Module):
59-
if hasattr(module, "_emb_module"):
60-
check_emb_collection_modules(module._emb_module, ret_list)
61-
62-
if hasattr(module, "_emb_modules"):
63-
check_emb_collection_modules(module._emb_modules, ret_list)
64-
65-
if hasattr(module, "_lookups"):
66-
tmp_module_list = module._lookups
67-
for tmp_emb_module in tmp_module_list:
68-
check_emb_collection_modules(tmp_emb_module, ret_list)
69-
70-
if isinstance(module, nn.ModuleList):
71-
for i in range(len(module)):
72-
tmp_emb_module = module[i]
73-
74-
if isinstance(tmp_emb_module, nn.Module):
75-
check_emb_collection_modules(tmp_emb_module, ret_list)
76-
else:
71+
# Follow TorchRec/DynamicEmb internal attributes.
72+
# _lookups is stored as a plain Python list (not nn.ModuleList) so
73+
# nn.Module.children() cannot discover it — traverse it explicitly.
74+
for attr in ("_lookups", "_emb_modules", "_emb_module"):
75+
child = getattr(module, attr, None)
76+
if child is None:
7777
continue
78+
if isinstance(child, (list, nn.ModuleList)):
79+
for item in child:
80+
check_emb_collection_modules(item, ret_list, visited)
81+
else:
82+
check_emb_collection_modules(child, ret_list, visited)
83+
84+
# Recurse into nn.Module children so we can traverse through
85+
# wrapper modules (DMP / DDP / Float16Module) that are not
86+
# covered by the private attributes above.
87+
for child in module.children():
88+
check_emb_collection_modules(child, ret_list, visited)
7889

7990

8091
def get_dynamic_emb_module(model: nn.Module) -> List[nn.Module]:

corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
DynamicEmbTableOptions,
2929
FrequencyAdmissionStrategy,
3030
)
31+
from dynamicemb.batched_dynamicemb_tables import BatchedDynamicEmbeddingTablesV2
3132
from dynamicemb.dump_load import (
3233
DynamicEmbDump,
3334
DynamicEmbLoad,
@@ -122,6 +123,54 @@ def assert_batched_dynamicemb_storage_class(
122123
)
123124

124125

126+
def assert_get_dynamic_emb_module_finds_submodules(model) -> None:
127+
"""Verify get_dynamic_emb_module discovers BatchedDynamicEmbeddingTablesV2.
128+
129+
Tests two paths:
130+
1. Via find_sharded_modules + get_dynamic_emb_module (existing usage)
131+
2. Via get_dynamic_emb_module directly on the DMP model (requires
132+
children() traversal through wrapper modules, the fix for #353)
133+
134+
Both must return the same set of modules.
135+
"""
136+
# Path 1: existing approach - find sharded modules first, then search each
137+
via_sharded = []
138+
for _, _, sharded_module in find_sharded_modules(model, ""):
139+
via_sharded.extend(get_dynamic_emb_module(sharded_module))
140+
141+
# Path 2: search directly on the DMP wrapper (requires children() traversal)
142+
via_dmp = get_dynamic_emb_module(model)
143+
144+
assert (
145+
len(via_sharded) > 0
146+
), "find_sharded_modules + get_dynamic_emb_module found no modules"
147+
assert (
148+
len(via_dmp) > 0
149+
), "get_dynamic_emb_module on DMP model found no modules (children() traversal broken)"
150+
151+
# Every module found via either path must be BatchedDynamicEmbeddingTablesV2
152+
for m in via_sharded:
153+
assert isinstance(
154+
m, BatchedDynamicEmbeddingTablesV2
155+
), f"Expected BatchedDynamicEmbeddingTablesV2, got {type(m)}"
156+
for m in via_dmp:
157+
assert isinstance(
158+
m, BatchedDynamicEmbeddingTablesV2
159+
), f"Expected BatchedDynamicEmbeddingTablesV2, got {type(m)}"
160+
161+
# Both paths must discover the exact same set of modules (by identity)
162+
ids_sharded = set(id(m) for m in via_sharded)
163+
ids_dmp = set(id(m) for m in via_dmp)
164+
assert ids_sharded == ids_dmp, (
165+
f"Module sets differ: via_sharded has {len(ids_sharded)} modules, "
166+
f"via_dmp has {len(ids_dmp)} modules"
167+
)
168+
169+
# No duplicates in either result
170+
assert len(via_sharded) == len(ids_sharded), "Duplicates in via_sharded path"
171+
assert len(via_dmp) == len(ids_dmp), "Duplicates in via_dmp path"
172+
173+
125174
def update_scores(
126175
score_strategy: str,
127176
expect_scores: Dict[int, int],
@@ -354,32 +403,46 @@ def create_model(
354403

355404

356405
def check_counter_table_checkpoint(x, y):
357-
device = torch.cuda.current_device()
406+
device = torch.device(f"cuda:{torch.cuda.current_device()}")
358407
tables_x = get_dynamic_emb_module(x)
359408
tables_y = get_dynamic_emb_module(y)
409+
assert len(tables_x) == len(tables_y)
360410

361411
for table_x, table_y in zip(tables_x, tables_y):
362-
for cnt_tx, cnt_ty in zip(
363-
table_x._admission_counter, table_y._admission_counter
364-
):
365-
assert cnt_tx.table_.size() == cnt_ty.table_.size()
366-
367-
for keys, named_scores, _ in cnt_tx._batched_export_keys_scores(
368-
cnt_tx.table_.score_names_, torch.device(f"cuda:{device}")
412+
cnt_x = table_x._admission_counter
413+
cnt_y = table_y._admission_counter
414+
if cnt_x is None:
415+
assert cnt_y is None
416+
continue
417+
assert cnt_x.table_.size() == cnt_y.table_.size()
418+
419+
freq_name = cnt_x.score_name_
420+
for table_id in range(len(table_x._table_names)):
421+
for keys, named_scores, _ in cnt_x.table_._batched_export_keys_scores(
422+
[freq_name], device, table_id
369423
):
370424
if keys.numel() == 0:
371425
continue
372-
freq_name = cnt_tx.table_.score_names_[0]
373426
frequencies = named_scores[freq_name]
374427

428+
lookup_table_ids = torch.full(
429+
(keys.numel(),), table_id, dtype=torch.int64, device=device
430+
)
375431
score_arg_lookup = ScoreArg(
376432
name=freq_name,
377433
value=torch.zeros_like(frequencies),
378434
policy=ScorePolicy.CONST,
379435
)
380-
_, founds, _ = cnt_ty.lookup(keys, score_arg_lookup)
381-
382-
assert torch.equal(frequencies, score_arg_lookup)
436+
score_out, founds, _ = cnt_y.table_.lookup(
437+
keys, lookup_table_ids, score_arg_lookup
438+
)
439+
assert founds.all(), (
440+
f"counter keys missing from loaded table_id={table_id}: "
441+
f"{keys[~founds].tolist()}"
442+
)
443+
assert torch.equal(
444+
frequencies, score_out
445+
), f"counter frequency mismatch for table_id={table_id}"
383446

384447

385448
@click.command()
@@ -441,6 +504,8 @@ def test_model_load_dump(
441504
),
442505
)
443506

507+
assert_get_dynamic_emb_module_finds_submodules(ref_model)
508+
444509
expect_scores_collection: Dict[str, Dict[int, int]] = {}
445510
kjts, feature_names, all_kjts = generate_sparse_feature(
446511
num_embedding_collections=num_embedding_collections,

examples/commons/utils/dynamicemb_cache_stats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import torch
2323
import torch.nn as nn
24-
from commons.utils.dynamicemb_utils import find_dynamicemb_modules
24+
from dynamicemb.dump_load import get_dynamic_emb_module
2525

2626

2727
class _CacheDebugHook:
@@ -64,7 +64,7 @@ def install_cache_debug_hooks(model: nn.Module) -> int:
6464
Number of hooks installed.
6565
"""
6666
count = 0
67-
modules = find_dynamicemb_modules(model)
67+
modules = get_dynamic_emb_module(model)
6868
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
6969
if rank == 0:
7070
print(f"[CACHE_DEBUG] Found {len(modules)} DynamicEmb module(s)", flush=True)

examples/commons/utils/dynamicemb_utils.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

examples/hstu/training/pretrain_gr_ranking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ def main():
159159
maybe_load_ckpts(trainer_args.ckpt_load_dir, model, dense_optimizer)
160160

161161
if os.environ.get("FILL_DYNAMICEMB_TABLES", "0") == "1":
162-
from commons.utils.dynamicemb_utils import find_dynamicemb_modules
162+
from dynamicemb.dump_load import get_dynamic_emb_module
163163

164-
for dyn_module in find_dynamicemb_modules(model_train):
164+
for dyn_module in get_dynamic_emb_module(model_train):
165165
if hasattr(dyn_module, "fill_tables"):
166166
try:
167167
dyn_module.fill_tables(load_factor=0.95)

0 commit comments

Comments
 (0)