diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index 2ed691344..e3b225e9a 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -1,6 +1,6 @@ +import contextvars import datetime import json -from threading import local from time import time from django.utils.encoding import force_str @@ -13,30 +13,12 @@ except ImportError: PostgresJson = None +recording = contextvars.ContextVar("debug-toolbar-recording", default=True) + class SQLQueryTriggered(Exception): """Thrown when template panel triggers a query""" - pass - - -class ThreadLocalState(local): - def __init__(self): - self.enabled = True - - @property - def Wrapper(self): - if self.enabled: - return NormalCursorWrapper - return ExceptionCursorWrapper - - def recording(self, v): - self.enabled = v - - -state = ThreadLocalState() -recording = state.recording # export function - def wrap_cursor(connection, panel): if not hasattr(connection, "_djdt_cursor"): @@ -50,16 +32,22 @@ def cursor(*args, **kwargs): # See: # https://github.com/jazzband/django-debug-toolbar/pull/615 # https://github.com/jazzband/django-debug-toolbar/pull/896 - return state.Wrapper( - connection._djdt_cursor(*args, **kwargs), connection, panel - ) + if recording.get(): + wrapper = NormalCursorWrapper + else: + wrapper = ExceptionCursorWrapper + return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel) def chunked_cursor(*args, **kwargs): # prevent double wrapping # solves https://github.com/jazzband/django-debug-toolbar/issues/1239 cursor = connection._djdt_chunked_cursor(*args, **kwargs) if not isinstance(cursor, BaseCursorWrapper): - return state.Wrapper(cursor, connection, panel) + if recording.get(): + wrapper = NormalCursorWrapper + else: + wrapper = ExceptionCursorWrapper + return wrapper(cursor, connection, panel) return cursor connection.cursor = cursor diff --git a/debug_toolbar/panels/templates/panel.py b/debug_toolbar/panels/templates/panel.py index 66f6c60fb..0615a7601 100644 --- a/debug_toolbar/panels/templates/panel.py +++ b/debug_toolbar/panels/templates/panel.py @@ -118,7 +118,7 @@ def _store_template_info(self, sender, **kwargs): value.model._meta.label, ) else: - recording(False) + token = recording.set(False) try: saferepr(value) # this MAY trigger a db query except SQLQueryTriggered: @@ -130,7 +130,7 @@ def _store_template_info(self, sender, **kwargs): else: temp_layer[key] = value finally: - recording(True) + recording.reset(token) pformatted = pformat(temp_layer) self.pformat_layers.append((context_layer, pformatted)) context_list.append(pformatted) diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index 4dd1d3fcc..2445827c7 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -1,9 +1,11 @@ +import asyncio import datetime import os import unittest from unittest.mock import patch import django +from asgiref.sync import sync_to_async from django.contrib.auth.models import User from django.db import connection from django.db.models import Count @@ -16,6 +18,14 @@ from ..base import BaseTestCase from ..models import PostgresJSON +from ..sync import database_sync_to_async + + +def sql_call(use_iterator=False): + qs = User.objects.all() + if use_iterator: + qs = qs.iterator() + return list(qs) class SQLPanelTestCase(BaseTestCase): @@ -30,7 +40,7 @@ def test_disabled(self): def test_recording(self): self.assertEqual(len(self.panel._queries), 0) - list(User.objects.all()) + sql_call() # ensure query was logged self.assertEqual(len(self.panel._queries), 1) @@ -49,29 +59,64 @@ def test_recording(self): def test_recording_chunked_cursor(self): self.assertEqual(len(self.panel._queries), 0) - list(User.objects.all().iterator()) + sql_call(use_iterator=True) # ensure query was logged self.assertEqual(len(self.panel._queries), 1) - @patch("debug_toolbar.panels.sql.tracking.state", wraps=sql_tracking.state) - def test_cursor_wrapper_singleton(self, mock_state): - list(User.objects.all()) + @patch( + "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", + wraps=sql_tracking.NormalCursorWrapper, + ) + def test_cursor_wrapper_singleton(self, mock_wrapper): + sql_call() # ensure that cursor wrapping is applied only once - self.assertEqual(mock_state.Wrapper.call_count, 1) + self.assertEqual(mock_wrapper.call_count, 1) - @patch("debug_toolbar.panels.sql.tracking.state", wraps=sql_tracking.state) - def test_chunked_cursor_wrapper_singleton(self, mock_state): - list(User.objects.all().iterator()) + @patch( + "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", + wraps=sql_tracking.NormalCursorWrapper, + ) + def test_chunked_cursor_wrapper_singleton(self, mock_wrapper): + sql_call(use_iterator=True) # ensure that cursor wrapping is applied only once - self.assertEqual(mock_state.Wrapper.call_count, 1) + self.assertEqual(mock_wrapper.call_count, 1) + + @patch( + "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", + wraps=sql_tracking.NormalCursorWrapper, + ) + async def test_cursor_wrapper_async(self, mock_wrapper): + await sync_to_async(sql_call)() + + self.assertEqual(mock_wrapper.call_count, 1) + + @patch( + "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", + wraps=sql_tracking.NormalCursorWrapper, + ) + async def test_cursor_wrapper_asyncio_ctx(self, mock_wrapper): + self.assertTrue(sql_tracking.recording.get()) + await sync_to_async(sql_call)() + + async def task(): + sql_tracking.recording.set(False) + # Calling this in another context requires the db connections + # to be closed properly. + await database_sync_to_async(sql_call)() + + # Ensure this is called in another context + await asyncio.create_task(task()) + # Because it was called in another context, it should not have affected ours + self.assertTrue(sql_tracking.recording.get()) + self.assertEqual(mock_wrapper.call_count, 1) def test_generate_server_timing(self): self.assertEqual(len(self.panel._queries), 0) - list(User.objects.all()) + sql_call() response = self.panel.process_request(self.request) self.panel.generate_stats(self.request, response) @@ -337,7 +382,7 @@ def test_disable_stacktraces(self): self.assertEqual(len(self.panel._queries), 0) with self.settings(DEBUG_TOOLBAR_CONFIG={"ENABLE_STACKTRACES": False}): - list(User.objects.all()) + sql_call() # ensure query was logged self.assertEqual(len(self.panel._queries), 1) diff --git a/tests/sync.py b/tests/sync.py new file mode 100644 index 000000000..d71298089 --- /dev/null +++ b/tests/sync.py @@ -0,0 +1,22 @@ +""" +Taken from channels.db +""" +from asgiref.sync import SyncToAsync +from django.db import close_old_connections + + +class DatabaseSyncToAsync(SyncToAsync): + """ + SyncToAsync version that cleans up old database connections when it exits. + """ + + def thread_handler(self, loop, *args, **kwargs): + close_old_connections() + try: + return super().thread_handler(loop, *args, **kwargs) + finally: + close_old_connections() + + +# The class is TitleCased, but we want to encourage use as a callable/decorator +database_sync_to_async = DatabaseSyncToAsync