Skip to content

Commit 259f931

Browse files
fix: clear remaining security-quality findings (#1219)
* test: avoid substring assertions in metadata URL coverage * test: normalize sarif formatter imports * fix: address remaining quality findings
1 parent 30f4ef2 commit 259f931

6 files changed

Lines changed: 252 additions & 284 deletions

File tree

packages/modelaudit-picklescan/src/modelaudit_picklescan/call_graph.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import sys
88
import sysconfig
9+
import threading
910
from collections import deque
1011
from collections.abc import Callable, Iterable, Iterator, Mapping
1112
from contextlib import contextmanager
@@ -14,8 +15,16 @@
1415
from functools import lru_cache
1516
from importlib.machinery import EXTENSION_SUFFIXES, BuiltinImporter, FrozenImporter, ModuleSpec, PathFinder
1617
from pathlib import Path
18+
from typing import Protocol, TypeVar, cast
1719

20+
# Bound per-pass import/callable fan-out for untrusted inputs. The 32-reference
21+
# cap has kept call-graph enrichment useful while preventing pathological scan
22+
# growth; raising it improves completeness at a runtime cost, lowering it can
23+
# reduce detection coverage.
1824
_MAX_IMPORT_REFERENCES = 32
25+
# Limit per-module source reads to 1 MiB so AST parsing remains bounded on large
26+
# inputs. This is an explicit coverage/performance tradeoff and can be tuned if
27+
# scan precision or throughput needs change.
1928
_MAX_SOURCE_BYTES = 1024 * 1024
2029
_MAX_CALL_GRAPH_DEPTH = 4
2130
_MAX_VISITED_FUNCTIONS = 64
@@ -45,6 +54,21 @@
4554
"_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH",
4655
default=0,
4756
)
57+
_SHARED_SOURCE_SENSITIVE_CACHE_LOCK = threading.RLock()
58+
_CachedFunctionT = TypeVar("_CachedFunctionT", bound=Callable[..., object])
59+
60+
61+
class _CacheClearable(Protocol):
62+
def cache_clear(self) -> None: ...
63+
64+
65+
_SOURCE_SENSITIVE_CACHED_FUNCTIONS: set[_CacheClearable] = set()
66+
67+
68+
def _register_source_sensitive_cache(function: _CachedFunctionT) -> _CachedFunctionT:
69+
_SOURCE_SENSITIVE_CACHED_FUNCTIONS.add(cast(_CacheClearable, function))
70+
return function
71+
4872

