Skip to content

Commit 8791aa0

Browse files
authored
PYTHON-4790 Migrate test_retryable_writes.py to async (#1876)
1 parent c0f7810 commit 8791aa0

File tree

6 files changed

+1092
-22
lines changed

6 files changed

+1092
-22
lines changed

test/asynchronous/helpers.py

+360
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# Copyright 2024-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
16+
from __future__ import annotations
17+
18+
import base64
19+
import gc
20+
import multiprocessing
21+
import os
22+
import signal
23+
import socket
24+
import subprocess
25+
import sys
26+
import threading
27+
import time
28+
import traceback
29+
import unittest
30+
import warnings
31+
from asyncio import iscoroutinefunction
32+
33+
try:
34+
import ipaddress
35+
36+
HAVE_IPADDRESS = True
37+
except ImportError:
38+
HAVE_IPADDRESS = False
39+
from functools import wraps
40+
from typing import Any, Callable, Dict, Generator, no_type_check
41+
from unittest import SkipTest
42+
43+
from bson.son import SON
44+
from pymongo import common, message
45+
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
46+
from pymongo.uri_parser import parse_uri
47+
48+
if HAVE_SSL:
49+
import ssl
50+
51+
_IS_SYNC = False
52+
53+
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
54+
if hasattr(gc, "set_debug"):
55+
gc.set_debug(
56+
gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0)
57+
)
58+
59+
# The host and port of a single mongod or mongos, or the seed host
60+
# for a replica set.
61+
host = os.environ.get("DB_IP", "localhost")
62+
port = int(os.environ.get("DB_PORT", 27017))
63+
IS_SRV = "mongodb+srv" in host
64+
65+
db_user = os.environ.get("DB_USER", "user")
66+
db_pwd = os.environ.get("DB_PASSWORD", "password")
67+
68+
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
69+
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
70+
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
71+
72+
TLS_OPTIONS: Dict = {"tls": True}
73+
if CLIENT_PEM:
74+
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
75+
if CA_PEM:
76+
TLS_OPTIONS["tlsCAFile"] = CA_PEM
77+
78+
COMPRESSORS = os.environ.get("COMPRESSORS")
79+
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
80+
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
81+
TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS"))
82+
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
83+
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
84+
85+
if TEST_LOADBALANCER:
86+
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
87+
host, port = res["nodelist"][0]
88+
db_user = res["username"] or db_user
89+
db_pwd = res["password"] or db_pwd
90+
elif TEST_SERVERLESS:
91+
TEST_LOADBALANCER = True
92+
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
93+
host, port = res["nodelist"][0]
94+
db_user = res["username"] or db_user
95+
db_pwd = res["password"] or db_pwd
96+
TLS_OPTIONS = {"tls": True}
97+
# Spec says serverless tests must be run with compression.
98+
COMPRESSORS = COMPRESSORS or "zlib"
99+
100+
101+
# Shared KMS data.
102+
LOCAL_MASTER_KEY = base64.b64decode(
103+
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
104+
b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
105+
)
106+
AWS_CREDS = {
107+
"accessKeyId": os.environ.get("FLE_AWS_KEY", ""),
108+
"secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""),
109+
}
110+
AWS_CREDS_2 = {
111+
"accessKeyId": os.environ.get("FLE_AWS_KEY2", ""),
112+
"secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""),
113+
}
114+
AZURE_CREDS = {
115+
"tenantId": os.environ.get("FLE_AZURE_TENANTID", ""),
116+
"clientId": os.environ.get("FLE_AZURE_CLIENTID", ""),
117+
"clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""),
118+
}
119+
GCP_CREDS = {
120+
"email": os.environ.get("FLE_GCP_EMAIL", ""),
121+
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
122+
}
123+
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
124+
125+
# Ensure Evergreen metadata doesn't result in truncation
126+
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
127+
128+
129+
def is_server_resolvable():
130+
"""Returns True if 'server' is resolvable."""
131+
socket_timeout = socket.getdefaulttimeout()
132+
socket.setdefaulttimeout(1)
133+
try:
134+
try:
135+
socket.gethostbyname("server")
136+
return True
137+
except OSError:
138+
return False
139+
finally:
140+
socket.setdefaulttimeout(socket_timeout)
141+
142+
143+
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
144+
cmd = SON([("createUser", user)])
145+
# X509 doesn't use a password
146+
if pwd:
147+
cmd["pwd"] = pwd
148+
cmd["roles"] = roles or ["root"]
149+
cmd.update(**kwargs)
150+
return authdb.command(cmd)
151+
152+
153+
class client_knobs:
154+
def __init__(
155+
self,
156+
heartbeat_frequency=None,
157+
min_heartbeat_interval=None,
158+
kill_cursor_frequency=None,
159+
events_queue_frequency=None,
160+
):
161+
self.heartbeat_frequency = heartbeat_frequency
162+
self.min_heartbeat_interval = min_heartbeat_interval
163+
self.kill_cursor_frequency = kill_cursor_frequency
164+
self.events_queue_frequency = events_queue_frequency
165+
166+
self.old_heartbeat_frequency = None
167+
self.old_min_heartbeat_interval = None
168+
self.old_kill_cursor_frequency = None
169+
self.old_events_queue_frequency = None
170+
self._enabled = False
171+
self._stack = None
172+
173+
def enable(self):
174+
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
175+
self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
176+
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
177+
self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
178+
179+
if self.heartbeat_frequency is not None:
180+
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
181+
182+
if self.min_heartbeat_interval is not None:
183+
common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
184+
185+
if self.kill_cursor_frequency is not None:
186+
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
187+
188+
if self.events_queue_frequency is not None:
189+
common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
190+
self._enabled = True
191+
# Store the allocation traceback to catch non-disabled client_knobs.
192+
self._stack = "".join(traceback.format_stack())
193+
194+
def __enter__(self):
195+
self.enable()
196+
197+
@no_type_check
198+
def disable(self):
199+
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
200+
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
201+
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
202+
common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
203+
self._enabled = False
204+
205+
def __exit__(self, exc_type, exc_val, exc_tb):
206+
self.disable()
207+
208+
def __call__(self, func):
209+
def make_wrapper(f):
210+
@wraps(f)
211+
async def wrap(*args, **kwargs):
212+
with self:
213+
return await f(*args, **kwargs)
214+
215+
return wrap
216+
217+
return make_wrapper(func)
218+
219+
def __del__(self):
220+
if self._enabled:
221+
msg = (
222+
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
223+
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
224+
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
225+
common.HEARTBEAT_FREQUENCY,
226+
common.MIN_HEARTBEAT_INTERVAL,
227+
common.KILL_CURSOR_FREQUENCY,
228+
common.EVENTS_QUEUE_FREQUENCY,
229+
self._stack,
230+
)
231+
)
232+
self.disable()
233+
raise Exception(msg)
234+
235+
236+
def _all_users(db):
237+
return {u["user"] for u in db.command("usersInfo").get("users", [])}
238+
239+
240+
def sanitize_cmd(cmd):
241+
cp = cmd.copy()
242+
cp.pop("$clusterTime", None)
243+
cp.pop("$db", None)
244+
cp.pop("$readPreference", None)
245+
cp.pop("lsid", None)
246+
if MONGODB_API_VERSION:
247+
# Stable API parameters
248+
cp.pop("apiVersion", None)
249+
# OP_MSG encoding may move the payload type one field to the
250+
# end of the command. Do the same here.
251+
name = next(iter(cp))
252+
try:
253+
identifier = message._FIELD_MAP[name]
254+
docs = cp.pop(identifier)
255+
cp[identifier] = docs
256+
except KeyError:
257+
pass
258+
return cp
259+
260+
261+
def sanitize_reply(reply):
262+
cp = reply.copy()
263+
cp.pop("$clusterTime", None)
264+
cp.pop("operationTime", None)
265+
return cp
266+
267+
268+
def print_thread_tracebacks() -> None:
269+
"""Print all Python thread tracebacks."""
270+
for thread_id, frame in sys._current_frames().items():
271+
sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n")
272+
traceback.print_stack(frame, file=sys.stderr)
273+
274+
275+
def print_thread_stacks(pid: int) -> None:
276+
"""Print all C-level thread stacks for a given process id."""
277+
if sys.platform == "darwin":
278+
cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"']
279+
else:
280+
cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"']
281+
282+
try:
283+
res = subprocess.run(
284+
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
285+
)
286+
except Exception as exc:
287+
sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}")
288+
else:
289+
sys.stderr.write(res.stdout)
290+
291+
292+
# Global knobs to speed up the test suite.
293+
global_knobs = client_knobs(events_queue_frequency=0.05)
294+
295+
296+
def _get_executors(topology):
297+
executors = []
298+
for server in topology._servers.values():
299+
# Some MockMonitor do not have an _executor.
300+
if hasattr(server._monitor, "_executor"):
301+
executors.append(server._monitor._executor)
302+
if hasattr(server._monitor, "_rtt_monitor"):
303+
executors.append(server._monitor._rtt_monitor._executor)
304+
executors.append(topology._Topology__events_executor)
305+
if topology._srv_monitor:
306+
executors.append(topology._srv_monitor._executor)
307+
308+
return [e for e in executors if e is not None]
309+
310+
311+
def print_running_topology(topology):
312+
running = [e for e in _get_executors(topology) if not e._stopped]
313+
if running:
314+
print(
315+
"WARNING: found Topology with running threads:\n"
316+
f" Threads: {running}\n"
317+
f" Topology: {topology}\n"
318+
f" Creation traceback:\n{topology._settings._stack}"
319+
)
320+
321+
322+
def test_cases(suite):
323+
"""Iterator over all TestCases within a TestSuite."""
324+
for suite_or_case in suite._tests:
325+
if isinstance(suite_or_case, unittest.TestCase):
326+
# unittest.TestCase
327+
yield suite_or_case
328+
else:
329+
# unittest.TestSuite
330+
yield from test_cases(suite_or_case)
331+
332+
333+
# Helper method to workaround https://bugs.python.org/issue21724
334+
def clear_warning_registry():
335+
"""Clear the __warningregistry__ for all modules."""
336+
for _, module in list(sys.modules.items()):
337+
if hasattr(module, "__warningregistry__"):
338+
module.__warningregistry__ = {} # type:ignore[attr-defined]
339+
340+
341+
class SystemCertsPatcher:
342+
def __init__(self, ca_certs):
343+
if (
344+
ssl.OPENSSL_VERSION.lower().startswith("libressl")
345+
and sys.platform == "darwin"
346+
and not _ssl.IS_PYOPENSSL
347+
):
348+
raise SkipTest(
349+
"LibreSSL on OSX doesn't support setting CA certificates "
350+
"using SSL_CERT_FILE environment variable."
351+
)
352+
self.original_certs = os.environ.get("SSL_CERT_FILE")
353+
# Tell OpenSSL where CA certificates live.
354+
os.environ["SSL_CERT_FILE"] = ca_certs
355+
356+
def disable(self):
357+
if self.original_certs is None:
358+
os.environ.pop("SSL_CERT_FILE")
359+
else:
360+
os.environ["SSL_CERT_FILE"] = self.original_certs

0 commit comments

Comments
 (0)