Skip to content

Commit f080053

Browse files
committed
feat: Support Python subprocesses
Pytest-socket should be able to block socket calls in Python subprocesses created by tests (e.g., pip's test suite). To hook into new Python subprocesses, we can use a .pth file to run code during Python startup. This is what pytest-cov does to automagically support subprocess coverage tracking. State is passed to the .pth file via the _PYTEST_SOCKET_SUBPROCESS environment variable with JSON as the encoding format. Some refactoring was necessary to allow for the right pytest-socket state to be easily passed down to the .pth file (w/o recalculating or rerunning the entirety of pytest_socket.py). Testing-wise, majority of the tests contained in test_socket.py and test_restrict_hosts.py were copied as subprocess tests. While this doesn't cover every single surface, this should be sufficient to ensure the subprocess support is working properly.
1 parent 6093640 commit f080053

8 files changed

+461
-17
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ tests to ensure network calls are prevented.
1616
## Features
1717

1818
- Disables all network calls flowing through Python\'s `socket` interface.
19+
- Python subprocesses are supported
1920

2021
## Requirements
2122

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include = [
1212
{ path = "README.md", format = "sdist" },
1313
{ path = "tests", format = "sdist" },
1414
{ path = ".flake8", format = "sdist" },
15+
{ path = "pytest_socket.pth", format = ["sdist", "wheel"] }
1516
]
1617
classifiers = [
1718
"Development Status :: 4 - Beta",

pytest_socket.embed

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Inject pytest-socket into Python subprocesses via a .pth file.
2+
3+
To update the .pth file, simply run this script which will write the code
4+
below to a .pth file in the same directory as a single line.
5+
"""
6+
7+
import os
8+
os.environ["_PYTEST_SOCKET_SUBPROCESS"] = ""
9+
10+
# .PTH START
11+
if config := os.getenv("_PYTEST_SOCKET_SUBPROCESS", None):
12+
import json
13+
state = None
14+
try:
15+
import socket
16+
from pytest_socket import disable_socket, _create_guarded_connect
17+
state = json.loads(config)
18+
19+
if state["mode"] == "disable":
20+
disable_socket(allow_unix_socket=state["allow_unix_socket"])
21+
elif state["mode"] == "allow-hosts":
22+
socket.socket.connect = _create_guarded_connect(
23+
allowed_hosts=state["allowed_hosts"],
24+
allow_unix_socket=state["allow_unix_socket"],
25+
_pretty_allowed_list=state["_pretty_allowed_list"]
26+
)
27+
28+
except Exception as exc:
29+
import sys
30+
sys.stderr.write(
31+
"pytest-socket: Failed to set up subprocess socket patching.\n"
32+
f"Configuration: {state}\n"
33+
f"{exc.__class__.__name__}: {exc}\n"
34+
)
35+
# .PTH END
36+
37+
if __name__ == "__main__":
38+
from pathlib import Path
39+
40+
src = Path(__file__)
41+
dst = src.with_suffix(".pth")
42+
lines = src.read_text().splitlines()
43+
code = "\n".join(lines[lines.index("# .PTH START") + 1 : lines.index("# .PTH END")])
44+
45+
print(f"Writing to {dst}")
46+
# Only lines beginning with an import will be executed.
47+
# https://docs.python.org/3/library/site.html
48+
dst.write_text(f"import os; exec({code!r})\n")

pytest_socket.pth

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import os; exec('if config := os.getenv("_PYTEST_SOCKET_SUBPROCESS", None):\n import json\n state = None\n try:\n import socket\n from pytest_socket import disable_socket, _create_guarded_connect\n state = json.loads(config)\n\n if state["mode"] == "disable":\n disable_socket(allow_unix_socket=state["allow_unix_socket"])\n elif state["mode"] == "allow-hosts":\n socket.socket.connect = _create_guarded_connect(\n allowed_hosts=state["allowed_hosts"],\n allow_unix_socket=state["allow_unix_socket"],\n _pretty_allowed_list=state["_pretty_allowed_list"]\n )\n\n except Exception as exc:\n import sys\n sys.stderr.write(\n "pytest-socket: Failed to set up subprocess socket patching.\\n"\n f"Configuration: {state}\\n"\n f"{exc.__class__.__name__}: {exc}\\n"\n )')

pytest_socket.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,34 @@
11
import ipaddress
22
import itertools
3+
import json
4+
import os
35
import socket
46
import typing
57
from collections import defaultdict
68
from dataclasses import dataclass, field
79

810
import pytest
911

12+
_SUBPROCESS_ENVVAR = "_PYTEST_SOCKET_SUBPROCESS"
1013
_true_socket = socket.socket
1114
_true_connect = socket.socket.connect
1215

1316

17+
def update_subprocess_config(config: typing.Dict[str, object]) -> None:
18+
"""Enable pytest-socket in Python subprocesses.
19+
20+
The configuration will be read by the .pth file to mirror the
21+
restrictions in the main process.
22+
"""
23+
os.environ[_SUBPROCESS_ENVVAR] = json.dumps(config)
24+
25+
26+
def delete_subprocess_config() -> None:
27+
"""Disable pytest-socket in Python subprocesses."""
28+
if _SUBPROCESS_ENVVAR in os.environ:
29+
del os.environ[_SUBPROCESS_ENVVAR]
30+
31+
1432
class SocketBlockedError(RuntimeError):
1533
def __init__(self, *_args, **_kwargs):
1634
super().__init__("A test tried to use socket.socket.")
@@ -103,11 +121,15 @@ def __new__(cls, family=-1, type=-1, proto=-1, fileno=None):
103121
raise SocketBlockedError()
104122

105123
socket.socket = GuardedSocket
124+
update_subprocess_config(
125+
{"mode": "disable", "allow_unix_socket": allow_unix_socket}
126+
)
106127

107128

108129
def enable_socket():
109130
"""re-enable socket.socket to enable the Internet. useful in testing."""
110131
socket.socket = _true_socket
132+
delete_subprocess_config()
111133

112134

113135
def pytest_configure(config):
@@ -249,6 +271,25 @@ def normalize_allowed_hosts(
249271
return ip_hosts
250272

251273

274+
def _create_guarded_connect(
275+
allowed_hosts: typing.Sequence[str],
276+
allow_unix_socket: bool,
277+
_pretty_allowed_list: typing.Sequence[str],
278+
) -> typing.Callable:
279+
"""Create a function to replace socket.connect."""
280+
281+
def guarded_connect(inst, *args):
282+
host = host_from_connect_args(args)
283+
if host in allowed_hosts or (
284+
_is_unix_socket(inst.family) and allow_unix_socket
285+
):
286+
return _true_connect(inst, *args)
287+
288+
raise SocketConnectBlockedError(_pretty_allowed_list, host)
289+
290+
return guarded_connect
291+
292+
252293
def socket_allow_hosts(
253294
allowed: typing.Union[str, typing.List[str], None] = None,
254295
allow_unix_socket: bool = False,
@@ -276,19 +317,21 @@ def socket_allow_hosts(
276317
]
277318
)
278319

279-
def guarded_connect(inst, *args):
280-
host = host_from_connect_args(args)
281-
if host in allowed_ip_hosts_and_hostnames or (
282-
_is_unix_socket(inst.family) and allow_unix_socket
283-
):
284-
return _true_connect(inst, *args)
285-
286-
raise SocketConnectBlockedError(allowed_list, host)
287-
288-
socket.socket.connect = guarded_connect
320+
socket.socket.connect = _create_guarded_connect(
321+
allowed_ip_hosts_and_hostnames, allow_unix_socket, allowed_list
322+
)
323+
update_subprocess_config(
324+
{
325+
"mode": "allow-hosts",
326+
"allowed_hosts": list(allowed_ip_hosts_and_hostnames),
327+
"allow_unix_socket": allow_unix_socket,
328+
"_pretty_allowed_list": allowed_list,
329+
}
330+
)
289331

290332

291333
def _remove_restrictions():
292334
"""restore socket.socket.* to allow access to the Internet. useful in testing."""
293335
socket.socket = _true_socket
294336
socket.socket.connect = _true_connect
337+
delete_subprocess_config()

tests/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,9 @@ def assert_socket_blocked(result, passed=0, skipped=0, failed=1):
1414
result.stdout.fnmatch_lines(
1515
"*Socket*Blocked*Error: A test tried to use socket.socket.*"
1616
)
17+
18+
19+
def assert_host_blocked(result, host):
20+
result.stdout.fnmatch_lines(
21+
f'*A test tried to use socket.socket.connect() with host "{host}"*'
22+
)

tests/test_restrict_hosts.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from pytest_socket import normalize_allowed_hosts
7+
from pytest_socket import assert_host_blocked, normalize_allowed_hosts
88

99
localhost = "127.0.0.1"
1010

@@ -46,12 +46,6 @@ def {2}():
4646
"""
4747

4848

49-
def assert_host_blocked(result, host):
50-
result.stdout.fnmatch_lines(
51-
f'*A test tried to use socket.socket.connect() with host "{host}"*'
52-
)
53-
54-
5549
@pytest.fixture
5650
def assert_connect(httpbin, testdir):
5751
def assert_socket_connect(should_pass, **kwargs):

0 commit comments

Comments
 (0)