4973
_CLASS_ENTRYPOINT_METHODS = (
5074
"__getattribute__",
@@ -290,7 +314,9 @@ def find_dangerous_call_graphs(
290314
positional_arg_counts = _callable_invocation_positional_arg_counts(callable_invocations)
291315
callable_references = _iter_callable_invocation_references(callable_invocations)
292316
invoked_references = {
293-
(str(reference.get("module", "")), str(reference.get("name", ""))) for reference in callable_references
317+
(str(reference.get("module", "")), str(reference.get("name", "")))
318+
for reference in callable_references
319+
if str(reference.get("module", "")) and str(reference.get("name", ""))
294320
}
295321

296322
for reference in _iter_call_graph_references(import_references, callable_references, invoked_references):
@@ -454,15 +480,18 @@ def find_unanalyzed_callable_call_graph_references(
454480
@contextmanager
455481
def shared_source_sensitive_caches() -> Iterator[None]:
456482
"""Share one fresh cache generation across related enrichment passes."""
457-
if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() == 0:
458-
_clear_source_sensitive_caches_now()
459-
token = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.set(_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() + 1)
460-
try:
461-
yield
462-
finally:
463-
_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.reset(token)
483+
with _SHARED_SOURCE_SENSITIVE_CACHE_LOCK:
484+
depth = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get()
485+
if depth == 0:
486+
_clear_source_sensitive_caches_now()
487+
token = _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.set(depth + 1)
488+
try:
489+
yield
490+
finally:
491+
_SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.reset(token)
464492

465493

494+
@_register_source_sensitive_cache
466495
@lru_cache(maxsize=4096)
467496
def _safe_call_graph_entrypoints(function_name: str) -> tuple[str, ...]:
468497
try:
@@ -700,6 +729,7 @@ def _is_skippable_torch_extension_global_reference(module: str, name: str) -> bo
700729
return not _has_static_torch_extension_global_target(module, name)
701730

702731

732+
@_register_source_sensitive_cache
703733
@lru_cache(maxsize=256)
704734
def _has_static_torch_extension_global_target(module: str, name: str) -> bool:
705735
analysis = _analyze_module(module)
@@ -735,37 +765,14 @@ def has_unanalyzed_call_graph_import_references(import_references: object) -> bo
735765

736766

737767
def _clear_source_sensitive_caches() -> None:
738-
if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() > 0:
739-
return
740-
_clear_source_sensitive_caches_now()
768+
with _SHARED_SOURCE_SENSITIVE_CACHE_LOCK:
769+
if _SHARED_SOURCE_SENSITIVE_CACHE_DEPTH.get() > 0:
770+
return
771+
_clear_source_sensitive_caches_now()
741772

742773

743774
def _clear_source_sensitive_caches_now() -> None:
744-
for function in (
745-
_safe_call_graph_entrypoints,
746-
_has_static_torch_extension_global_target,
747-
_find_sink_path,
748-
_find_invoked_import_execution_path,
749-
_find_file_open_path,
750-
_find_file_write_path,
751-
_call_graph_entrypoints,
752-
_resolve_function_target,
753-
_resolve_wildcard_reexport_alias,
754-
_wildcard_export_summary,
755-
_resolve_class_target,
756-
_split_function_name,
757-
_module_source_context,
758-
_analyze_module,
759-
_source_function_context,
760-
_source_class_context,
761-
_constructor_parameter_self_attribute_targets,
762-
_iter_call_nodes,
763-
_collect_function_import_aliases,
764-
_parameter_controlled_names,
765-
_can_invoke_function_with_positional_args,
766-
_can_follow_import_execution_fallback,
767-
_resolve_module_source,
768-
):
775+
for function in _SOURCE_SENSITIVE_CACHED_FUNCTIONS:
769776
function.cache_clear()
770777

771778

@@ -842,6 +849,7 @@ def _find_meta_path_module_spec_without_imports(module_name: str) -> ModuleSpec
842849
return None
843850

844851

852+
@_register_source_sensitive_cache
845853
@lru_cache(maxsize=4096)
846854
def _find_sink_path(start: str) -> tuple[str, ...] | None:
847855
return _find_matching_call_path(
@@ -851,6 +859,7 @@ def _find_sink_path(start: str) -> tuple[str, ...] | None:
851859
)
852860

853861

862+
@_register_source_sensitive_cache
854863
@lru_cache(maxsize=4096)
855864
def _find_invoked_import_execution_path(
856865
start: str,
@@ -881,11 +890,13 @@ def _find_invoked_import_execution_path(
881890
return path
882891

883892

893+
@_register_source_sensitive_cache
884894
@lru_cache(maxsize=4096)
885895
def _find_file_open_path(start: str) -> tuple[str, ...] | None:
886896
return _find_matching_call_path(start, _file_open_sink)
887897

888898

899+
@_register_source_sensitive_cache
889900
@lru_cache(maxsize=4096)
890901
def _find_file_write_path(start: str) -> tuple[str, ...] | None:
891902
return _find_matching_call_path(start, _file_write_sink)
@@ -997,6 +1008,7 @@ def _calls_for_function(function_name: str) -> tuple[str, ...] | None:
9971008
return analysis.calls_by_function.get(f"{module_name}.{qualified_name}")
9981009

9991010

1011+
@_register_source_sensitive_cache
10001012
@lru_cache(maxsize=4096)
10011013
def _call_graph_entrypoints(function_name: str) -> tuple[str, ...]:
10021014
resolved = _resolve_function_target(function_name)
@@ -1009,6 +1021,7 @@ def _call_graph_entrypoints(function_name: str) -> tuple[str, ...]:
10091021
return _class_entrypoints(class_target)
10101022

10111023

1024+
@_register_source_sensitive_cache
10121025
@lru_cache(maxsize=4096)
10131026
def _resolve_function_target(function_name: str) -> str | None:
10141027
alias_target = _static_import_reference_alias(function_name)
@@ -1113,6 +1126,7 @@ def _static_import_reference_alias(function_name: str) -> str | None:
11131126
return None
11141127

11151128

1129+
@_register_source_sensitive_cache
11161130
@lru_cache(maxsize=4096)
11171131
def _resolve_wildcard_reexport_alias(module_name: str, qualified_name: str) -> str | None:
11181132
return _resolve_wildcard_reexport_alias_inner(module_name, qualified_name, set(), 0)
@@ -1150,6 +1164,7 @@ def _resolve_wildcard_reexport_alias_inner(
11501164
return None
11511165

11521166

1167+
@_register_source_sensitive_cache
11531168
@lru_cache(maxsize=4096)
11541169
def _wildcard_export_summary(module_name: str) -> _WildcardExportSummary | None:
11551170
context = _module_source_context(module_name)
@@ -1158,6 +1173,7 @@ def _wildcard_export_summary(module_name: str) -> _WildcardExportSummary | None:
11581173
return _collect_module_export_summary(context.module_statements, module_name, context.is_package)
11591174

11601175

1176+
@_register_source_sensitive_cache
11611177
@lru_cache(maxsize=4096)
11621178
def _resolve_class_target(function_name: str) -> str | None:
11631179
module_name, qualified_name = _split_function_name(function_name)
@@ -1193,6 +1209,7 @@ def _class_entrypoints(class_name: str) -> tuple[str, ...]:
11931209
return analysis.class_entrypoints.get(f"{module_name}.{qualified_name}", ())
11941210

11951211

1212+
@_register_source_sensitive_cache
11961213
@lru_cache(maxsize=4096)
11971214
def _split_function_name(function_name: str) -> tuple[str | None, str]:
11981215
parts = function_name.split(".")
@@ -1205,6 +1222,7 @@ def _split_function_name(function_name: str) -> tuple[str | None, str]:
12051222
return None, function_name
12061223

12071224

1225+
@_register_source_sensitive_cache
12081226
@lru_cache(maxsize=1024)
12091227
def _analyze_module(module_name: str) -> _ModuleAnalysis | None:
12101228
context = _module_source_context(module_name)
@@ -1259,6 +1277,7 @@ def _analyze_module(module_name: str) -> _ModuleAnalysis | None:
12591277
)
12601278

12611279

1280+
@_register_source_sensitive_cache
12621281
@lru_cache(maxsize=4096)
12631282
def _module_source_context(module_name: str) -> _ModuleSourceContext | None:
12641283
source_path = _resolve_module_source(module_name)
@@ -1367,6 +1386,7 @@ def _collect_import_aliases(nodes: Iterable[ast.AST], module_name: str, is_packa
13671386
return aliases
13681387

13691388

1389+
@_register_source_sensitive_cache
13701390
@lru_cache(maxsize=4096)
13711391
def _collect_function_import_aliases(
13721392
function_node: ast.FunctionDef | ast.AsyncFunctionDef,
@@ -1530,6 +1550,7 @@ def _class_source_context_for_target(
15301550
return _source_class_context(class_target)
15311551

15321552

1553+
@_register_source_sensitive_cache
15331554
@lru_cache(maxsize=4096)
15341555
def _source_function_context(
15351556
function_name: str,
@@ -1584,6 +1605,7 @@ def _inherited_source_function_context(
15841605
return inherited_method.module_name, inherited_is_package, inherited_method.node
15851606

15861607

1608+
@_register_source_sensitive_cache
15871609
@lru_cache(maxsize=4096)
15881610
def _source_class_context(class_name: str) -> _ClassSourceContext | None:
15891611
module_name, qualified_name = _split_source_qualified_name(class_name)
@@ -1927,6 +1949,7 @@ def _class_base_targets(
19271949
return tuple(targets)
19281950

19291951

1952+
@_register_source_sensitive_cache
19301953
@lru_cache(maxsize=4096)
19311954
def _constructor_parameter_self_attribute_targets(class_name: str, parameter_name: str) -> tuple[str, ...]:
19321955
module_name, qualified_name = _split_source_qualified_name(class_name)
@@ -2295,6 +2318,7 @@ def _has_required_user_arguments(function_node: ast.FunctionDef | ast.AsyncFunct
22952318
return has_required_keyword_only
22962319

22972320

2321+
@_register_source_sensitive_cache
22982322
@lru_cache(maxsize=4096)
22992323
def _can_invoke_function_with_positional_args(function_name: str, positional_arg_count: int) -> bool:
23002324
resolved = _resolve_function_target(function_name)
@@ -2307,6 +2331,7 @@ def _can_invoke_function_with_positional_args(function_name: str, positional_arg
23072331
return _can_enter_function_with_positional_args(function_node, positional_arg_count)
23082332

23092333

2334+
@_register_source_sensitive_cache
23102335
@lru_cache(maxsize=4096)
23112336
def _can_follow_import_execution_fallback(function_name: str, positional_arg_count: int) -> bool:
23122337
resolved = _resolve_function_target(function_name)
@@ -2608,6 +2633,7 @@ def _function_instance_alias_value(
26082633
return _resolve_class_target(call_target)
26092634

26102635

2636+
@_register_source_sensitive_cache
26112637
@lru_cache(maxsize=4096)
26122638
def _iter_call_nodes(function_node: ast.FunctionDef | ast.AsyncFunctionDef) -> tuple[ast.Call, ...]:
26132639
calls: list[ast.Call] = []
@@ -2647,6 +2673,7 @@ def _visit_nested_function_signature(self, node: ast.FunctionDef | ast.AsyncFunc
26472673
return tuple(calls)
26482674

26492675

2676+
@_register_source_sensitive_cache
26502677
@lru_cache(maxsize=4096)
26512678
def _parameter_controlled_names(function_node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
26522679
controlled = _initial_parameter_controlled_names(function_node)
@@ -2819,6 +2846,7 @@ def _resolve_import_from_module(module_name: str, is_package: bool, level: int,
28192846
return ".".join(part for part in parts if part)
28202847

28212848

2849+
@_register_source_sensitive_cache
28222850
@lru_cache(maxsize=1024)
28232851
def _resolve_module_source(module_name: str) -> Path | None:
28242852
parts = module_name.split(".")

packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,8 @@ def _global_call_payload(module: str, name: str, *arg_operands: bytes) -> bytes:
7070

7171

7272
def _clear_call_graph_caches() -> None:
73-
call_graph_module = sys.modules["modelaudit_picklescan.call_graph"]
74-
for function_name in (
75-
"_analyze_module",
76-
"_call_graph_entrypoints",
77-
"_find_sink_path",
78-
"_has_static_torch_extension_global_target",
79-
"_resolve_class_target",
80-
"_resolve_function_target",
81-
"_resolve_module_source",
82-
"_safe_call_graph_entrypoints",
83-
"_iter_call_nodes",
84-
"_collect_function_import_aliases",
85-
"_parameter_controlled_names",
86-
"_split_function_name",
87-
"_wildcard_export_summary",
88-
"_module_source_context",
89-
):
90-
cache_clear = getattr(getattr(call_graph_module, function_name), "cache_clear", None)
91-
if cache_clear is not None:
92-
cache_clear()
73+
for function in call_graph._SOURCE_SENSITIVE_CACHED_FUNCTIONS:
74+
function.cache_clear()
9375

9476

9577
def test_wildcard_summary_and_analysis_share_module_parse(

0 commit comments

Comments
 (0)