Skip to content

Commit 4ab1dbd

Browse files
authored
Merge branch 'main' into pin-jinja2
2 parents 1ee1cfe + 52ba090 commit 4ab1dbd

File tree

1 file changed

+28
-18
lines changed

1 file changed

+28
-18
lines changed

torchvision/models/feature_extraction.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,23 @@ def _get_leaf_modules_for_ops() -> List[type]:
184184
return result
185185

186186

187+
def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
188+
default_autowrap_modules = (math, torchvision.ops)
189+
default_leaf_modules = _get_leaf_modules_for_ops()
190+
result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs
191+
result_tracer_kwargs["autowrap_modules"] = (
192+
tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules))
193+
if "autowrap_modules" in result_tracer_kwargs
194+
else default_autowrap_modules
195+
)
196+
result_tracer_kwargs["leaf_modules"] = (
197+
list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules))
198+
if "leaf_modules" in result_tracer_kwargs
199+
else default_leaf_modules
200+
)
201+
return result_tracer_kwargs
202+
203+
187204
def get_graph_node_names(
188205
model: nn.Module,
189206
tracer_kwargs: Optional[Dict[str, Any]] = None,
@@ -212,7 +229,11 @@ def get_graph_node_names(
212229
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
213230
``NodePathTracer`` (they are eventually passed onto
214231
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
215-
By default it will be set to wrap and make leaf nodes all torchvision ops.
232+
By default it will be set to wrap and make leaf nodes all torchvision ops:
233+
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
234+
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
235+
provided dictionary.
236+
216237
suppress_diff_warning (bool, optional): whether to suppress a warning
217238
when there are discrepancies between the train and eval version of
218239
the graph. Defaults to False.
@@ -226,14 +247,7 @@ def get_graph_node_names(
226247
>>> model = torchvision.models.resnet18()
227248
>>> train_nodes, eval_nodes = get_graph_node_names(model)
228249
"""
229-
if tracer_kwargs is None:
230-
tracer_kwargs = {
231-
"autowrap_modules": (
232-
math,
233-
torchvision.ops,
234-
),
235-
"leaf_modules": _get_leaf_modules_for_ops(),
236-
}
250+
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
237251
is_training = model.training
238252
train_tracer = NodePathTracer(**tracer_kwargs)
239253
train_tracer.trace(model.train())
@@ -378,7 +392,10 @@ def create_feature_extractor(
378392
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
379393
``NodePathTracer`` (which passes them onto it's parent class
380394
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
381-
By default it will be set to wrap and make leaf nodes all torchvision ops.
395+
By default it will be set to wrap and make leaf nodes all torchvision ops:
396+
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
397+
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
398+
provided dictionary.
382399
suppress_diff_warning (bool, optional): whether to suppress a warning
383400
when there are discrepancies between the train and eval version of
384401
the graph. Defaults to False.
@@ -423,14 +440,7 @@ def create_feature_extractor(
423440
>>> 'autowrap_functions': [leaf_function]})
424441
425442
"""
426-
if tracer_kwargs is None:
427-
tracer_kwargs = {
428-
"autowrap_modules": (
429-
math,
430-
torchvision.ops,
431-
),
432-
"leaf_modules": _get_leaf_modules_for_ops(),
433-
}
443+
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
434444
is_training = model.training
435445

436446
if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]):

0 commit comments

Comments
 (0)