Skip to content

Commit 0b110fa

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Add logic for fqn_to_feature_names (#3059)
Summary: # This Diff Added implementation for fqn_to_feature_names method along with initial testing framework and UTs for fqn_to_feature_names # ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training Differential Revision: D75908963
1 parent a401ef1 commit 0b110fa

File tree

3 files changed

+570
-3
lines changed

3 files changed

+570
-3
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
from typing import Dict, List, Optional, Union
9+
import logging as logger
10+
from collections import Counter, OrderedDict
11+
from typing import Dict, Iterable, List, Optional, Union
1012

1113
import torch
1214

@@ -59,11 +61,15 @@ def __init__(
5961
consumers: Optional[List[str]] = None,
6062
delete_on_read: bool = True,
6163
mode: TrackingMode = TrackingMode.ID_ONLY,
64+
fqns_to_skip: Iterable[str] = (),
6265
) -> None:
6366
self._model = model
6467
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
6568
self._delete_on_read = delete_on_read
6669
self._mode = mode
70+
self._fqn_to_feature_map: Dict[str, List[str]] = {}
71+
self._fqns_to_skip: Iterable[str] = fqns_to_skip
72+
self.fqn_to_feature_names()
6773
pass
6874

6975
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
@@ -85,14 +91,71 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
8591
"""
8692
return {}
8793

88-
def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
94+
def fqn_to_feature_names(self) -> Dict[str, List[str]]:
8995
"""
9096
Returns a mapping from FQN to feature names for a given module.
9197
9298
Args:
9399
module (nn.Module): the module to retrieve feature names for.
94100
"""
95-
return {}
101+
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
102+
return self._fqn_to_feature_map
103+
104+
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
105+
table_to_fqn: Dict[str, str] = OrderedDict()
106+
for fqn, named_module in self._model.named_modules():
107+
split_fqn = fqn.split(".")
108+
109+
should_skip = False
110+
for fqn_to_skip in self._fqns_to_skip:
111+
if fqn_to_skip in split_fqn:
112+
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
113+
should_skip = True
114+
break
115+
if should_skip:
116+
continue
117+
118+
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119+
if isinstance(named_module, SUPPORTED_MODULES):
120+
for table_name, config in named_module._table_name_to_config.items():
121+
logger.info(
122+
f"Found {table_name} for {fqn} with features {config.feature_names}"
123+
)
124+
table_to_feature_names[table_name] = config.feature_names
125+
for table_name in table_to_feature_names:
126+
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127+
# will incorrectly match fqn with all the table names that have the same prefix
128+
if table_name in split_fqn:
129+
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
130+
if table_name in table_to_fqn:
131+
# Sanity check for validating that we don't have more then one tbale mapping to same fqn.
132+
logger.warning(
133+
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
134+
)
135+
table_to_fqn[table_name] = embedding_fqn
136+
logger.info(f"Table to fqn: {table_to_fqn}")
137+
flatten_names = [
138+
name for names in table_to_feature_names.values() for name in names
139+
]
140+
# Some ads models have duplicate feature names, so we are relaxing the condition in case we come across duplicate feature names.
141+
if len(set(flatten_names)) != len(flatten_names):
142+
counts = Counter(flatten_names)
143+
duplicates = [item for item, count in counts.items() if count > 1]
144+
logger.warning(f"duplicate feature names found: {duplicates}")
145+
146+
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
147+
for table_name in table_to_feature_names:
148+
if table_name not in table_to_fqn:
149+
# This is likely unexpected, where we can't locate the FQN associated with this table.
150+
logger.warning(
151+
f"Table {table_name} not found in {table_to_fqn}, skipping"
152+
)
153+
continue
154+
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
155+
table_name
156+
]
157+
self._fqn_to_feature_map = fqn_to_feature_names
158+
return fqn_to_feature_names
96159

97160
def clear(self, consumer: Optional[str] = None) -> None:
98161
"""
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
import os
10+
from dataclasses import dataclass
11+
from typing import cast, Dict, List
12+
13+
import torch
14+
import torchrec
15+
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
16+
17+
from parameterized import parameterized
18+
from torch import distributed as dist, nn
19+
from torch.testing._internal.common_distributed import MultiProcessTestCase
20+
from torchrec.distributed import DistributedModelParallel
21+
from torchrec.distributed.embedding import EmbeddingCollectionSharder
22+
from torchrec.distributed.embedding_types import ModuleSharder, ShardingType
23+
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
24+
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
25+
from torchrec.distributed.model_tracker.tests.utils import (
26+
EmbeddingTableProps,
27+
generate_planner_constraints,
28+
TestEBCModel,
29+
TestECModel,
30+
)
31+
32+
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
33+
from torchrec.modules.embedding_configs import (
34+
EmbeddingBagConfig,
35+
EmbeddingConfig,
36+
PoolingType,
37+
)
38+
39+
NUM_EMBEDDINGS: int = 16
40+
EMBEDDING_DIM: int = 256
41+
42+
43+
class ModelDeltaTrackerTest(MultiProcessTestCase):
44+
# pyre-fixme[2]: Parameter must be annotated.
45+
def __init__(self, methodName="runTest") -> None:
46+
super().__init__(methodName)
47+
48+
@property
49+
def world_size(self) -> int:
50+
return 2
51+
52+
def setUp(self) -> None:
53+
super().setUp()
54+
self._spawn_processes()
55+
56+
def tearDown(self) -> None:
57+
super().tearDown()
58+
try:
59+
os.remove(self.file_name)
60+
except OSError:
61+
pass
62+
63+
def _get_store(self) -> dist.FileStore:
64+
return dist.FileStore(self.file_name, self.world_size)
65+
66+
def _get_process_group(self) -> dist.ProcessGroup:
67+
store = self._get_store()
68+
dist.init_process_group(
69+
"nccl", store=store, rank=self.rank, world_size=self.world_size
70+
)
71+
return dist.distributed_c10d._get_default_group()
72+
73+
def _get_models(
74+
self,
75+
embedding_type: str,
76+
tables: Dict[str, EmbeddingTableProps],
77+
optimizer_type: OptimType = OptimType.ADAM,
78+
) -> DistributedModelParallel:
79+
torch.manual_seed(0)
80+
torch.cuda.set_device(self.rank)
81+
pg = self._get_process_group()
82+
test_model = (
83+
TestECModel(
84+
tables=[
85+
EmbeddingConfig(
86+
name=table_name,
87+
embedding_dim=table.embedding_dim,
88+
num_embeddings=table.num_embeddings,
89+
feature_names=table.feature_names,
90+
)
91+
for table_name, table in tables.items()
92+
]
93+
)
94+
if embedding_type == "EC"
95+
else TestEBCModel(
96+
tables=[
97+
EmbeddingBagConfig(
98+
name=table_name,
99+
embedding_dim=table.embedding_dim,
100+
num_embeddings=table.num_embeddings,
101+
feature_names=table.feature_names,
102+
pooling=table.pooling,
103+
)
104+
for table_name, table in tables.items()
105+
]
106+
)
107+
)
108+
planner = EmbeddingShardingPlanner(
109+
topology=Topology(self.world_size, "cuda"),
110+
constraints=generate_planner_constraints(tables),
111+
)
112+
sharders = [
113+
cast(
114+
ModuleSharder[nn.Module],
115+
EmbeddingCollectionSharder(
116+
fused_params={
117+
"optimizer": optimizer_type,
118+
"beta1": 0.9,
119+
"beta2": 0.99,
120+
}
121+
),
122+
),
123+
cast(
124+
ModuleSharder[nn.Module],
125+
EmbeddingBagCollectionSharder(
126+
fused_params={"optimizer": optimizer_type}
127+
),
128+
),
129+
]
130+
plan = planner.collective_plan(test_model, sharders, pg)
131+
return DistributedModelParallel(
132+
module=test_model,
133+
device=torch.device(f"cuda:{self.rank}"),
134+
env=torchrec.distributed.ShardingEnv.from_process_group(pg),
135+
plan=plan,
136+
sharders=sharders,
137+
)
138+
139+
@dataclass
140+
class ModelDeltaTrackerInputTestParams:
141+
# input parameters
142+
embedding_type: str
143+
embedding_tables: Dict[str, EmbeddingTableProps]
144+
fqns_to_skip: List[str]
145+
146+
@dataclass
147+
class FqnToFeatureNamesOutputTestParams:
148+
# expected output parameters
149+
expected_fqn_to_feature_names: Dict[str, List[str]]
150+
151+
@parameterized.expand(
152+
[
153+
(
154+
"EC_model_test",
155+
ModelDeltaTrackerInputTestParams(
156+
embedding_type="EC",
157+
embedding_tables={
158+
"sparse_table_1": EmbeddingTableProps(
159+
num_embeddings=NUM_EMBEDDINGS,
160+
embedding_dim=EMBEDDING_DIM,
161+
sharding=ShardingType.ROW_WISE,
162+
feature_names=["f1", "f2", "f3"],
163+
pooling=PoolingType.NONE,
164+
),
165+
"sparse_table_2": EmbeddingTableProps(
166+
num_embeddings=NUM_EMBEDDINGS,
167+
embedding_dim=EMBEDDING_DIM,
168+
sharding=ShardingType.ROW_WISE,
169+
feature_names=["f4", "f5", "f6"],
170+
pooling=PoolingType.NONE,
171+
),
172+
},
173+
fqns_to_skip=[],
174+
),
175+
FqnToFeatureNamesOutputTestParams(
176+
expected_fqn_to_feature_names={
177+
"ec.embeddings.sparse_table_1": ["f1", "f2", "f3"],
178+
"ec.embeddings.sparse_table_2": ["f4", "f5", "f6"],
179+
},
180+
),
181+
),
182+
(
183+
"EBC_model_test",
184+
ModelDeltaTrackerInputTestParams(
185+
embedding_type="EBC",
186+
embedding_tables={
187+
"sparse_table_1": EmbeddingTableProps(
188+
num_embeddings=NUM_EMBEDDINGS,
189+
embedding_dim=EMBEDDING_DIM,
190+
sharding=ShardingType.ROW_WISE,
191+
feature_names=["f1", "f2", "f3"],
192+
pooling=PoolingType.SUM,
193+
),
194+
"sparse_table_2": EmbeddingTableProps(
195+
num_embeddings=NUM_EMBEDDINGS,
196+
embedding_dim=EMBEDDING_DIM,
197+
sharding=ShardingType.ROW_WISE,
198+
feature_names=["f4", "f5", "f6"],
199+
pooling=PoolingType.SUM,
200+
),
201+
},
202+
fqns_to_skip=[],
203+
),
204+
FqnToFeatureNamesOutputTestParams(
205+
expected_fqn_to_feature_names={
206+
"ebc.embedding_bags.sparse_table_1": ["f1", "f2", "f3"],
207+
"ebc.embedding_bags.sparse_table_2": ["f4", "f5", "f6"],
208+
},
209+
),
210+
),
211+
(
212+
"EC_model_test_with_duplicate_feature_names",
213+
ModelDeltaTrackerInputTestParams(
214+
embedding_type="EC",
215+
embedding_tables={
216+
"sparse_table_1": EmbeddingTableProps(
217+
num_embeddings=NUM_EMBEDDINGS,
218+
embedding_dim=EMBEDDING_DIM,
219+
sharding=ShardingType.ROW_WISE,
220+
feature_names=["f1", "f2", "f3"],
221+
pooling=PoolingType.NONE,
222+
),
223+
"sparse_table_2": EmbeddingTableProps(
224+
num_embeddings=NUM_EMBEDDINGS,
225+
embedding_dim=EMBEDDING_DIM,
226+
sharding=ShardingType.ROW_WISE,
227+
feature_names=["f3", "f4", "f5"],
228+
pooling=PoolingType.NONE,
229+
),
230+
},
231+
fqns_to_skip=[],
232+
),
233+
FqnToFeatureNamesOutputTestParams(
234+
expected_fqn_to_feature_names={
235+
"ec.embeddings.sparse_table_1": ["f1", "f2", "f3"],
236+
"ec.embeddings.sparse_table_2": ["f3", "f4", "f5"],
237+
},
238+
),
239+
),
240+
(
241+
"EBC_model_test_fqns_to_skip",
242+
ModelDeltaTrackerInputTestParams(
243+
embedding_type="EBC",
244+
embedding_tables={
245+
"sparse_table_1": EmbeddingTableProps(
246+
num_embeddings=NUM_EMBEDDINGS,
247+
embedding_dim=EMBEDDING_DIM,
248+
sharding=ShardingType.ROW_WISE,
249+
feature_names=["f1", "f2", "f3"],
250+
pooling=PoolingType.SUM,
251+
),
252+
"sparse_table_2": EmbeddingTableProps(
253+
num_embeddings=NUM_EMBEDDINGS,
254+
embedding_dim=EMBEDDING_DIM,
255+
sharding=ShardingType.ROW_WISE,
256+
feature_names=["f4", "f5", "f6"],
257+
pooling=PoolingType.SUM,
258+
),
259+
},
260+
fqns_to_skip=["sparse_table_1"],
261+
),
262+
FqnToFeatureNamesOutputTestParams(
263+
expected_fqn_to_feature_names={
264+
"ebc.embedding_bags.sparse_table_2": ["f4", "f5", "f6"],
265+
},
266+
),
267+
),
268+
]
269+
)
270+
def test_fqn_to_feature_names(
271+
self,
272+
_test_name: str,
273+
input_params: ModelDeltaTrackerInputTestParams,
274+
output_params: FqnToFeatureNamesOutputTestParams,
275+
) -> None:
276+
model = self._get_models(
277+
input_params.embedding_type, input_params.embedding_tables
278+
)
279+
model_dt = ModelDeltaTracker(model, fqns_to_skip=input_params.fqns_to_skip)
280+
self.assertEqual(
281+
model_dt.fqn_to_feature_names(), output_params.expected_fqn_to_feature_names
282+
)

0 commit comments

Comments
 (0)