diff --git a/debug_toolbar/panels/sql/panel.py b/debug_toolbar/panels/sql/panel.py index 90e2ba812..c8576e16f 100644 --- a/debug_toolbar/panels/sql/panel.py +++ b/debug_toolbar/panels/sql/panel.py @@ -10,7 +10,7 @@ from debug_toolbar.panels import Panel from debug_toolbar.panels.sql import views from debug_toolbar.panels.sql.forms import SQLSelectForm -from debug_toolbar.panels.sql.tracking import unwrap_cursor, wrap_cursor +from debug_toolbar.panels.sql.tracking import wrap_cursor from debug_toolbar.panels.sql.utils import contrasting_color_generator, reformat_sql from debug_toolbar.utils import render_stacktrace @@ -190,11 +190,12 @@ def get_urls(cls): def enable_instrumentation(self): # This is thread-safe because database connections are thread-local. for connection in connections.all(): - wrap_cursor(connection, self) + wrap_cursor(connection) + connection._djdt_logger = self def disable_instrumentation(self): for connection in connections.all(): - unwrap_cursor(connection) + connection._djdt_logger = None def generate_stats(self, request, response): colors = contrasting_color_generator() diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index 565d9244b..425e4e5cc 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -3,6 +3,8 @@ import json from time import time +import django.test.testcases +from django.db.backends.utils import CursorWrapper from django.utils.encoding import force_str from debug_toolbar import settings as dt_settings @@ -31,10 +33,15 @@ class SQLQueryTriggered(Exception): """Thrown when template panel triggers a query""" -def wrap_cursor(connection, panel): +def wrap_cursor(connection): + # If running a Django SimpleTestCase, which isn't allowed to access the database, + # don't perform any monkey patching. + if isinstance(connection.cursor, django.test.testcases._DatabaseFailure): + return if not hasattr(connection, "_djdt_cursor"): connection._djdt_cursor = connection.cursor connection._djdt_chunked_cursor = connection.chunked_cursor + connection._djdt_logger = None def cursor(*args, **kwargs): # Per the DB API cursor() does not accept any arguments. There's @@ -43,78 +50,55 @@ def cursor(*args, **kwargs): # See: # https://github.com/jazzband/django-debug-toolbar/pull/615 # https://github.com/jazzband/django-debug-toolbar/pull/896 + logger = connection._djdt_logger + cursor = connection._djdt_cursor(*args, **kwargs) + if logger is None: + return cursor if allow_sql.get(): wrapper = NormalCursorWrapper else: wrapper = ExceptionCursorWrapper - return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel) + return wrapper(cursor.cursor, connection, logger) def chunked_cursor(*args, **kwargs): # prevent double wrapping # solves https://github.com/jazzband/django-debug-toolbar/issues/1239 + logger = connection._djdt_logger cursor = connection._djdt_chunked_cursor(*args, **kwargs) - if not isinstance(cursor, BaseCursorWrapper): + if logger is not None and not isinstance(cursor, DjDTCursorWrapper): if allow_sql.get(): wrapper = NormalCursorWrapper else: wrapper = ExceptionCursorWrapper - return wrapper(cursor, connection, panel) + return wrapper(cursor.cursor, connection, logger) return cursor connection.cursor = cursor connection.chunked_cursor = chunked_cursor - return cursor - - -def unwrap_cursor(connection): - if hasattr(connection, "_djdt_cursor"): - # Sometimes the cursor()/chunked_cursor() methods of the DatabaseWrapper - # instance are already monkey patched before wrap_cursor() is called. (In - # particular, Django's SimpleTestCase monkey patches those methods for any - # disallowed databases to raise an exception if they are accessed.) Thus only - # delete our monkey patch if the method we saved is the same as the class - # method. Otherwise, restore the prior monkey patch from our saved method. - if connection._djdt_cursor == connection.__class__.cursor: - del connection.cursor - else: - connection.cursor = connection._djdt_cursor - del connection._djdt_cursor - if connection._djdt_chunked_cursor == connection.__class__.chunked_cursor: - del connection.chunked_cursor - else: - connection.chunked_cursor = connection._djdt_chunked_cursor - del connection._djdt_chunked_cursor -class BaseCursorWrapper: - pass +class DjDTCursorWrapper(CursorWrapper): + def __init__(self, cursor, db, logger): + super().__init__(cursor, db) + # logger must implement a ``record`` method + self.logger = logger -class ExceptionCursorWrapper(BaseCursorWrapper): +class ExceptionCursorWrapper(DjDTCursorWrapper): """ Wraps a cursor and raises an exception on any operation. Used in Templates panel. """ - def __init__(self, cursor, db, logger): - pass - def __getattr__(self, attr): raise SQLQueryTriggered() -class NormalCursorWrapper(BaseCursorWrapper): +class NormalCursorWrapper(DjDTCursorWrapper): """ Wraps a cursor and logs queries. """ - def __init__(self, cursor, db, logger): - self.cursor = cursor - # Instance of a BaseDatabaseWrapper subclass - self.db = db - # logger must implement a ``record`` method - self.logger = logger - def _quote_expr(self, element): if isinstance(element, str): return "'%s'" % element.replace("'", "''") @@ -154,6 +138,21 @@ def _decode(self, param): except UnicodeDecodeError: return "(encoded string)" + def _last_executed_query(self, sql, params): + """Get the last executed query from the connection.""" + # Django's psycopg3 backend creates a new cursor in its implementation of the + # .last_executed_query() method. To avoid wrapping that cursor, temporarily set + # the DatabaseWrapper's ._djdt_logger attribute to None. This will cause the + # monkey-patched .cursor() and .chunked_cursor() methods to skip the wrapping + # process during the .last_executed_query() call. + self.db._djdt_logger = None + try: + return self.db.ops.last_executed_query( + self.cursor, sql, self._quote_params(params) + ) + finally: + self.db._djdt_logger = self.logger + def _record(self, method, sql, params): alias = self.db.alias vendor = self.db.vendor @@ -186,9 +185,7 @@ def _record(self, method, sql, params): params = { "vendor": vendor, "alias": alias, - "sql": self.db.ops.last_executed_query( - self.cursor, sql, self._quote_params(params) - ), + "sql": self._last_executed_query(sql, params), "duration": duration, "raw_sql": sql, "params": _params, @@ -196,7 +193,9 @@ def _record(self, method, sql, params): "stacktrace": get_stack_trace(skip=2), "start_time": start_time, "stop_time": stop_time, - "is_slow": duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"], + "is_slow": ( + duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"] + ), "is_select": sql.lower().strip().startswith("select"), "template_info": template_info, } @@ -241,22 +240,10 @@ def _record(self, method, sql, params): self.logger.record(**params) def callproc(self, procname, params=None): - return self._record(self.cursor.callproc, procname, params) + return self._record(super().callproc, procname, params) def execute(self, sql, params=None): - return self._record(self.cursor.execute, sql, params) + return self._record(super().execute, sql, params) def executemany(self, sql, param_list): - return self._record(self.cursor.executemany, sql, param_list) - - def __getattr__(self, attr): - return getattr(self.cursor, attr) - - def __iter__(self): - return iter(self.cursor) - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - self.close() + return self._record(super().executemany, sql, param_list)