|  | 
|  | 1 | +# Copyright 2024-present MongoDB, Inc. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +# http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +"""Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" | 
|  | 16 | +from __future__ import annotations | 
|  | 17 | + | 
|  | 18 | +import base64 | 
|  | 19 | +import gc | 
|  | 20 | +import multiprocessing | 
|  | 21 | +import os | 
|  | 22 | +import signal | 
|  | 23 | +import socket | 
|  | 24 | +import subprocess | 
|  | 25 | +import sys | 
|  | 26 | +import threading | 
|  | 27 | +import time | 
|  | 28 | +import traceback | 
|  | 29 | +import unittest | 
|  | 30 | +import warnings | 
|  | 31 | +from asyncio import iscoroutinefunction | 
|  | 32 | + | 
|  | 33 | +try: | 
|  | 34 | +    import ipaddress | 
|  | 35 | + | 
|  | 36 | +    HAVE_IPADDRESS = True | 
|  | 37 | +except ImportError: | 
|  | 38 | +    HAVE_IPADDRESS = False | 
|  | 39 | +from functools import wraps | 
|  | 40 | +from typing import Any, Callable, Dict, Generator, no_type_check | 
|  | 41 | +from unittest import SkipTest | 
|  | 42 | + | 
|  | 43 | +from bson.son import SON | 
|  | 44 | +from pymongo import common, message | 
|  | 45 | +from pymongo.ssl_support import HAVE_SSL, _ssl  # type:ignore[attr-defined] | 
|  | 46 | +from pymongo.uri_parser import parse_uri | 
|  | 47 | + | 
|  | 48 | +if HAVE_SSL: | 
|  | 49 | +    import ssl | 
|  | 50 | + | 
|  | 51 | +_IS_SYNC = False | 
|  | 52 | + | 
|  | 53 | +# Enable debug output for uncollectable objects. PyPy does not have set_debug. | 
|  | 54 | +if hasattr(gc, "set_debug"): | 
|  | 55 | +    gc.set_debug( | 
|  | 56 | +        gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0) | 
|  | 57 | +    ) | 
|  | 58 | + | 
|  | 59 | +# The host and port of a single mongod or mongos, or the seed host | 
|  | 60 | +# for a replica set. | 
|  | 61 | +host = os.environ.get("DB_IP", "localhost") | 
|  | 62 | +port = int(os.environ.get("DB_PORT", 27017)) | 
|  | 63 | +IS_SRV = "mongodb+srv" in host | 
|  | 64 | + | 
|  | 65 | +db_user = os.environ.get("DB_USER", "user") | 
|  | 66 | +db_pwd = os.environ.get("DB_PASSWORD", "password") | 
|  | 67 | + | 
|  | 68 | +CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") | 
|  | 69 | +CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem")) | 
|  | 70 | +CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem")) | 
|  | 71 | + | 
|  | 72 | +TLS_OPTIONS: Dict = {"tls": True} | 
|  | 73 | +if CLIENT_PEM: | 
|  | 74 | +    TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM | 
|  | 75 | +if CA_PEM: | 
|  | 76 | +    TLS_OPTIONS["tlsCAFile"] = CA_PEM | 
|  | 77 | + | 
|  | 78 | +COMPRESSORS = os.environ.get("COMPRESSORS") | 
|  | 79 | +MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") | 
|  | 80 | +TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) | 
|  | 81 | +TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS")) | 
|  | 82 | +SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") | 
|  | 83 | +MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") | 
|  | 84 | + | 
|  | 85 | +if TEST_LOADBALANCER: | 
|  | 86 | +    res = parse_uri(SINGLE_MONGOS_LB_URI or "") | 
|  | 87 | +    host, port = res["nodelist"][0] | 
|  | 88 | +    db_user = res["username"] or db_user | 
|  | 89 | +    db_pwd = res["password"] or db_pwd | 
|  | 90 | +elif TEST_SERVERLESS: | 
|  | 91 | +    TEST_LOADBALANCER = True | 
|  | 92 | +    res = parse_uri(SINGLE_MONGOS_LB_URI or "") | 
|  | 93 | +    host, port = res["nodelist"][0] | 
|  | 94 | +    db_user = res["username"] or db_user | 
|  | 95 | +    db_pwd = res["password"] or db_pwd | 
|  | 96 | +    TLS_OPTIONS = {"tls": True} | 
|  | 97 | +    # Spec says serverless tests must be run with compression. | 
|  | 98 | +    COMPRESSORS = COMPRESSORS or "zlib" | 
|  | 99 | + | 
|  | 100 | + | 
|  | 101 | +# Shared KMS data. | 
|  | 102 | +LOCAL_MASTER_KEY = base64.b64decode( | 
|  | 103 | +    b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ" | 
|  | 104 | +    b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk" | 
|  | 105 | +) | 
|  | 106 | +AWS_CREDS = { | 
|  | 107 | +    "accessKeyId": os.environ.get("FLE_AWS_KEY", ""), | 
|  | 108 | +    "secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""), | 
|  | 109 | +} | 
|  | 110 | +AWS_CREDS_2 = { | 
|  | 111 | +    "accessKeyId": os.environ.get("FLE_AWS_KEY2", ""), | 
|  | 112 | +    "secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""), | 
|  | 113 | +} | 
|  | 114 | +AZURE_CREDS = { | 
|  | 115 | +    "tenantId": os.environ.get("FLE_AZURE_TENANTID", ""), | 
|  | 116 | +    "clientId": os.environ.get("FLE_AZURE_CLIENTID", ""), | 
|  | 117 | +    "clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""), | 
|  | 118 | +} | 
|  | 119 | +GCP_CREDS = { | 
|  | 120 | +    "email": os.environ.get("FLE_GCP_EMAIL", ""), | 
|  | 121 | +    "privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""), | 
|  | 122 | +} | 
|  | 123 | +KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")} | 
|  | 124 | + | 
|  | 125 | +# Ensure Evergreen metadata doesn't result in truncation | 
|  | 126 | +os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000") | 
|  | 127 | + | 
|  | 128 | + | 
|  | 129 | +def is_server_resolvable(): | 
|  | 130 | +    """Returns True if 'server' is resolvable.""" | 
|  | 131 | +    socket_timeout = socket.getdefaulttimeout() | 
|  | 132 | +    socket.setdefaulttimeout(1) | 
|  | 133 | +    try: | 
|  | 134 | +        try: | 
|  | 135 | +            socket.gethostbyname("server") | 
|  | 136 | +            return True | 
|  | 137 | +        except OSError: | 
|  | 138 | +            return False | 
|  | 139 | +    finally: | 
|  | 140 | +        socket.setdefaulttimeout(socket_timeout) | 
|  | 141 | + | 
|  | 142 | + | 
|  | 143 | +def _create_user(authdb, user, pwd=None, roles=None, **kwargs): | 
|  | 144 | +    cmd = SON([("createUser", user)]) | 
|  | 145 | +    # X509 doesn't use a password | 
|  | 146 | +    if pwd: | 
|  | 147 | +        cmd["pwd"] = pwd | 
|  | 148 | +    cmd["roles"] = roles or ["root"] | 
|  | 149 | +    cmd.update(**kwargs) | 
|  | 150 | +    return authdb.command(cmd) | 
|  | 151 | + | 
|  | 152 | + | 
|  | 153 | +class client_knobs: | 
|  | 154 | +    def __init__( | 
|  | 155 | +        self, | 
|  | 156 | +        heartbeat_frequency=None, | 
|  | 157 | +        min_heartbeat_interval=None, | 
|  | 158 | +        kill_cursor_frequency=None, | 
|  | 159 | +        events_queue_frequency=None, | 
|  | 160 | +    ): | 
|  | 161 | +        self.heartbeat_frequency = heartbeat_frequency | 
|  | 162 | +        self.min_heartbeat_interval = min_heartbeat_interval | 
|  | 163 | +        self.kill_cursor_frequency = kill_cursor_frequency | 
|  | 164 | +        self.events_queue_frequency = events_queue_frequency | 
|  | 165 | + | 
|  | 166 | +        self.old_heartbeat_frequency = None | 
|  | 167 | +        self.old_min_heartbeat_interval = None | 
|  | 168 | +        self.old_kill_cursor_frequency = None | 
|  | 169 | +        self.old_events_queue_frequency = None | 
|  | 170 | +        self._enabled = False | 
|  | 171 | +        self._stack = None | 
|  | 172 | + | 
|  | 173 | +    def enable(self): | 
|  | 174 | +        self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY | 
|  | 175 | +        self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL | 
|  | 176 | +        self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY | 
|  | 177 | +        self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY | 
|  | 178 | + | 
|  | 179 | +        if self.heartbeat_frequency is not None: | 
|  | 180 | +            common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency | 
|  | 181 | + | 
|  | 182 | +        if self.min_heartbeat_interval is not None: | 
|  | 183 | +            common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval | 
|  | 184 | + | 
|  | 185 | +        if self.kill_cursor_frequency is not None: | 
|  | 186 | +            common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency | 
|  | 187 | + | 
|  | 188 | +        if self.events_queue_frequency is not None: | 
|  | 189 | +            common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency | 
|  | 190 | +        self._enabled = True | 
|  | 191 | +        # Store the allocation traceback to catch non-disabled client_knobs. | 
|  | 192 | +        self._stack = "".join(traceback.format_stack()) | 
|  | 193 | + | 
|  | 194 | +    def __enter__(self): | 
|  | 195 | +        self.enable() | 
|  | 196 | + | 
|  | 197 | +    @no_type_check | 
|  | 198 | +    def disable(self): | 
|  | 199 | +        common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency | 
|  | 200 | +        common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval | 
|  | 201 | +        common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency | 
|  | 202 | +        common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency | 
|  | 203 | +        self._enabled = False | 
|  | 204 | + | 
|  | 205 | +    def __exit__(self, exc_type, exc_val, exc_tb): | 
|  | 206 | +        self.disable() | 
|  | 207 | + | 
|  | 208 | +    def __call__(self, func): | 
|  | 209 | +        def make_wrapper(f): | 
|  | 210 | +            @wraps(f) | 
|  | 211 | +            async def wrap(*args, **kwargs): | 
|  | 212 | +                with self: | 
|  | 213 | +                    return await f(*args, **kwargs) | 
|  | 214 | + | 
|  | 215 | +            return wrap | 
|  | 216 | + | 
|  | 217 | +        return make_wrapper(func) | 
|  | 218 | + | 
|  | 219 | +    def __del__(self): | 
|  | 220 | +        if self._enabled: | 
|  | 221 | +            msg = ( | 
|  | 222 | +                "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, " | 
|  | 223 | +                "MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, " | 
|  | 224 | +                "EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format( | 
|  | 225 | +                    common.HEARTBEAT_FREQUENCY, | 
|  | 226 | +                    common.MIN_HEARTBEAT_INTERVAL, | 
|  | 227 | +                    common.KILL_CURSOR_FREQUENCY, | 
|  | 228 | +                    common.EVENTS_QUEUE_FREQUENCY, | 
|  | 229 | +                    self._stack, | 
|  | 230 | +                ) | 
|  | 231 | +            ) | 
|  | 232 | +            self.disable() | 
|  | 233 | +            raise Exception(msg) | 
|  | 234 | + | 
|  | 235 | + | 
|  | 236 | +def _all_users(db): | 
|  | 237 | +    return {u["user"] for u in db.command("usersInfo").get("users", [])} | 
|  | 238 | + | 
|  | 239 | + | 
|  | 240 | +def sanitize_cmd(cmd): | 
|  | 241 | +    cp = cmd.copy() | 
|  | 242 | +    cp.pop("$clusterTime", None) | 
|  | 243 | +    cp.pop("$db", None) | 
|  | 244 | +    cp.pop("$readPreference", None) | 
|  | 245 | +    cp.pop("lsid", None) | 
|  | 246 | +    if MONGODB_API_VERSION: | 
|  | 247 | +        # Stable API parameters | 
|  | 248 | +        cp.pop("apiVersion", None) | 
|  | 249 | +    # OP_MSG encoding may move the payload type one field to the | 
|  | 250 | +    # end of the command. Do the same here. | 
|  | 251 | +    name = next(iter(cp)) | 
|  | 252 | +    try: | 
|  | 253 | +        identifier = message._FIELD_MAP[name] | 
|  | 254 | +        docs = cp.pop(identifier) | 
|  | 255 | +        cp[identifier] = docs | 
|  | 256 | +    except KeyError: | 
|  | 257 | +        pass | 
|  | 258 | +    return cp | 
|  | 259 | + | 
|  | 260 | + | 
|  | 261 | +def sanitize_reply(reply): | 
|  | 262 | +    cp = reply.copy() | 
|  | 263 | +    cp.pop("$clusterTime", None) | 
|  | 264 | +    cp.pop("operationTime", None) | 
|  | 265 | +    return cp | 
|  | 266 | + | 
|  | 267 | + | 
|  | 268 | +def print_thread_tracebacks() -> None: | 
|  | 269 | +    """Print all Python thread tracebacks.""" | 
|  | 270 | +    for thread_id, frame in sys._current_frames().items(): | 
|  | 271 | +        sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n") | 
|  | 272 | +        traceback.print_stack(frame, file=sys.stderr) | 
|  | 273 | + | 
|  | 274 | + | 
|  | 275 | +def print_thread_stacks(pid: int) -> None: | 
|  | 276 | +    """Print all C-level thread stacks for a given process id.""" | 
|  | 277 | +    if sys.platform == "darwin": | 
|  | 278 | +        cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"'] | 
|  | 279 | +    else: | 
|  | 280 | +        cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"'] | 
|  | 281 | + | 
|  | 282 | +    try: | 
|  | 283 | +        res = subprocess.run( | 
|  | 284 | +            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8" | 
|  | 285 | +        ) | 
|  | 286 | +    except Exception as exc: | 
|  | 287 | +        sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}") | 
|  | 288 | +    else: | 
|  | 289 | +        sys.stderr.write(res.stdout) | 
|  | 290 | + | 
|  | 291 | + | 
|  | 292 | +# Global knobs to speed up the test suite. | 
|  | 293 | +global_knobs = client_knobs(events_queue_frequency=0.05) | 
|  | 294 | + | 
|  | 295 | + | 
|  | 296 | +def _get_executors(topology): | 
|  | 297 | +    executors = [] | 
|  | 298 | +    for server in topology._servers.values(): | 
|  | 299 | +        # Some MockMonitor do not have an _executor. | 
|  | 300 | +        if hasattr(server._monitor, "_executor"): | 
|  | 301 | +            executors.append(server._monitor._executor) | 
|  | 302 | +        if hasattr(server._monitor, "_rtt_monitor"): | 
|  | 303 | +            executors.append(server._monitor._rtt_monitor._executor) | 
|  | 304 | +    executors.append(topology._Topology__events_executor) | 
|  | 305 | +    if topology._srv_monitor: | 
|  | 306 | +        executors.append(topology._srv_monitor._executor) | 
|  | 307 | + | 
|  | 308 | +    return [e for e in executors if e is not None] | 
|  | 309 | + | 
|  | 310 | + | 
|  | 311 | +def print_running_topology(topology): | 
|  | 312 | +    running = [e for e in _get_executors(topology) if not e._stopped] | 
|  | 313 | +    if running: | 
|  | 314 | +        print( | 
|  | 315 | +            "WARNING: found Topology with running threads:\n" | 
|  | 316 | +            f"  Threads: {running}\n" | 
|  | 317 | +            f"  Topology: {topology}\n" | 
|  | 318 | +            f"  Creation traceback:\n{topology._settings._stack}" | 
|  | 319 | +        ) | 
|  | 320 | + | 
|  | 321 | + | 
|  | 322 | +def test_cases(suite): | 
|  | 323 | +    """Iterator over all TestCases within a TestSuite.""" | 
|  | 324 | +    for suite_or_case in suite._tests: | 
|  | 325 | +        if isinstance(suite_or_case, unittest.TestCase): | 
|  | 326 | +            # unittest.TestCase | 
|  | 327 | +            yield suite_or_case | 
|  | 328 | +        else: | 
|  | 329 | +            # unittest.TestSuite | 
|  | 330 | +            yield from test_cases(suite_or_case) | 
|  | 331 | + | 
|  | 332 | + | 
|  | 333 | +# Helper method to workaround https://bugs.python.org/issue21724 | 
|  | 334 | +def clear_warning_registry(): | 
|  | 335 | +    """Clear the __warningregistry__ for all modules.""" | 
|  | 336 | +    for _, module in list(sys.modules.items()): | 
|  | 337 | +        if hasattr(module, "__warningregistry__"): | 
|  | 338 | +            module.__warningregistry__ = {}  # type:ignore[attr-defined] | 
|  | 339 | + | 
|  | 340 | + | 
|  | 341 | +class SystemCertsPatcher: | 
|  | 342 | +    def __init__(self, ca_certs): | 
|  | 343 | +        if ( | 
|  | 344 | +            ssl.OPENSSL_VERSION.lower().startswith("libressl") | 
|  | 345 | +            and sys.platform == "darwin" | 
|  | 346 | +            and not _ssl.IS_PYOPENSSL | 
|  | 347 | +        ): | 
|  | 348 | +            raise SkipTest( | 
|  | 349 | +                "LibreSSL on OSX doesn't support setting CA certificates " | 
|  | 350 | +                "using SSL_CERT_FILE environment variable." | 
|  | 351 | +            ) | 
|  | 352 | +        self.original_certs = os.environ.get("SSL_CERT_FILE") | 
|  | 353 | +        # Tell OpenSSL where CA certificates live. | 
|  | 354 | +        os.environ["SSL_CERT_FILE"] = ca_certs | 
|  | 355 | + | 
|  | 356 | +    def disable(self): | 
|  | 357 | +        if self.original_certs is None: | 
|  | 358 | +            os.environ.pop("SSL_CERT_FILE") | 
|  | 359 | +        else: | 
|  | 360 | +            os.environ["SSL_CERT_FILE"] = self.original_certs | 
0 commit comments