Skip to content

Commit 50123cc

Browse files
committed
parse Redis commands in the mock server and shutdown server on failure
1 parent 87cf9b1 commit 50123cc

File tree

1 file changed

+123
-65
lines changed

1 file changed

+123
-65
lines changed

tests/test_connect.py

Lines changed: 123 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import re
23
import socket
34
import ssl
45
import threading
@@ -12,6 +13,43 @@
1213

1314

1415
_CLIENT_NAME = "test-suite-client"
16+
_CMD_SEP = b"\r\n"
17+
_SUCCESS_RESP = b"+OK" + _CMD_SEP
18+
_ERROR_RESP = b"-ERR" + _CMD_SEP
19+
_COMMANDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
20+
21+
22+
@pytest.fixture
23+
def tcp_address():
24+
with socket.socket() as sock:
25+
sock.bind(("127.0.0.1", 0))
26+
return sock.getsockname()
27+
28+
29+
@pytest.fixture
30+
def uds_address(tmpdir):
31+
return tmpdir / "uds.sock"
32+
33+
34+
@pytest.fixture
35+
def ssl_cert(tcp_address, tmpdir):
36+
"""More or less equivalent to
37+
38+
.. code::
39+
40+
openssl req -new -x509 -days 365 -nodes -out mycert.pem -keyout mycert.pem
41+
"""
42+
host, _ = tcp_address
43+
ca = trustme.CA()
44+
cert = ca.issue_cert(host, common_name="trustme")
45+
46+
server_pem = str(tmpdir / "server.pem")
47+
cert.private_key_and_cert_chain_pem.write_to_path(path=server_pem)
48+
49+
client_pem = str(tmpdir / "client.pem")
50+
ca.cert_pem.write_to_path(path=client_pem)
51+
52+
return client_pem, server_pem
1553

1654

1755
def test_tcp_connect(tcp_address):
@@ -35,7 +73,25 @@ def test_tcp_ssl_connect(tcp_address, ssl_cert):
3573
_assert_connect(conn, tcp_address, certfile=server_pem)
3674

3775

38-
def redis_mock_server(server_address, ready, commands, certfile=None):
76+
def _assert_connect(conn, server_address, certfile=None):
77+
ready = threading.Event()
78+
stop = threading.Event()
79+
t = threading.Thread(
80+
target=_redis_mock_server,
81+
args=(server_address, ready, stop),
82+
kwargs={"certfile": certfile},
83+
)
84+
t.start()
85+
try:
86+
ready.wait()
87+
conn.connect()
88+
conn.disconnect()
89+
finally:
90+
stop.set()
91+
t.join(timeout=5)
92+
93+
94+
def _redis_mock_server(server_address, ready, stop, certfile=None):
3995
try:
4096
if isinstance(server_address, str):
4197
family = socket.AF_UNIX
@@ -46,86 +102,88 @@ def redis_mock_server(server_address, ready, commands, certfile=None):
46102
else:
47103
family = socket.AF_INET
48104
mockname = "Redis mock server (TCP)"
105+
49106
with socket.socket(family, socket.SOCK_STREAM) as s:
50107
s.bind(server_address)
51108
s.listen(1)
109+
s.settimeout(0.1)
52110

53111
if certfile:
54112
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
113+
context.minimum_version = ssl.TLSVersion.TLSv1_2
55114
context.load_cert_chain(certfile=certfile)
56115

