Skip to content

Commit aa52759

Browse files
perf: cache call graph call nodes (#1215)
1 parent d56eef2 commit aa52759

2 files changed

Lines changed: 35 additions & 0 deletions

File tree

packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ def _clear_source_sensitive_caches_now() -> None:
757757
_source_function_context,
758758
_source_class_context,
759759
_constructor_parameter_self_attribute_targets,
760+
_iter_call_nodes,
760761
_collect_function_import_aliases,
761762
_parameter_controlled_names,
762763
_can_invoke_function_with_positional_args,
@@ -2602,6 +2603,7 @@ def _function_instance_alias_value(
26022603
return _resolve_class_target(call_target)
26032604

26042605

2606+
@lru_cache(maxsize=4096)
26052607
def _iter_call_nodes(function_node: ast.FunctionDef | ast.AsyncFunctionDef) -> tuple[ast.Call, ...]:
26062608
calls: list[ast.Call] = []
26072609

packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from importlib.machinery import ModuleSpec
1515
from importlib.util import find_spec
1616
from pathlib import Path
17+
from typing import Any
1718

1819
import pytest
1920

@@ -79,6 +80,7 @@ def _clear_call_graph_caches() -> None:
7980
"_resolve_function_target",
8081
"_resolve_module_source",
8182
"_safe_call_graph_entrypoints",
83+
"_iter_call_nodes",
8284
"_collect_function_import_aliases",
8385
"_parameter_controlled_names",
8486
"_split_function_name",
@@ -140,6 +142,37 @@ def _env_without_pythonpath() -> dict[str, str]:
140142
return {key: value for key, value in os.environ.items() if key != "PYTHONPATH"}
141143

142144

145+
def test_iter_call_nodes_reuses_cached_walk(monkeypatch: pytest.MonkeyPatch) -> None:
146+
module = ast.parse(
147+
"""
148+
def bridge(target, command):
149+
callback = getattr(target, command)
150+
callback(command)
151+
"""
152+
)
153+
bridge = module.body[0]
154+
assert isinstance(bridge, ast.FunctionDef)
155+
156+
visits = 0
157+
original_visit = ast.NodeVisitor.visit
158+
159+
def counting_visit(self: ast.NodeVisitor, node: ast.AST) -> Any:
160+
nonlocal visits
161+
visits += 1
162+
return original_visit(self, node)
163+
164+
monkeypatch.setattr(ast.NodeVisitor, "visit", counting_visit)
165+
call_graph._iter_call_nodes.cache_clear()
166+
167+
first = call_graph._iter_call_nodes(bridge)
168+
first_visit_count = visits
169+
second = call_graph._iter_call_nodes(bridge)
170+
171+
assert first == second
172+
assert first_visit_count > 0
173+
assert visits == first_visit_count
174+
175+
143176
def test_collect_function_import_aliases_reuses_cached_walk(monkeypatch: pytest.MonkeyPatch) -> None:
144177
module = ast.parse(
145178
"""

0 commit comments

Comments
 (0)