|
7 | 7 |
|
8 | 8 | # pyre-strict
|
9 | 9 | import logging as logger
|
10 |
| -from collections import Counter, OrderedDict |
| 10 | +from collections import OrderedDict |
11 | 11 | from typing import Dict, Iterable, List, Optional, Union
|
12 | 12 |
|
13 | 13 | import torch
|
|
32 | 32 | }
|
33 | 33 |
|
34 | 34 | # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
|
35 |
| -SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection] |
| 35 | +SUPPORTED_MODULES_TO_PREFIX = { |
| 36 | + ShardedEmbeddingCollection: ".embeddings", |
| 37 | + ShardedEmbeddingBagCollection: ".embedding_bags", |
| 38 | +} |
36 | 39 |
|
37 | 40 |
|
38 | 41 | class ModelDeltaTracker:
|
@@ -100,61 +103,31 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
|
100 | 103 | if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
|
101 | 104 | return self._fqn_to_feature_map
|
102 | 105 |
|
103 |
| - table_to_feature_names: Dict[str, List[str]] = OrderedDict() |
104 |
| - table_to_fqn: Dict[str, str] = OrderedDict() |
| 106 | + fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() |
105 | 107 | for fqn, named_module in self._model.named_modules():
|
106 |
| - split_fqn = fqn.split(".") |
107 | 108 | # Skipping partial FQNs present in fqns_to_skip
|
108 | 109 | # TODO: Validate if we need to support more complex patterns for skipping fqns
|
109 |
| - should_skip = False |
110 |
| - for fqn_to_skip in self._fqns_to_skip: |
111 |
| - if fqn_to_skip in split_fqn: |
112 |
| - logger.info(f"Skipping {fqn} because it is part of fqns_to_skip") |
113 |
| - should_skip = True |
114 |
| - break |
115 |
| - if should_skip: |
116 |
| - continue |
117 |
| - |
118 |
| - # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. |
119 |
| - if isinstance(named_module, SUPPORTED_MODULES): |
| 110 | + if type(named_module) in SUPPORTED_MODULES_TO_PREFIX: |
120 | 111 | for table_name, config in named_module._table_name_to_config.items():
|
121 |
| - logger.info( |
122 |
| - f"Found {table_name} for {fqn} with features {config.feature_names}" |
| 112 | + embedding_fqn = ( |
| 113 | + fqn.replace("_dmp_wrapped_module.module.", "") |
| 114 | + + SUPPORTED_MODULES_TO_PREFIX[type(named_module)] |
| 115 | + + f".{table_name}" |
123 | 116 | )
|
124 |
| - table_to_feature_names[table_name] = config.feature_names |
125 |
| - for table_name in table_to_feature_names: |
126 |
| - # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn" |
127 |
| - # will incorrectly match fqn with all the table names that have the same prefix |
128 |
| - if table_name in split_fqn: |
129 |
| - embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "") |
130 |
| - if table_name in table_to_fqn: |
131 |
| - # Sanity check for validating that we don't have more then one table mapping to same fqn. |
132 |
| - logger.warning( |
133 |
| - f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" |
134 |
| - ) |
135 |
| - table_to_fqn[table_name] = embedding_fqn |
136 |
| - logger.info(f"Table to fqn: {table_to_fqn}") |
137 |
| - flatten_names = [ |
138 |
| - name for names in table_to_feature_names.values() for name in names |
139 |
| - ] |
140 |
| - # TODO: Validate if there is a better way to handle duplicate feature names. |
141 |
| - # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case. |
142 |
| - if len(set(flatten_names)) != len(flatten_names): |
143 |
| - counts = Counter(flatten_names) |
144 |
| - duplicates = [item for item, count in counts.items() if count > 1] |
145 |
| - logger.warning(f"duplicate feature names found: {duplicates}") |
146 | 117 |
|
147 |
| - fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() |
148 |
| - for table_name in table_to_feature_names: |
149 |
| - if table_name not in table_to_fqn: |
150 |
| - # This is likely unexpected, where we can't locate the FQN associated with this table. |
151 |
| - logger.warning( |
152 |
| - f"Table {table_name} not found in {table_to_fqn}, skipping" |
153 |
| - ) |
154 |
| - continue |
155 |
| - fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ |
156 |
| - table_name |
157 |
| - ] |
| 118 | + should_skip = False |
| 119 | + for fqn_to_skip in self._fqns_to_skip: |
| 120 | + if fqn_to_skip in embedding_fqn: |
| 121 | + logger.info( |
| 122 | + f"Skipping {fqn} because it is part of fqns_to_skip" |
| 123 | + ) |
| 124 | + should_skip = True |
| 125 | + break |
| 126 | + if should_skip: |
| 127 | + continue |
| 128 | + if embedding_fqn not in fqn_to_feature_names: |
| 129 | + fqn_to_feature_names[embedding_fqn] = [] |
| 130 | + fqn_to_feature_names[embedding_fqn].extend(config.feature_names) |
158 | 131 | self._fqn_to_feature_map = fqn_to_feature_names
|
159 | 132 | return fqn_to_feature_names
|
160 | 133 |
|
|
0 commit comments