57116
_logger.info("Start %s: %s", mockname, server_address)
58117
ready.set()
59-
ssock, _ = s.accept()
60-
with ssock:
118+
119+
# Wait a client connection
120+
while not stop.is_set():
121+
try:
122+
sconn, _ = s.accept()
123+
sconn.settimeout(0.1)
124+
break
125+
except socket.timeout:
126+
pass
127+
if stop.is_set():
128+
_logger.info("Exit %s: %s", mockname, server_address)
129+
return
130+
131+
# Receive commands from the client
132+
with sconn:
61133
if certfile:
62-
conn = context.wrap_socket(ssock, server_side=True)
134+
conn = context.wrap_socket(sconn, server_side=True)
63135
else:
64-
conn = ssock
136+
conn = sconn
65137
try:
66-
while True:
67-
data = conn.recv(1024)
68-
if not data:
69-
_logger.info("Exit %s: %s", mockname, server_address)
70-
break
71-
_logger.info("Command in %s: %s", mockname, data)
72-
resp = b"+ERROR\r\n"
73-
resp = commands.get(data, resp)
74-
_logger.info("Response from %s: %s", mockname, resp)
75-
conn.sendall(resp)
138+
buffer = b""
139+
command = None
140+
command_ptr = None
141+
fragment_length = None
142+
while not stop.is_set() or buffer:
143+
try:
144+
buffer += conn.recv(1024)
145+
except socket.timeout:
146+
continue
147+
if not buffer:
148+
continue
149+
parts = re.split(_CMD_SEP, buffer)
150+
buffer = parts[-1]
151+
for fragment in parts[:-1]:
152+
fragment = fragment.decode()
153+
_logger.info(
154+
"Command fragment in %s: %s", mockname, fragment
155+
)
156+
157+
if fragment.startswith("*") and command is None:
158+
command = [None for _ in range(int(fragment[1:]))]
159+
command_ptr = 0
160+
fragment_length = None
161+
continue
162+
163+
if (
164+
fragment.startswith("$")
165+
and command[command_ptr] is None
166+
):
167+
fragment_length = int(fragment[1:])
168+
continue
169+
170+
assert len(fragment) == fragment_length
171+
command[command_ptr] = fragment
172+
command_ptr += 1
173+
174+
if command_ptr < len(command):
175+
continue
176+
177+
command = " ".join(command)
178+
_logger.info("Command in %s: %s", mockname, command)
179+
resp = _COMMANDS.get(command, _ERROR_RESP)
180+
_logger.info("Response from %s: %s", mockname, resp)
181+
conn.sendall(resp)
182+
command = None
76183
finally:
77184
if certfile:
78185
conn.close()
186+
_logger.info("Exit %s: %s", mockname, server_address)
79187
except BaseException as e:
80188
_logger.exception("Error in %s: %s", mockname, e)
81189
raise
82-
83-
84-
def _assert_connect(conn, server_address, **server_kwargs):
85-
command = conn.pack_command("CLIENT", "SETNAME", _CLIENT_NAME)[0]
86-
commands = {command: b"+OK\r\n"}
87-
88-
ready = threading.Event()
89-
t = threading.Thread(
90-
target=redis_mock_server,
91-
args=(server_address, ready, commands),
92-
kwargs=server_kwargs,
93-
)
94-
t.start()
95-
ready.wait()
96-
conn.connect()
97-
conn.disconnect()
98-
t.join()
99-
100-
101-
@pytest.fixture
102-
def tcp_address():
103-
with socket.socket() as sock:
104-
sock.bind(("127.0.0.1", 0))
105-
return sock.getsockname()
106-
107-
108-
@pytest.fixture
109-
def uds_address(tmpdir):
110-
return tmpdir / "uds.sock"
111-
112-
113-
@pytest.fixture
114-
def ssl_cert(tcp_address, tmpdir):
115-
"""More or less equivalent to
116-
117-
.. code::
118-
119-
openssl req -new -x509 -days 365 -nodes -out mycert.pem -keyout mycert.pem
120-
"""
121-
host, _ = tcp_address
122-
ca = trustme.CA()
123-
cert = ca.issue_cert(host, common_name="trustme")
124-
125-
server_pem = str(tmpdir / "server.pem")
126-
cert.private_key_and_cert_chain_pem.write_to_path(path=server_pem)
127-
128-
client_pem = str(tmpdir / "client.pem")
129-
ca.cert_pem.write_to_path(path=client_pem)
130-
131-
return client_pem, server_pem

0 commit comments

Comments
 (0)