diff --git a/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py b/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py index 54e3ed080..861b0e27b 100644 --- a/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py +++ b/packages/modelaudit-picklescan/tests/test_call_graph_import_statements.py @@ -1863,6 +1863,52 @@ def test_scan_bytes_refreshes_invoked_import_fallback_after_source_rewrite( assert _has_critical_call_graph_finding(dangerous_report, module_name, "invoke", "builtins.__import__") +def test_startup_hook_write_call_graph_refreshes_after_source_rewrite( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + module_dir = tmp_path / "modules" + module_dir.mkdir() + module_name = "modelaudit_tp_rewritten_startup_hook_source" + module_path = module_dir / f"{module_name}.py" + module_path.write_text( + "def open_payload(path):\n return path\n\ndef write_payload(handle, value):\n return value\n", + encoding="utf-8", + ) + monkeypatch.syspath_prepend(str(module_dir)) + importlib.invalidate_caches() + _clear_call_graph_caches() + import_references = ( + {"module": module_name, "name": "open_payload"}, + {"module": module_name, "name": "write_payload"}, + ) + callable_invocations = import_references + + try: + safe_findings = call_graph.find_startup_hook_write_call_graphs(import_references, callable_invocations) + + module_path.write_text( + "def open_payload(path):\n return open(path, 'w', encoding='utf-8')\n\n" + "def write_payload(handle, value):\n return handle.write(value)\n", + encoding="utf-8", + ) + importlib.invalidate_caches() + dangerous_findings = call_graph.find_startup_hook_write_call_graphs( + import_references, + callable_invocations, + ) + finally: + _clear_call_graph_caches() + + assert safe_findings == () + assert len(dangerous_findings) == 1 + finding = dangerous_findings[0] + assert finding.opener_import_reference == f"{module_name}.open_payload" + assert finding.writer_import_reference == f"{module_name}.write_payload" + assert finding.open_sink == "builtins.open" + assert finding.write_sink == "handle.write" + + def test_call_graph_propagates_wrapper_import_execution_fallbacks() -> None: calls = call_graph._calls_for_function("platform.mac_ver") or ()