@@ -184,6 +184,23 @@ def _get_leaf_modules_for_ops() -> List[type]:
184
184
return result
185
185
186
186
187
+ def _set_default_tracer_kwargs (original_tr_kwargs : Optional [Dict [str , Any ]]) -> Dict [str , Any ]:
188
+ default_autowrap_modules = (math , torchvision .ops )
189
+ default_leaf_modules = _get_leaf_modules_for_ops ()
190
+ result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs
191
+ result_tracer_kwargs ["autowrap_modules" ] = (
192
+ tuple (set (result_tracer_kwargs ["autowrap_modules" ] + default_autowrap_modules ))
193
+ if "autowrap_modules" in result_tracer_kwargs
194
+ else default_autowrap_modules
195
+ )
196
+ result_tracer_kwargs ["leaf_modules" ] = (
197
+ list (set (result_tracer_kwargs ["leaf_modules" ] + default_leaf_modules ))
198
+ if "leaf_modules" in result_tracer_kwargs
199
+ else default_leaf_modules
200
+ )
201
+ return result_tracer_kwargs
202
+
203
+
187
204
def get_graph_node_names (
188
205
model : nn .Module ,
189
206
tracer_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -212,7 +229,11 @@ def get_graph_node_names(
212
229
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
213
230
``NodePathTracer`` (they are eventually passed onto
214
231
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
215
- By default it will be set to wrap and make leaf nodes all torchvision ops.
232
+ By default it will be set to wrap and make leaf nodes all torchvision ops:
233
+ {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
234
+ WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
235
+ provided dictionary.
236
+
216
237
suppress_diff_warning (bool, optional): whether to suppress a warning
217
238
when there are discrepancies between the train and eval version of
218
239
the graph. Defaults to False.
@@ -226,14 +247,7 @@ def get_graph_node_names(
226
247
>>> model = torchvision.models.resnet18()
227
248
>>> train_nodes, eval_nodes = get_graph_node_names(model)
228
249
"""
229
- if tracer_kwargs is None :
230
- tracer_kwargs = {
231
- "autowrap_modules" : (
232
- math ,
233
- torchvision .ops ,
234
- ),
235
- "leaf_modules" : _get_leaf_modules_for_ops (),
236
- }
250
+ tracer_kwargs = _set_default_tracer_kwargs (tracer_kwargs )
237
251
is_training = model .training
238
252
train_tracer = NodePathTracer (** tracer_kwargs )
239
253
train_tracer .trace (model .train ())
@@ -378,7 +392,10 @@ def create_feature_extractor(
378
392
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
379
393
``NodePathTracer`` (which passes them onto it's parent class
380
394
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
381
- By default it will be set to wrap and make leaf nodes all torchvision ops.
395
+ By default it will be set to wrap and make leaf nodes all torchvision ops:
396
+ {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
397
+ WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
398
+ provided dictionary.
382
399
suppress_diff_warning (bool, optional): whether to suppress a warning
383
400
when there are discrepancies between the train and eval version of
384
401
the graph. Defaults to False.
@@ -423,14 +440,7 @@ def create_feature_extractor(
423
440
>>> 'autowrap_functions': [leaf_function]})
424
441
425
442
"""
426
- if tracer_kwargs is None :
427
- tracer_kwargs = {
428
- "autowrap_modules" : (
429
- math ,
430
- torchvision .ops ,
431
- ),
432
- "leaf_modules" : _get_leaf_modules_for_ops (),
433
- }
443
+ tracer_kwargs = _set_default_tracer_kwargs (tracer_kwargs )
434
444
is_training = model .training
435
445
436
446
if all (arg is None for arg in [return_nodes , train_return_nodes , eval_return_nodes ]):
0 commit comments