Skip to content

Commit 681d951

Browse files
kausvfacebook-github-bot
authored andcommitted
Simplify fqns_to_feature_names in Delta Tracker
Summary: Each table in collection will have FQN. We can simplify the logic with this assumption to avoid two iterations. Existing tests passed with this logic. Differential Revision: D76432354
1 parent ca19117 commit 681d951

File tree

1 file changed

+25
-51
lines changed

1 file changed

+25
-51
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# pyre-strict
99
import logging as logger
10-
from collections import Counter, OrderedDict
10+
from collections import OrderedDict
1111
from typing import Dict, Iterable, List, Optional, Union
1212

1313
import torch
@@ -32,7 +32,10 @@
3232
}
3333

3434
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
35-
SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection]
35+
SUPPORTED_MODULES_TO_PREFIX = {
36+
ShardedEmbeddingCollection: ".embeddings",
37+
ShardedEmbeddingBagCollection: ".embedding_bags",
38+
}
3639

3740

3841
class ModelDeltaTracker:
@@ -101,59 +104,30 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
101104
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
102105
return self._fqn_to_feature_map
103106

104-
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
105-
table_to_fqn: Dict[str, str] = OrderedDict()
107+
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
106108
for fqn, named_module in self._model.named_modules():
107-
split_fqn = fqn.split(".")
108-
109-
should_skip = False
110-
for fqn_to_skip in self._fqns_to_skip:
111-
if fqn_to_skip in split_fqn:
112-
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
113-
should_skip = True
114-
break
115-
if should_skip:
116-
continue
117-
118-
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119-
if isinstance(named_module, SUPPORTED_MODULES):
109+
110+
if type(named_module) in SUPPORTED_MODULES_TO_PREFIX:
120111
for table_name, config in named_module._table_name_to_config.items():
121-
logger.info(
122-
f"Found {table_name} for {fqn} with features {config.feature_names}"
112+
embedding_fqn = (
113+
fqn.replace("_dmp_wrapped_module.module.", "")
114+
+ SUPPORTED_MODULES_TO_PREFIX[type(named_module)]
115+
+ f".{table_name}"
123116
)
124-
table_to_feature_names[table_name] = config.feature_names
125-
for table_name in table_to_feature_names:
126-
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127-
# will incorrectly match fqn with all the table names that have the same prefix
128-
if table_name in split_fqn:
129-
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
130-
if table_name in table_to_fqn:
131-
# Sanity check for validating that we don't have more then one tbale mapping to same fqn.
132-
logger.warning(
133-
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
134-
)
135-
table_to_fqn[table_name] = embedding_fqn
136-
logger.info(f"Table to fqn: {table_to_fqn}")
137-
flatten_names = [
138-
name for names in table_to_feature_names.values() for name in names
139-
]
140-
# Some ads models have duplicate feature names, so we are relaxing the condition in case we come across duplicate feature names.
141-
if len(set(flatten_names)) != len(flatten_names):
142-
counts = Counter(flatten_names)
143-
duplicates = [item for item, count in counts.items() if count > 1]
144-
logger.warning(f"duplicate feature names found: {duplicates}")
145117

146-
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
147-
for table_name in table_to_feature_names:
148-
if table_name not in table_to_fqn:
149-
# This is likely unexpected, where we can't locate the FQN associated with this table.
150-
logger.warning(
151-
f"Table {table_name} not found in {table_to_fqn}, skipping"
152-
)
153-
continue
154-
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
155-
table_name
156-
]
118+
should_skip = False
119+
for fqn_to_skip in self._fqns_to_skip:
120+
if fqn_to_skip in embedding_fqn:
121+
logger.info(
122+
f"Skipping {fqn} because it is part of fqns_to_skip"
123+
)
124+
should_skip = True
125+
break
126+
if should_skip:
127+
continue
128+
if embedding_fqn not in fqn_to_feature_names:
129+
fqn_to_feature_names[embedding_fqn] = []
130+
fqn_to_feature_names[embedding_fqn].extend(config.feature_names)
157131
self._fqn_to_feature_map = fqn_to_feature_names
158132
return fqn_to_feature_names
159133

0 commit comments

Comments
 (0)