6
6
# LICENSE file in the root directory of this source tree.
7
7
8
8
# pyre-strict
9
- from typing import Dict , List , Optional , Union
9
+ import logging as logger
10
+ from collections import Counter , OrderedDict
11
+ from typing import Dict , Iterable , List , Optional , Union
10
12
11
13
import torch
12
14
@@ -59,11 +61,15 @@ def __init__(
59
61
consumers : Optional [List [str ]] = None ,
60
62
delete_on_read : bool = True ,
61
63
mode : TrackingMode = TrackingMode .ID_ONLY ,
64
+ fqns_to_skip : Iterable [str ] = (),
62
65
) -> None :
63
66
self ._model = model
64
67
self ._consumers : List [str ] = consumers or [self .DEFAULT_CONSUMER ]
65
68
self ._delete_on_read = delete_on_read
66
69
self ._mode = mode
70
+ self ._fqn_to_feature_map : Dict [str , List [str ]] = {}
71
+ self ._fqns_to_skip : Iterable [str ] = fqns_to_skip
72
+ self .fqn_to_feature_names ()
67
73
pass
68
74
69
75
def record_lookup (self , kjt : KeyedJaggedTensor , states : torch .Tensor ) -> None :
@@ -85,14 +91,71 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
85
91
"""
86
92
return {}
87
93
88
- def fqn_to_feature_names (self , module : nn . Module ) -> Dict [str , List [str ]]:
94
+ def fqn_to_feature_names (self ) -> Dict [str , List [str ]]:
89
95
"""
90
96
Returns a mapping from FQN to feature names for a given module.
91
97
92
98
Args:
93
99
module (nn.Module): the module to retrieve feature names for.
94
100
"""
95
- return {}
101
+ if (self ._fqn_to_feature_map is not None ) and len (self ._fqn_to_feature_map ) > 0 :
102
+ return self ._fqn_to_feature_map
103
+
104
+ table_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
105
+ table_to_fqn : Dict [str , str ] = OrderedDict ()
106
+ for fqn , named_module in self ._model .named_modules ():
107
+ split_fqn = fqn .split ("." )
108
+
109
+ should_skip = False
110
+ for fqn_to_skip in self ._fqns_to_skip :
111
+ if fqn_to_skip in split_fqn :
112
+ logger .info (f"Skipping { fqn } because it is part of fqns_to_skip" )
113
+ should_skip = True
114
+ break
115
+ if should_skip :
116
+ continue
117
+
118
+ # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119
+ if isinstance (named_module , SUPPORTED_MODULES ):
120
+ for table_name , config in named_module ._table_name_to_config .items ():
121
+ logger .info (
122
+ f"Found { table_name } for { fqn } with features { config .feature_names } "
123
+ )
124
+ table_to_feature_names [table_name ] = config .feature_names
125
+ for table_name in table_to_feature_names :
126
+ # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127
+ # will incorrectly match fqn with all the table names that have the same prefix
128
+ if table_name in split_fqn :
129
+ embedding_fqn = fqn .replace ("_dmp_wrapped_module.module." , "" )
130
+ if table_name in table_to_fqn :
131
+ # Sanity check for validating that we don't have more then one tbale mapping to same fqn.
132
+ logger .warning (
133
+ f"Override { table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
134
+ )
135
+ table_to_fqn [table_name ] = embedding_fqn
136
+ logger .info (f"Table to fqn: { table_to_fqn } " )
137
+ flatten_names = [
138
+ name for names in table_to_feature_names .values () for name in names
139
+ ]
140
+ # Some ads models have duplicate feature names, so we are relaxing the condition in case we come across duplicate feature names.
141
+ if len (set (flatten_names )) != len (flatten_names ):
142
+ counts = Counter (flatten_names )
143
+ duplicates = [item for item , count in counts .items () if count > 1 ]
144
+ logger .warning (f"duplicate feature names found: { duplicates } " )
145
+
146
+ fqn_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
147
+ for table_name in table_to_feature_names :
148
+ if table_name not in table_to_fqn :
149
+ # This is likely unexpected, where we can't locate the FQN associated with this table.
150
+ logger .warning (
151
+ f"Table { table_name } not found in { table_to_fqn } , skipping"
152
+ )
153
+ continue
154
+ fqn_to_feature_names [table_to_fqn [table_name ]] = table_to_feature_names [
155
+ table_name
156
+ ]
157
+ self ._fqn_to_feature_map = fqn_to_feature_names
158
+ return fqn_to_feature_names
96
159
97
160
def clear (self , consumer : Optional [str ] = None ) -> None :
98
161
"""
0 commit comments