Skip to content

Commit ca19117

Browse files
maliafzalfacebook-github-bot
authored andcommitted
Add logic for fqn_to_feature_names (#3059)
Summary: Pull Request resolved: #3059 # This Diff Added implementation for fqn_to_feature_names method along with initial testing framework and UTs for fqn_to_feature_names # ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training Differential Revision: D75908963
1 parent e62add5 commit ca19117

File tree

3 files changed

+588
-3
lines changed

3 files changed

+588
-3
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
from typing import Dict, List, Optional, Union
9+
import logging as logger
10+
from collections import Counter, OrderedDict
11+
from typing import Dict, Iterable, List, Optional, Union
1012

1113
import torch
1214

@@ -59,11 +61,15 @@ def __init__(
5961
consumers: Optional[List[str]] = None,
6062
delete_on_read: bool = True,
6163
mode: TrackingMode = TrackingMode.ID_ONLY,
64+
fqns_to_skip: Iterable[str] = (),
6265
) -> None:
6366
self._model = model
6467
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
6568
self._delete_on_read = delete_on_read
6669
self._mode = mode
70+
self._fqn_to_feature_map: Dict[str, List[str]] = {}
71+
self._fqns_to_skip: Iterable[str] = fqns_to_skip
72+
self.fqn_to_feature_names()
6773
pass
6874

6975
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
@@ -85,14 +91,71 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
8591
"""
8692
return {}
8793

88-
def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
94+
def fqn_to_feature_names(self) -> Dict[str, List[str]]:
8995
"""
9096
Returns a mapping from FQN to feature names for a given module.
9197
9298
Args:
9399
module (nn.Module): the module to retrieve feature names for.
94100
"""
95-
return {}
101+
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
102+
return self._fqn_to_feature_map
103+
104+
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
105+
table_to_fqn: Dict[str, str] = OrderedDict()
106+
for fqn, named_module in self._model.named_modules():
107+
split_fqn = fqn.split(".")
108+
109+
should_skip = False
110+
for fqn_to_skip in self._fqns_to_skip:
111+
if fqn_to_skip in split_fqn:
112+
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
113+
should_skip = True
114+
break
115+
if should_skip:
116+
continue
117+
118+
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119+
if isinstance(named_module, SUPPORTED_MODULES):
120+
for table_name, config in named_module._table_name_to_config.items():
121+
logger.info(
122+
f"Found {table_name} for {fqn} with features {config.feature_names}"
123+
)
124+
table_to_feature_names[table_name] = config.feature_names
125+
for table_name in table_to_feature_names:
126+
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127+
# will incorrectly match fqn with all the table names that have the same prefix
128+
if table_name in split_fqn:
129+
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
130+
if table_name in table_to_fqn:
131+
# Sanity check for validating that we don't have more then one tbale mapping to same fqn.
132+
logger.warning(
133+
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
134+
)
135+
table_to_fqn[table_name] = embedding_fqn
136+
logger.info(f"Table to fqn: {table_to_fqn}")
137+
flatten_names = [
138+
name for names in table_to_feature_names.values() for name in names
139+
]
140+
# Some ads models have duplicate feature names, so we are relaxing the condition in case we come across duplicate feature names.
141+
if len(set(flatten_names)) != len(flatten_names):
142+
counts = Counter(flatten_names)
143+
duplicates = [item for item, count in counts.items() if count > 1]
144+
logger.warning(f"duplicate feature names found: {duplicates}")
145+
146+
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
147+
for table_name in table_to_feature_names:
148+
if table_name not in table_to_fqn:
149+
# This is likely unexpected, where we can't locate the FQN associated with this table.
150+
logger.warning(
151+
f"Table {table_name} not found in {table_to_fqn}, skipping"
152+
)
153+
continue
154+
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
155+
table_name
156+
]
157+
self._fqn_to_feature_map = fqn_to_feature_names
158+
return fqn_to_feature_names
96159

97160
def clear(self, consumer: Optional[str] = None) -> None:
98161
"""

0 commit comments

Comments
 (0)