1
1
import logging
2
+ import re
2
3
import socket
3
4
import ssl
4
5
import threading
12
13
13
14
14
15
_CLIENT_NAME = "test-suite-client"
16
+ _CMD_SEP = b"\r \n "
17
+ _SUCCESS_RESP = b"+OK" + _CMD_SEP
18
+ _ERROR_RESP = b"-ERR" + _CMD_SEP
19
+ _COMMANDS = {f"CLIENT SETNAME { _CLIENT_NAME } " : _SUCCESS_RESP }
20
+
21
+
22
+ @pytest .fixture
23
+ def tcp_address ():
24
+ with socket .socket () as sock :
25
+ sock .bind (("127.0.0.1" , 0 ))
26
+ return sock .getsockname ()
27
+
28
+
29
+ @pytest .fixture
30
+ def uds_address (tmpdir ):
31
+ return tmpdir / "uds.sock"
32
+
33
+
34
+ @pytest .fixture
35
+ def ssl_cert (tcp_address , tmpdir ):
36
+ """More or less equivalent to
37
+
38
+ .. code::
39
+
40
+ openssl req -new -x509 -days 365 -nodes -out mycert.pem -keyout mycert.pem
41
+ """
42
+ host , _ = tcp_address
43
+ ca = trustme .CA ()
44
+ cert = ca .issue_cert (host , common_name = "trustme" )
45
+
46
+ server_pem = str (tmpdir / "server.pem" )
47
+ cert .private_key_and_cert_chain_pem .write_to_path (path = server_pem )
48
+
49
+ client_pem = str (tmpdir / "client.pem" )
50
+ ca .cert_pem .write_to_path (path = client_pem )
51
+
52
+ return client_pem , server_pem
15
53
16
54
17
55
def test_tcp_connect (tcp_address ):
@@ -35,7 +73,25 @@ def test_tcp_ssl_connect(tcp_address, ssl_cert):
35
73
_assert_connect (conn , tcp_address , certfile = server_pem )
36
74
37
75
38
- def redis_mock_server (server_address , ready , commands , certfile = None ):
76
+ def _assert_connect (conn , server_address , certfile = None ):
77
+ ready = threading .Event ()
78
+ stop = threading .Event ()
79
+ t = threading .Thread (
80
+ target = _redis_mock_server ,
81
+ args = (server_address , ready , stop ),
82
+ kwargs = {"certfile" : certfile },
83
+ )
84
+ t .start ()
85
+ try :
86
+ ready .wait ()
87
+ conn .connect ()
88
+ conn .disconnect ()
89
+ finally :
90
+ stop .set ()
91
+ t .join (timeout = 5 )
92
+
93
+
94
+ def _redis_mock_server (server_address , ready , stop , certfile = None ):
39
95
try :
40
96
if isinstance (server_address , str ):
41
97
family = socket .AF_UNIX
@@ -46,86 +102,88 @@ def redis_mock_server(server_address, ready, commands, certfile=None):
46
102
else :
47
103
family = socket .AF_INET
48
104
mockname = "Redis mock server (TCP)"
105
+
49
106
with socket .socket (family , socket .SOCK_STREAM ) as s :
50
107
s .bind (server_address )
51
108
s .listen (1 )
109
+ s .settimeout (0.1 )
52
110
53
111
if certfile :
54
112
context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
113
+ context .minimum_version = ssl .TLSVersion .TLSv1_2
55
114
context .load_cert_chain (certfile = certfile )
56
115
57
116
_logger .info ("Start %s: %s" , mockname , server_address )
58
117
ready .set ()
59
- ssock , _ = s .accept ()
60
- with ssock :
118
+
119
+ # Wait a client connection
120
+ while not stop .is_set ():
121
+ try :
122
+ sconn , _ = s .accept ()
123
+ sconn .settimeout (0.1 )
124
+ break
125
+ except socket .timeout :
126
+ pass
127
+ if stop .is_set ():
128
+ _logger .info ("Exit %s: %s" , mockname , server_address )
129
+ return
130
+
131
+ # Receive commands from the client
132
+ with sconn :
61
133
if certfile :
62
- conn = context .wrap_socket (ssock , server_side = True )
134
+ conn = context .wrap_socket (sconn , server_side = True )
63
135
else :
64
- conn = ssock
136
+ conn = sconn
65
137
try :
66
- while True :
67
- data = conn .recv (1024 )
68
- if not data :
69
- _logger .info ("Exit %s: %s" , mockname , server_address )
70
- break
71
- _logger .info ("Command in %s: %s" , mockname , data )
72
- resp = b"+ERROR\r \n "
73
- resp = commands .get (data , resp )
74
- _logger .info ("Response from %s: %s" , mockname , resp )
75
- conn .sendall (resp )
138
+ buffer = b""
139
+ command = None
140
+ command_ptr = None
141
+ fragment_length = None
142
+ while not stop .is_set () or buffer :
143
+ try :
144
+ buffer += conn .recv (1024 )
145
+ except socket .timeout :
146
+ continue
147
+ if not buffer :
148
+ continue
149
+ parts = re .split (_CMD_SEP , buffer )
150
+ buffer = parts [- 1 ]
151
+ for fragment in parts [:- 1 ]:
152
+ fragment = fragment .decode ()
153
+ _logger .info (
154
+ "Command fragment in %s: %s" , mockname , fragment
155
+ )
156
+
157
+ if fragment .startswith ("*" ) and command is None :
158
+ command = [None for _ in range (int (fragment [1 :]))]
159
+ command_ptr = 0
160
+ fragment_length = None
161
+ continue
162
+
163
+ if (
164
+ fragment .startswith ("$" )
165
+ and command [command_ptr ] is None
166
+ ):
167
+ fragment_length = int (fragment [1 :])
168
+ continue
169
+
170
+ assert len (fragment ) == fragment_length
171
+ command [command_ptr ] = fragment
172
+ command_ptr += 1
173
+
174
+ if command_ptr < len (command ):
175
+ continue
176
+
177
+ command = " " .join (command )
178
+ _logger .info ("Command in %s: %s" , mockname , command )
179
+ resp = _COMMANDS .get (command , _ERROR_RESP )
180
+ _logger .info ("Response from %s: %s" , mockname , resp )
181
+ conn .sendall (resp )
182
+ command = None
76
183
finally :
77
184
if certfile :
78
185
conn .close ()
186
+ _logger .info ("Exit %s: %s" , mockname , server_address )
79
187
except BaseException as e :
80
188
_logger .exception ("Error in %s: %s" , mockname , e )
81
189
raise
82
-
83
-
84
- def _assert_connect (conn , server_address , ** server_kwargs ):
85
- command = conn .pack_command ("CLIENT" , "SETNAME" , _CLIENT_NAME )[0 ]
86
- commands = {command : b"+OK\r \n " }
87
-
88
- ready = threading .Event ()
89
- t = threading .Thread (
90
- target = redis_mock_server ,
91
- args = (server_address , ready , commands ),
92
- kwargs = server_kwargs ,
93
- )
94
- t .start ()
95
- ready .wait ()
96
- conn .connect ()
97
- conn .disconnect ()
98
- t .join ()
99
-
100
-
101
- @pytest .fixture
102
- def tcp_address ():
103
- with socket .socket () as sock :
104
- sock .bind (("127.0.0.1" , 0 ))
105
- return sock .getsockname ()
106
-
107
-
108
- @pytest .fixture
109
- def uds_address (tmpdir ):
110
- return tmpdir / "uds.sock"
111
-
112
-
113
- @pytest .fixture
114
- def ssl_cert (tcp_address , tmpdir ):
115
- """More or less equivalent to
116
-
117
- .. code::
118
-
119
- openssl req -new -x509 -days 365 -nodes -out mycert.pem -keyout mycert.pem
120
- """
121
- host , _ = tcp_address
122
- ca = trustme .CA ()
123
- cert = ca .issue_cert (host , common_name = "trustme" )
124
-
125
- server_pem = str (tmpdir / "server.pem" )
126
- cert .private_key_and_cert_chain_pem .write_to_path (path = server_pem )
127
-
128
- client_pem = str (tmpdir / "client.pem" )
129
- ca .cert_pem .write_to_path (path = client_pem )
130
-
131
- return client_pem , server_pem
0 commit comments