Skip to content

Commit 5aa48f2

Browse files
test: address AI findings in recent test changes (#1234)
* test: address AI findings in recent test changes * test: dedupe shared pickle child setup
1 parent a1efccb commit 5aa48f2

5 files changed

Lines changed: 92 additions & 66 deletions

File tree

packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ def test_wildcard_summary_and_analysis_share_module_parse(
8383
parse_calls = 0
8484
real_parse = call_graph.ast.parse
8585

86-
def tracking_parse(source: str, filename: str = "<unknown>") -> ast.Module:
86+
def tracking_parse(source_code: str, filename: str = "<unknown>") -> ast.Module:
8787
nonlocal parse_calls
8888
parse_calls += 1
89-
return real_parse(source, filename=filename)
89+
return real_parse(source_code, filename=filename)
9090

9191
monkeypatch.setattr(
9292
call_graph, "_resolve_module_source", lambda module_name: module_path if module_name == "module" else None
@@ -124,6 +124,21 @@ def _env_without_pythonpath() -> dict[str, str]:
124124
return {key: value for key, value in os.environ.items() if key != "PYTHONPATH"}
125125

126126

127+
def _pickle_exec_child_code(body: str) -> str:
128+
return f"""
129+
import pickle
130+
import sys
131+
from pathlib import Path
132+
133+
module_dir = Path(sys.argv[1])
134+
marker = Path(sys.argv[2])
135+
payload = bytes.fromhex(sys.argv[3])
136+
137+
sys.path.insert(0, str(module_dir))
138+
{body}
139+
"""
140+
141+
127142
def test_iter_call_nodes_reuses_cached_walk(monkeypatch: pytest.MonkeyPatch) -> None:
128143
module = ast.parse(
129144
"""
@@ -224,9 +239,12 @@ def counting_initial_parameter_controlled_names(
224239
def test_split_function_name_reuses_cached_resolution(monkeypatch: pytest.MonkeyPatch) -> None:
225240
analyze_calls: list[str] = []
226241

227-
def fake_analyze_module(module_name: str) -> object | None:
242+
class _AnalyzedModule:
243+
pass
244+
245+
def fake_analyze_module(module_name: str) -> _AnalyzedModule | None:
228246
analyze_calls.append(module_name)
229-
return object() if module_name == "pkg.mod" else None
247+
return _AnalyzedModule() if module_name == "pkg.mod" else None
230248

231249
monkeypatch.setattr(call_graph, "_analyze_module", fake_analyze_module)
232250
call_graph._split_function_name.cache_clear()
@@ -1479,20 +1497,13 @@ def test_scan_bytes_ignores_uninvoked_nested_function_body_calls(
14791497

14801498
assert report.verdict == SafetyVerdict.CLEAN
14811499
assert not _has_critical_call_graph_finding(report, module_name, function_name, "subprocess.run")
1482-
child_code = """
1483-
import pickle
1484-
import sys
1485-
from pathlib import Path
1486-
1487-
module_dir = Path(sys.argv[1])
1488-
marker = Path(sys.argv[2])
1489-
payload = bytes.fromhex(sys.argv[3])
1490-
1491-
sys.path.insert(0, str(module_dir))
1500+
child_code = _pickle_exec_child_code(
1501+
"""
14921502
pickle.loads(payload)
14931503
if marker.exists():
14941504
raise SystemExit("nested body unexpectedly executed")
14951505
"""
1506+
)
14961507
result = _run_python_subprocess(
14971508
[sys.executable, "-c", child_code, str(module_dir), str(marker), payload.hex()],
14981509
cwd=tmp_path.parent,
@@ -1529,22 +1540,15 @@ def test_scan_bytes_does_not_treat_newobj_as_init_invocation(
15291540

15301541
assert report.verdict == SafetyVerdict.CLEAN
15311542
assert not _has_critical_call_graph_finding(report, module_name, "InitImports", "os.system")
1532-
child_code = """
1533-
import pickle
1534-
import sys
1535-
from pathlib import Path
1536-
1537-
module_dir = Path(sys.argv[1])
1538-
marker = Path(sys.argv[2])
1539-
payload = bytes.fromhex(sys.argv[3])
1540-
1541-
sys.path.insert(0, str(module_dir))
1543+
child_code = _pickle_exec_child_code(
1544+
"""
15421545
result = pickle.loads(payload)
15431546
if getattr(result, "value", None) != "safe":
15441547
raise SystemExit(f"unexpected state: {result.__dict__!r}")
15451548
if marker.exists():
15461549
raise SystemExit("NEWOBJ unexpectedly executed __init__")
15471550
"""
1551+
)
15481552
result = _run_python_subprocess(
15491553
[sys.executable, "-c", child_code, str(module_dir), str(marker), payload.hex()],
15501554
cwd=tmp_path.parent,
@@ -3642,7 +3646,7 @@ def test_scan_bytes_blocks_itertools_adapter_next_call_iterator_consumption_rce(
36423646

36433647

36443648
@pytest.mark.parametrize(
3645-
("payload", "values_literal", "expected_repr", "requires_python311"),
3649+
("payload", "values_literal", "expected_repr", "requires_python_3_11_plus"),
36463650
[
36473651
(
36483652
_builtins_help_call_iterator_stdlib_materializer_payload("array", "array", _unicode_operand("i"), b"h\x00"),
@@ -3859,7 +3863,7 @@ def test_scan_bytes_blocks_stdlib_eager_call_iterator_consumption_rce(
38593863
payload: bytes,
38603864
values_literal: str,
38613865
expected_repr: str,
3862-
requires_python311: bool,
3866+
requires_python_3_11_plus: bool,
38633867
) -> None:
38643868
module_dir = tmp_path / "modules"
38653869
module_dir.mkdir()
@@ -3891,7 +3895,7 @@ def test_scan_bytes_blocks_stdlib_eager_call_iterator_consumption_rce(
38913895
)
38923896

38933897
assert not marker.exists()
3894-
if requires_python311 and sys.version_info < (3, 11):
3898+
if requires_python_3_11_plus and sys.version_info < (3, 11):
38953899
return
38963900
child_code = """
38973901
import ast
@@ -3945,7 +3949,7 @@ def test_scan_bytes_blocks_stdlib_eager_call_iterator_consumption_rce(
39453949

39463950

39473951
@pytest.mark.parametrize(
3948-
("payload", "values_literal", "expected_repr", "requires_python311"),
3952+
("payload", "values_literal", "expected_repr", "requires_python_3_11_plus"),
39493953
[
39503954
(
39513955
_builtins_help_call_iterator_stdlib_materializer_payload(
@@ -3976,7 +3980,7 @@ def test_scan_bytes_blocks_weighted_statistics_call_iterator_consumption_rce(
39763980
payload: bytes,
39773981
values_literal: str,
39783982
expected_repr: str,
3979-
requires_python311: bool,
3983+
requires_python_3_11_plus: bool,
39803984
) -> None:
39813985
module_dir = tmp_path / "modules"
39823986
module_dir.mkdir()
@@ -4008,7 +4012,7 @@ def test_scan_bytes_blocks_weighted_statistics_call_iterator_consumption_rce(
40084012
)
40094013

40104014
assert not marker.exists()
4011-
if requires_python311 and sys.version_info < (3, 11):
4015+
if requires_python_3_11_plus and sys.version_info < (3, 11):
40124016
return
40134017

40144018
child_code = """

tests/test_lazy_loading_integration.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
from modelaudit import core
1212
from modelaudit.scanners import _registry
1313

14+
MAX_SCANNERS_FOR_SINGLE_FILE_SCAN = 5
15+
MAX_SCANNERS_FOR_DIRECTORY_SCAN = 10
16+
MAX_SCANNERS_FOR_INCREMENTAL_SCAN = 15
17+
1418

1519
class TestCoreIntegration:
1620
"""Test integration of lazy loading with core scanning functionality."""
1721

18-
def test_scan_file_uses_lazy_loading(self):
22+
def test_scan_file_uses_lazy_loading(self) -> None:
1923
"""Test that scan_file uses lazy loading correctly."""
2024
# Reset loaded scanners
2125
_registry._loaded_scanners.clear()
@@ -36,11 +40,11 @@ def test_scan_file_uses_lazy_loading(self):
3640

3741
# Should have loaded minimal scanners
3842
loaded_count = len(_registry._loaded_scanners)
39-
assert loaded_count <= 5 # Should be minimal
43+
assert loaded_count <= MAX_SCANNERS_FOR_SINGLE_FILE_SCAN
4044
finally:
4145
Path(f.name).unlink(missing_ok=True)
4246

43-
def test_scan_directory_uses_lazy_loading(self):
47+
def test_scan_directory_uses_lazy_loading(self) -> None:
4448
"""Test that directory scanning uses lazy loading efficiently."""
4549
_registry._loaded_scanners.clear()
4650

@@ -61,7 +65,7 @@ def test_scan_directory_uses_lazy_loading(self):
6165

6266
# Should have loaded only necessary scanners
6367
loaded_count = len(_registry._loaded_scanners)
64-
assert loaded_count <= 10 # Should be reasonable
68+
assert loaded_count <= MAX_SCANNERS_FOR_DIRECTORY_SCAN
6569

6670
def test_preferred_scanner_lazy_loading(self, tmp_path: Path) -> None:
6771
"""Test that preferred scanner detection uses lazy loading."""
@@ -79,7 +83,7 @@ def test_preferred_scanner_lazy_loading(self, tmp_path: Path) -> None:
7983
# Should have loaded pickle scanner
8084
assert "pickle" in _registry._loaded_scanners
8185

82-
def test_multiple_file_types_incremental_loading(self):
86+
def test_multiple_file_types_incremental_loading(self) -> None:
8387
"""Test that scanning multiple file types loads scanners incrementally."""
8488
_registry._loaded_scanners.clear()
8589

@@ -106,13 +110,13 @@ def test_multiple_file_types_incremental_loading(self):
106110
# Should show incremental loading (or at least not loading everything at once)
107111
assert loaded_counts[0] > 0 # Some scanners loaded for first file
108112
# Later scans might load more, but shouldn't load everything
109-
assert max(loaded_counts) <= 15 # Reasonable upper bound
113+
assert max(loaded_counts) <= MAX_SCANNERS_FOR_INCREMENTAL_SCAN
110114

111115

112116
class TestPerformanceCharacteristics:
113117
"""Test performance characteristics of lazy loading."""
114118

115-
def test_import_performance(self):
119+
def test_import_performance(self) -> None:
116120
"""Test that importing scanners is fast with lazy loading."""
117121
# This test measures import time
118122
start_time = time.time()
@@ -122,16 +126,18 @@ def test_import_performance(self):
122126

123127
import_time = time.time() - start_time
124128

125-
# Should be much faster than 1 second (was 7+ seconds before)
129+
# The historical eager-loading baseline was 7+ seconds; 1 second leaves
130+
# room for local and CI variance while still catching a real regression.
126131
assert import_time < 1.0
127132

128133
# Accessing the registry should also be fast
129134
start_time = time.time()
130135
_ = scanners.SCANNER_REGISTRY
131136
access_time = time.time() - start_time
132137

133-
# First access loads scanners, but should still be reasonable
134-
assert access_time < 5.0 # Much better than 7+ seconds
138+
# First access performs one-time lazy-loading work, so keep this looser
139+
# than the import guard while still catching a return to the old 7+ second path.
140+
assert access_time < 5.0
135141

136142
def test_single_scanner_access_performance(self) -> None:
137143
"""Test that accessing a single scanner is fast."""

tests/test_perf_workflow.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@
77

88

99
def _load_perf_workflow() -> dict[str, Any]:
10-
workflow_path = Path(__file__).resolve().parents[1] / ".github" / "workflows" / "perf.yml"
10+
current_path = Path(__file__).resolve()
11+
workflow_path = next(
12+
(
13+
candidate_root / ".github" / "workflows" / "perf.yml"
14+
for candidate_root in [current_path.parent, *current_path.parents]
15+
if (candidate_root / ".github" / "workflows" / "perf.yml").is_file()
16+
),
17+
None,
18+
)
19+
if workflow_path is None:
20+
raise AssertionError("Could not locate .github/workflows/perf.yml from test file path")
1121
workflow = yaml.safe_load(workflow_path.read_text(encoding="utf-8"))
1222
assert isinstance(workflow, dict)
1323
return workflow

tests/test_security_enhancements.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,21 @@
1919
class TestJoblibScannerSecurity:
2020
"""Test security enhancements for Joblib scanner."""
2121

22-
def test_compression_bomb_detection(self, tmp_path):
22+
def test_compression_bomb_detection(self, tmp_path: Path) -> None:
2323
"""Test that compression bombs are detected."""
2424
# Create a compression bomb (large data that compresses well)
2525
bomb_data = b"A" * (10 * 1024 * 1024) # 10MB of 'A's
2626
compressed = zlib.compress(bomb_data, level=9)
27+
max_decompression_ratio = 50.0
28+
actual_ratio = len(bomb_data) / len(compressed)
29+
assert actual_ratio > max_decompression_ratio
2730

2831
# Write to a .joblib file
2932
joblib_file = tmp_path / "bomb.joblib"
3033
joblib_file.write_bytes(compressed)
3134

3235
# Configure scanner with low compression ratio limit
33-
config = {"max_decompression_ratio": 50.0} # Lower than actual ratio
36+
config = {"max_decompression_ratio": max_decompression_ratio} # Lower than actual ratio
3437
scanner = JoblibScanner(config)
3538

3639
result = scanner.scan(str(joblib_file))
@@ -151,11 +154,11 @@ def test_zip_format_joblib(self, tmp_path):
151154
# Should delegate to ZIP scanner and succeed
152155
assert result.success is True
153156

154-
def test_direct_pickle_joblib(self, tmp_path):
157+
def test_direct_pickle_joblib(self, tmp_path: Path) -> None:
155158
"""Test joblib files that are direct pickle (not compressed)."""
156159
# Create direct pickle data with pickle magic bytes
157160
data = {"test": "direct_pickle"}
158-
pickled = pickle.dumps(data, protocol=2) # Protocol 2 starts with 0x80
161+
pickled = pickle.dumps(data, protocol=2) # Protocol 2 starts with 0x80 0x02
159162

160163
joblib_file = tmp_path / "direct.joblib"
161164
joblib_file.write_bytes(pickled)
@@ -204,7 +207,7 @@ def test_file_read_chunk_limit(self, tmp_path):
204207
class TestNumPyScannerSecurity:
205208
"""Test security enhancements for NumPy scanner."""
206209

207-
def test_negative_dimension_rejection(self, tmp_path):
210+
def test_negative_dimension_rejection(self, tmp_path: Path) -> None:
208211
"""Test rejection of arrays with negative dimensions."""
209212
# We'll need to create a malformed numpy file manually
210213
# since numpy.save() won't create invalid files
@@ -219,8 +222,8 @@ def test_negative_dimension_rejection(self, tmp_path):
219222
header_len = len(header)
220223
f.write(header_len.to_bytes(2, "little"))
221224
f.write(header.encode("latin1"))
222-
# Add some dummy data
223-
f.write(b"\x00" * 1600) # 20 * 8 bytes per float64
225+
# The scanner should reject the invalid header before reading a full payload.
226+
f.write(b"\x00")
224227

225228
scanner = NumPyScanner()
226229
result = scanner.scan(str(npy_file))
@@ -273,7 +276,13 @@ def test_dimension_size_limit(self, tmp_path):
273276
assert len(size_issues) > 0
274277

275278
def test_dangerous_dtype_reports_cve_info(self, tmp_path: Path) -> None:
276-
"""Object dtype arrays should scan successfully while emitting CVE-2019-6446 info."""
279+
"""Object dtype arrays should emit informational CVE-2019-6446 context.
280+
281+
CVE-2019-6446 concerns unsafe loading of NumPy object arrays when pickle
282+
deserialization is allowed. Object dtypes can embed pickled Python objects,
283+
so the scanner should surface that context while still allowing a clean file
284+
to scan successfully.
285+
"""
277286
scanner = NumPyScanner()
278287
npy_file = tmp_path / "object_dtype.npy"
279288
np.save(npy_file, np.array([{"key": "value"}], dtype=object), allow_pickle=True)
@@ -363,23 +372,15 @@ def test_valid_numpy_array_still_works(self, tmp_path):
363372
assert "shape" in result.metadata
364373
assert "dtype" in result.metadata
365374

366-
def test_numpy_version_2_format(self, tmp_path):
375+
def test_numpy_version_2_format(self, tmp_path: Path) -> None:
367376
"""Test NumPy format version 2.0 handling."""
368-
# Create array that will use version 2.0 format
369-
# Use a very long array description to trigger v2.0 format
370-
371-
# Create a large 4D array that should trigger version 2.0
372-
# due to large header size, not structured dtype
373-
arr = np.zeros((100, 50, 20, 10), dtype=np.float64)
374-
377+
arr = np.zeros((2, 2), dtype=np.float64)
375378
npy_file = tmp_path / "version2.npy"
376-
np.save(npy_file, arr)
379+
with npy_file.open("wb") as file_obj:
380+
np.lib.format.write_array(file_obj, arr, version=(2, 0))
381+
assert npy_file.read_bytes()[6:8] == b"\x02\x00"
377382

378-
# Allow structured arrays for this test
379-
config = {
380-
"max_array_bytes": 10 * 1024 * 1024 * 1024,
381-
} # 10GB limit to allow large test array
382-
scanner = NumPyScanner(config)
383+
scanner = NumPyScanner()
383384
result = scanner.scan(str(npy_file))
384385

385386
# Should succeed

tests/utils/file/test_advanced_file_handler.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def scan(self, shard_path: str) -> ScanResult:
2929
return result
3030

3131

32-
class FailingShardScanner:
32+
class OperationalFailureScanner:
3333
"""Scanner that simulates an operational shard scan failure."""
3434

35-
name = "failing_shard_scanner"
35+
name = "operational_failure_scanner"
3636

3737
def scan(self, shard_path: str) -> ScanResult:
3838
raise RuntimeError(f"cannot scan {Path(shard_path).name}")
@@ -297,6 +297,11 @@ def test_massive_file_without_bounded_support_fails_closed(
297297
class ScannerWithoutBoundedSupport:
298298
name = "test_scanner"
299299

300+
def scan(self, _file_path: str) -> ScanResult:
301+
result = ScanResult(scanner_name=self.name)
302+
result.finish(success=True)
303+
return result
304+
300305
handler = AdvancedFileHandler(str(model_path), ScannerWithoutBoundedSupport())
301306
result = handler.scan()
302307

@@ -356,7 +361,7 @@ def test_parallel_shard_errors_mark_scan_inconclusive(self, tmp_path: Path) -> N
356361
"total_shards": 1,
357362
"total_size": shard_path.stat().st_size,
358363
},
359-
FailingShardScanner,
364+
OperationalFailureScanner,
360365
)
361366

362367
result = handler.scan_shards()

0 commit comments

Comments
 (0)