|
7 | 7 |
|
8 | 8 | # pyre-strict
|
9 | 9 | import logging as logger
|
10 |
| -from collections import Counter, OrderedDict |
| 10 | +from collections import OrderedDict |
11 | 11 | from typing import Dict, Iterable, List, Optional, Union
|
12 | 12 |
|
13 | 13 | import torch
|
|
32 | 32 | }
|
33 | 33 |
|
34 | 34 | # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
|
35 |
| -SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection] |
| 35 | +SUPPORTED_MODULES_TO_PREFIX = { |
| 36 | + ShardedEmbeddingCollection: ".embeddings", |
| 37 | + ShardedEmbeddingBagCollection: ".embedding_bags", |
| 38 | +} |
36 | 39 |
|
37 | 40 |
|
38 | 41 | class ModelDeltaTracker:
|
@@ -101,59 +104,30 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
|
101 | 104 | if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
|
102 | 105 | return self._fqn_to_feature_map
|
103 | 106 |
|
104 |
| - table_to_feature_names: Dict[str, List[str]] = OrderedDict() |
105 |
| - table_to_fqn: Dict[str, str] = OrderedDict() |
| 107 | + fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() |
106 | 108 | for fqn, named_module in self._model.named_modules():
|
107 |
| - split_fqn = fqn.split(".") |
108 |
| - |
109 |
| - should_skip = False |
110 |
| - for fqn_to_skip in self._fqns_to_skip: |
111 |
| - if fqn_to_skip in split_fqn: |
112 |
| - logger.info(f"Skipping {fqn} because it is part of fqns_to_skip") |
113 |
| - should_skip = True |
114 |
| - break |
115 |
| - if should_skip: |
116 |
| - continue |
117 |
| - |
118 |
| - # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. |
119 |
| - if isinstance(named_module, SUPPORTED_MODULES): |
| 109 | + |
| 110 | + if type(named_module) in SUPPORTED_MODULES_TO_PREFIX: |
120 | 111 | for table_name, config in named_module._table_name_to_config.items():
|
121 |
| - logger.info( |
122 |
| - f"Found {table_name} for {fqn} with features {config.feature_names}" |
| 112 | + embedding_fqn = ( |
| 113 | + fqn.replace("_dmp_wrapped_module.module.", "") |
| 114 | + + SUPPORTED_MODULES_TO_PREFIX[type(named_module)] |
| 115 | + + f".{table_name}" |
123 | 116 | )
|
124 |
| - table_to_feature_names[table_name] = config.feature_names |
125 |
| - for table_name in table_to_feature_names: |
126 |
| - # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn" |
127 |
| - # will incorrectly match fqn with all the table names that have the same prefix |
128 |
| - if table_name in split_fqn: |
129 |
| - embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "") |
130 |
| - if table_name in table_to_fqn: |
131 |
| - # Sanity check for validating that we don't have more then one tbale mapping to same fqn. |
132 |
| - logger.warning( |
133 |
| - f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" |
134 |
| - ) |
135 |
| - table_to_fqn[table_name] = embedding_fqn |
136 |
| - logger.info(f"Table to fqn: {table_to_fqn}") |
137 |
| - flatten_names = [ |
138 |
| - name for names in table_to_feature_names.values() for name in names |
139 |
| - ] |
140 |
| - # Some ads models have duplicate feature names, so we are relaxing the condition in case we come across duplicate feature names. |
141 |
| - if len(set(flatten_names)) != len(flatten_names): |
142 |
| - counts = Counter(flatten_names) |
143 |
| - duplicates = [item for item, count in counts.items() if count > 1] |
144 |
| - logger.warning(f"duplicate feature names found: {duplicates}") |
145 | 117 |
|
146 |
| - fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() |
147 |
| - for table_name in table_to_feature_names: |
148 |
| - if table_name not in table_to_fqn: |
149 |
| - # This is likely unexpected, where we can't locate the FQN associated with this table. |
150 |
| - logger.warning( |
151 |
| - f"Table {table_name} not found in {table_to_fqn}, skipping" |
152 |
| - ) |
153 |
| - continue |
154 |
| - fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ |
155 |
| - table_name |
156 |
| - ] |
| 118 | + should_skip = False |
| 119 | + for fqn_to_skip in self._fqns_to_skip: |
| 120 | + if fqn_to_skip in embedding_fqn: |
| 121 | + logger.info( |
| 122 | + f"Skipping {fqn} because it is part of fqns_to_skip" |
| 123 | + ) |
| 124 | + should_skip = True |
| 125 | + break |
| 126 | + if should_skip: |
| 127 | + continue |
| 128 | + if embedding_fqn not in fqn_to_feature_names: |
| 129 | + fqn_to_feature_names[embedding_fqn] = [] |
| 130 | + fqn_to_feature_names[embedding_fqn].extend(config.feature_names) |
157 | 131 | self._fqn_to_feature_map = fqn_to_feature_names
|
158 | 132 | return fqn_to_feature_names
|
159 | 133 |
|
|
0 commit comments