From a3e8cc0d28850eea7b9ff3073d6f31c256e83a49 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Wed, 2 Feb 2022 21:56:51 -0300 Subject: [PATCH 1/7] Fix sql recording for async views --- debug_toolbar/panels/sql/tracking.py | 38 ++++++++++------------------ 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index 2ed691344..e3b225e9a 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -1,6 +1,6 @@ +import contextvars import datetime import json -from threading import local from time import time from django.utils.encoding import force_str @@ -13,30 +13,12 @@ except ImportError: PostgresJson = None +recording = contextvars.ContextVar("debug-toolbar-recording", default=True) + class SQLQueryTriggered(Exception): """Thrown when template panel triggers a query""" - pass - - -class ThreadLocalState(local): - def __init__(self): - self.enabled = True - - @property - def Wrapper(self): - if self.enabled: - return NormalCursorWrapper - return ExceptionCursorWrapper - - def recording(self, v): - self.enabled = v - - -state = ThreadLocalState() -recording = state.recording # export function - def wrap_cursor(connection, panel): if not hasattr(connection, "_djdt_cursor"): @@ -50,16 +32,22 @@ def cursor(*args, **kwargs): # See: # https://github.com/jazzband/django-debug-toolbar/pull/615 # https://github.com/jazzband/django-debug-toolbar/pull/896 - return state.Wrapper( - connection._djdt_cursor(*args, **kwargs), connection, panel - ) + if recording.get(): + wrapper = NormalCursorWrapper + else: + wrapper = ExceptionCursorWrapper + return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel) def chunked_cursor(*args, **kwargs): # prevent double wrapping # solves https://github.com/jazzband/django-debug-toolbar/issues/1239 cursor = connection._djdt_chunked_cursor(*args, **kwargs) if not isinstance(cursor, BaseCursorWrapper): - return state.Wrapper(cursor, connection, panel) + if recording.get(): + wrapper = NormalCursorWrapper + else: + wrapper = ExceptionCursorWrapper + return wrapper(cursor, connection, panel) return cursor connection.cursor = cursor From a5a7a819d43882f2822830c8b874be2ae9843a19 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 6 Feb 2022 10:54:03 -0300 Subject: [PATCH 2/7] Recording is a contextvars now, it should be set with `.set()` --- debug_toolbar/panels/templates/panel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug_toolbar/panels/templates/panel.py b/debug_toolbar/panels/templates/panel.py index 66f6c60fb..8c114fd8e 100644 --- a/debug_toolbar/panels/templates/panel.py +++ b/debug_toolbar/panels/templates/panel.py @@ -118,7 +118,7 @@ def _store_template_info(self, sender, **kwargs): value.model._meta.label, ) else: - recording(False) + recording.set(False) try: saferepr(value) # this MAY trigger a db query except SQLQueryTriggered: @@ -130,7 +130,7 @@ def _store_template_info(self, sender, **kwargs): else: temp_layer[key] = value finally: - recording(True) + recording.set(True) pformatted = pformat(temp_layer) self.pformat_layers.append((context_layer, pformatted)) context_list.append(pformatted) From bb8a38896040d86b7b239ad7738b2c92ce6e25e6 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 6 Feb 2022 10:54:20 -0300 Subject: [PATCH 3/7] Fix broken tests --- tests/panels/test_sql.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index 4dd1d3fcc..30f890fb7 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -54,19 +54,25 @@ def test_recording_chunked_cursor(self): # ensure query was logged self.assertEqual(len(self.panel._queries), 1) - @patch("debug_toolbar.panels.sql.tracking.state", wraps=sql_tracking.state) - def test_cursor_wrapper_singleton(self, mock_state): + @patch( + "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", + wraps=sql_tracking.NormalCursorWrapper, + ) + def test_cursor_wrapper_singleton(self, mock_wrapper): list(User.objects.all()) # ensure that cursor wrapping is applied only once - self.assertEqual(mock_state.Wrapper.call_count, 1) + self.assertEqual(mock_wrapper.call_count, 1) - @patch("debug_toolbar.panels.sql.tracking.state", wraps=sql_tracking.state) - def test_chunked_cursor_wrapper_singleton(self, mock_state): + @patch( + "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", + wraps=sql_tracking.NormalCursorWrapper, + ) + def test_chunked_cursor_wrapper_singleton(self, mock_wrapper): list(User.objects.all().iterator()) # ensure that cursor wrapping is applied only once - self.assertEqual(mock_state.Wrapper.call_count, 1) + self.assertEqual(mock_wrapper.call_count, 1) def test_generate_server_timing(self): self.assertEqual(len(self.panel._queries), 0) From 9d2a9c6bc1be5b3f23d21474dfd02c50ee2d4371 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sun, 6 Feb 2022 11:33:59 -0300 Subject: [PATCH 4/7] Add some basic asyncio tests for sql tracking --- tests/panels/test_sql.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index 30f890fb7..d242a2fbd 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -1,7 +1,9 @@ +import asyncio import datetime import os import unittest from unittest.mock import patch +from asgiref.sync import sync_to_async import django from django.contrib.auth.models import User @@ -74,6 +76,33 @@ def test_chunked_cursor_wrapper_singleton(self, mock_wrapper): # ensure that cursor wrapping is applied only once self.assertEqual(mock_wrapper.call_count, 1) + @patch( + "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", + wraps=sql_tracking.NormalCursorWrapper, + ) + async def test_cursor_wrapper_async(self, mock_wrapper): + await sync_to_async(list)(User.objects.all()) + + self.assertEqual(mock_wrapper.call_count, 1) + + @patch( + "debug_toolbar.panels.sql.tracking.NormalCursorWrapper", + wraps=sql_tracking.NormalCursorWrapper, + ) + async def test_cursor_wrapper_asyncio_ctx(self, mock_wrapper): + self.assertTrue(sql_tracking.recording.get()) + await sync_to_async(list)(User.objects.all()) + + async def task(): + sql_tracking.recording.set(False) + await sync_to_async(list)(User.objects.all()) + + # Ensure this is called in another context + await asyncio.create_task(task()) + # Because it was called in another context, it should not have affected ours + self.assertTrue(sql_tracking.recording.get()) + self.assertEqual(mock_wrapper.call_count, 1) + def test_generate_server_timing(self): self.assertEqual(len(self.panel._queries), 0) From 75a62b7bdf692d2d8ba3abf26279beaee7ec22a0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 6 Feb 2022 14:34:34 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/panels/test_sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index d242a2fbd..8c4cdff10 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -3,9 +3,9 @@ import os import unittest from unittest.mock import patch -from asgiref.sync import sync_to_async import django +from asgiref.sync import sync_to_async from django.contrib.auth.models import User from django.db import connection from django.db.models import Count From ee558b5c55efdbf2215fecf4ef7a4c8d4ba342e6 Mon Sep 17 00:00:00 2001 From: tschilling Date: Sun, 20 Feb 2022 09:43:36 -0600 Subject: [PATCH 6/7] Close DB connections when testing async functionality. --- tests/panels/test_sql.py | 28 +++++++++++++++++++--------- tests/sync.py | 22 ++++++++++++++++++++++ 2 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 tests/sync.py diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index 8c4cdff10..2445827c7 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -18,6 +18,14 @@ from ..base import BaseTestCase from ..models import PostgresJSON +from ..sync import database_sync_to_async + + +def sql_call(use_iterator=False): + qs = User.objects.all() + if use_iterator: + qs = qs.iterator() + return list(qs) class SQLPanelTestCase(BaseTestCase): @@ -32,7 +40,7 @@ def test_disabled(self): def test_recording(self): self.assertEqual(len(self.panel._queries), 0) - list(User.objects.all()) + sql_call() # ensure query was logged self.assertEqual(len(self.panel._queries), 1) @@ -51,7 +59,7 @@ def test_recording(self): def test_recording_chunked_cursor(self): self.assertEqual(len(self.panel._queries), 0) - list(User.objects.all().iterator()) + sql_call(use_iterator=True) # ensure query was logged self.assertEqual(len(self.panel._queries), 1) @@ -61,7 +69,7 @@ def test_recording_chunked_cursor(self): wraps=sql_tracking.NormalCursorWrapper, ) def test_cursor_wrapper_singleton(self, mock_wrapper): - list(User.objects.all()) + sql_call() # ensure that cursor wrapping is applied only once self.assertEqual(mock_wrapper.call_count, 1) @@ -71,7 +79,7 @@ def test_cursor_wrapper_singleton(self, mock_wrapper): wraps=sql_tracking.NormalCursorWrapper, ) def test_chunked_cursor_wrapper_singleton(self, mock_wrapper): - list(User.objects.all().iterator()) + sql_call(use_iterator=True) # ensure that cursor wrapping is applied only once self.assertEqual(mock_wrapper.call_count, 1) @@ -81,7 +89,7 @@ def test_chunked_cursor_wrapper_singleton(self, mock_wrapper): wraps=sql_tracking.NormalCursorWrapper, ) async def test_cursor_wrapper_async(self, mock_wrapper): - await sync_to_async(list)(User.objects.all()) + await sync_to_async(sql_call)() self.assertEqual(mock_wrapper.call_count, 1) @@ -91,11 +99,13 @@ async def test_cursor_wrapper_async(self, mock_wrapper): ) async def test_cursor_wrapper_asyncio_ctx(self, mock_wrapper): self.assertTrue(sql_tracking.recording.get()) - await sync_to_async(list)(User.objects.all()) + await sync_to_async(sql_call)() async def task(): sql_tracking.recording.set(False) - await sync_to_async(list)(User.objects.all()) + # Calling this in another context requires the db connections + # to be closed properly. + await database_sync_to_async(sql_call)() # Ensure this is called in another context await asyncio.create_task(task()) @@ -106,7 +116,7 @@ async def task(): def test_generate_server_timing(self): self.assertEqual(len(self.panel._queries), 0) - list(User.objects.all()) + sql_call() response = self.panel.process_request(self.request) self.panel.generate_stats(self.request, response) @@ -372,7 +382,7 @@ def test_disable_stacktraces(self): self.assertEqual(len(self.panel._queries), 0) with self.settings(DEBUG_TOOLBAR_CONFIG={"ENABLE_STACKTRACES": False}): - list(User.objects.all()) + sql_call() # ensure query was logged self.assertEqual(len(self.panel._queries), 1) diff --git a/tests/sync.py b/tests/sync.py new file mode 100644 index 000000000..d71298089 --- /dev/null +++ b/tests/sync.py @@ -0,0 +1,22 @@ +""" +Taken from channels.db +""" +from asgiref.sync import SyncToAsync +from django.db import close_old_connections + + +class DatabaseSyncToAsync(SyncToAsync): + """ + SyncToAsync version that cleans up old database connections when it exits. + """ + + def thread_handler(self, loop, *args, **kwargs): + close_old_connections() + try: + return super().thread_handler(loop, *args, **kwargs) + finally: + close_old_connections() + + +# The class is TitleCased, but we want to encourage use as a callable/decorator +database_sync_to_async = DatabaseSyncToAsync From 468660f5aaea92e87fc6d2859e0475e156592366 Mon Sep 17 00:00:00 2001 From: tschilling Date: Sun, 20 Feb 2022 09:44:13 -0600 Subject: [PATCH 7/7] Reset ContextVar to previous value. --- debug_toolbar/panels/templates/panel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug_toolbar/panels/templates/panel.py b/debug_toolbar/panels/templates/panel.py index 8c114fd8e..0615a7601 100644 --- a/debug_toolbar/panels/templates/panel.py +++ b/debug_toolbar/panels/templates/panel.py @@ -118,7 +118,7 @@ def _store_template_info(self, sender, **kwargs): value.model._meta.label, ) else: - recording.set(False) + token = recording.set(False) try: saferepr(value) # this MAY trigger a db query except SQLQueryTriggered: @@ -130,7 +130,7 @@ def _store_template_info(self, sender, **kwargs): else: temp_layer[key] = value finally: - recording.set(True) + recording.reset(token) pformatted = pformat(temp_layer) self.pformat_layers.append((context_layer, pformatted)) context_list.append(pformatted)