From e396bc03434b9d19a80576abf8adc9304d4530ba Mon Sep 17 00:00:00 2001 From: amirreza Date: Wed, 26 Mar 2025 16:58:03 +0330 Subject: [PATCH 01/21] add test for middleware exception handling --- tests/test_middlewares/middlewares.py | 135 +++++++++ tests/test_middlewares/test_exceptions.py | 345 ++++++++++++++++++++++ tests/test_middlewares/urls.py | 16 + tests/test_middlewares/views.py | 39 +++ 4 files changed, 535 insertions(+) create mode 100644 tests/test_middlewares/middlewares.py create mode 100644 tests/test_middlewares/test_exceptions.py create mode 100644 tests/test_middlewares/urls.py create mode 100644 tests/test_middlewares/views.py diff --git a/tests/test_middlewares/middlewares.py b/tests/test_middlewares/middlewares.py new file mode 100644 index 0000000..d56f467 --- /dev/null +++ b/tests/test_middlewares/middlewares.py @@ -0,0 +1,135 @@ +from asgiref.sync import iscoroutinefunction, markcoroutinefunction + +from django.http import Http404, HttpResponse +from django.template import engines +from django.template.response import TemplateResponse +from django.utils.decorators import ( + async_only_middleware, + sync_and_async_middleware, +) + +log = [] + + +class BaseMiddleware: + async_capable = True + sync_capable = False + + def __init__(self, get_response): + self.get_response = get_response + if iscoroutinefunction(self.get_response): + markcoroutinefunction(self) + + async def __call__(self, request): + return await self.get_response(request) + + +class ProcessExceptionMiddleware(BaseMiddleware): + def process_exception(self, request, exception): + return HttpResponse("Exception caught") + + +@async_only_middleware +class AsyncProcessExceptionMiddleware(BaseMiddleware): + async def process_exception(self, request, exception): + return HttpResponse("Exception caught") + + +class ProcessExceptionLogMiddleware(BaseMiddleware): + def process_exception(self, request, exception): + log.append("process-exception") + + +class ProcessExceptionExcMiddleware(BaseMiddleware): + def process_exception(self, request, exception): + raise Exception("from process-exception") + + +class ProcessViewMiddleware(BaseMiddleware): + def process_view(self, request, view_func, view_args, view_kwargs): + return HttpResponse("Processed view %s" % view_func.__name__) + + +@async_only_middleware +class AsyncProcessViewMiddleware(BaseMiddleware): + async def process_view(self, request, view_func, view_args, view_kwargs): + return HttpResponse("Processed view %s" % view_func.__name__) + + +class ProcessViewNoneMiddleware(BaseMiddleware): + def process_view(self, request, view_func, view_args, view_kwargs): + log.append("processed view %s" % view_func.__name__) + return None + + +class ProcessViewTemplateResponseMiddleware(BaseMiddleware): + def process_view(self, request, view_func, view_args, view_kwargs): + template = engines["django"].from_string( + "Processed view {{ view }}{% for m in mw %}\n{{ m }}{% endfor %}" + ) + return TemplateResponse( + request, + template, + {"mw": [self.__class__.__name__], "view": view_func.__name__}, + ) + + +class TemplateResponseMiddleware(BaseMiddleware): + def process_template_response(self, request, response): + response.context_data["mw"].append(self.__class__.__name__) + return response + + +@async_only_middleware +class AsyncTemplateResponseMiddleware(BaseMiddleware): + async def process_template_response(self, request, response): + response.context_data["mw"].append(self.__class__.__name__) + return response + + +class LogMiddleware(BaseMiddleware): + async def __call__(self, request): + response = await self.get_response(request) + log.append((response.status_code, response.content)) + return response + + +class NoTemplateResponseMiddleware(BaseMiddleware): + def process_template_response(self, request, response): + return None + + +@async_only_middleware +class AsyncNoTemplateResponseMiddleware(BaseMiddleware): + async def process_template_response(self, request, response): + return None + + +class NotFoundMiddleware(BaseMiddleware): + async def __call__(self, request): + raise Http404("not found") + + +@async_only_middleware +def async_payment_middleware(get_response): + async def middleware(request): + response = await get_response(request) + response.status_code = 402 + return response + + return middleware + + +@sync_and_async_middleware +class SyncAndAsyncMiddleware(BaseMiddleware): + pass + + +class NotSyncOrAsyncMiddleware(BaseMiddleware): + """Middleware that is deliberately neither sync or async.""" + + sync_capable = False + async_capable = False + + async def __call__(self, request): + return await self.get_response(request) diff --git a/tests/test_middlewares/test_exceptions.py b/tests/test_middlewares/test_exceptions.py new file mode 100644 index 0000000..6f77678 --- /dev/null +++ b/tests/test_middlewares/test_exceptions.py @@ -0,0 +1,345 @@ +import logging + +import pytest + +from django.core.exceptions import MiddlewareNotUsed +from django.http import HttpResponse +from django.test import override_settings + +from . import middlewares as mw + + +class TestMiddleware: + @pytest.fixture(autouse=True) + def setup(self, settings): + settings.ROOT_URLCONF = "test_middlewares.urls" + yield + mw.log = [] + + def test_process_view_return_none(self, settings, client): + settings.MIDDLEWARE = ["test_middlewares.middlewares.ProcessViewNoneMiddleware"] + response = client.get("/middleware_exceptions/view/") + assert mw.log == ["processed view normal_view"] + assert response.content == b"OK" + + def test_process_view_return_response(self, settings, client): + settings.MIDDLEWARE = ["test_middlewares.middlewares.ProcessViewMiddleware"] + response = client.get("/middleware_exceptions/view/") + assert response.content == b"Processed view normal_view" + + def test_templateresponse_from_process_view_rendered(self, settings, client): + """ + TemplateResponses returned from process_view() must be rendered before + being passed to any middleware that tries to access response.content, + such as test_middlewares.middlewares.LogMiddleware. + """ + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.ProcessViewTemplateResponseMiddleware", + "test_middlewares.middlewares.LogMiddleware", + ] + response = client.get("/middleware_exceptions/view/") + assert ( + response.content + == b"Processed view normal_view\nProcessViewTemplateResponseMiddleware" + ) + + def test_templateresponse_from_process_view_passed_to_process_template_response( + self, settings, client + ): + """ + TemplateResponses returned from process_view() should be passed to any + template response middleware. + """ + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.ProcessViewTemplateResponseMiddleware", + "test_middlewares.middlewares.TemplateResponseMiddleware", + ] + response = client.get("/middleware_exceptions/view/") + expected_lines = [ + b"Processed view normal_view", + b"ProcessViewTemplateResponseMiddleware", + b"TemplateResponseMiddleware", + ] + assert response.content == b"\n".join(expected_lines) + + def test_process_template_response(self, settings, client): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.TemplateResponseMiddleware" + ] + response = client.get("/middleware_exceptions/template_response/") + assert response.content == b"template_response OK\nTemplateResponseMiddleware" + + def test_process_template_response_returns_none(self, settings, client): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.NoTemplateResponseMiddleware" + ] + msg = ( + "NoTemplateResponseMiddleware.process_template_response didn't " + "return an HttpResponse object. It returned None instead." + ) + with pytest.raises(ValueError, match=msg): + client.get("/middleware_exceptions/template_response/") + + def test_view_exception_converted_before_middleware(self, settings, client): + settings.MIDDLEWARE = ["test_middlewares.middlewares.LogMiddleware"] + response = client.get("/middleware_exceptions/permission_denied/") + assert mw.log == [(response.status_code, response.content)] + assert response.status_code == 403 + + def test_view_exception_handled_by_process_exception(self, settings, client): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.ProcessExceptionMiddleware" + ] + response = client.get("/middleware_exceptions/error/") + assert response.content == b"Exception caught" + + def test_response_from_process_exception_short_circuits_remainder( + self, settings, client + ): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.ProcessExceptionLogMiddleware", + "test_middlewares.middlewares.ProcessExceptionMiddleware", + ] + response = client.get("/middleware_exceptions/error/") + assert mw.log == [] + assert response.content == b"Exception caught" + + def test_response_from_process_exception_when_return_response( + self, settings, client + ): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.ProcessExceptionMiddleware", + "test_middlewares.middlewares.ProcessExceptionLogMiddleware", + ] + response = client.get("/middleware_exceptions/error/") + assert mw.log == ["process-exception"] + assert response.content == b"Exception caught" + + @override_settings( + MIDDLEWARE=[ + "test_middlewares.middlewares.LogMiddleware", + "test_middlewares.middlewares.NotFoundMiddleware", + ] + ) + def test_exception_in_middleware_converted_before_prior_middleware( + self, settings, client + ): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.LogMiddleware", + "test_middlewares.middlewares.NotFoundMiddleware", + ] + response = client.get("/middleware_exceptions/view/") + assert mw.log == [(404, response.content)] + assert response.status_code == 404 + + def test_exception_in_render_passed_to_process_exception(self, settings, client): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.ProcessExceptionMiddleware" + ] + response = client.get("/middleware_exceptions/exception_in_render/") + assert response.content == b"Exception caught" + + +class TestRootUrlconf: + @pytest.fixture(autouse=True) + def setup(self, settings): + settings.ROOT_URLCONF = "test_middlewares.urls" + + def test_missing_root_urlconf(self, settings, client): + # Removing ROOT_URLCONF is safe, as override_settings will restore + # the previously defined settings. + del settings.ROOT_URLCONF + with pytest.raises(AttributeError): + client.get("/middleware_exceptions/view/") + + +class MyMiddleware: + def __init__(self, get_response): + raise MiddlewareNotUsed + + async def process_request(self, request): + pass + + +class MyMiddlewareWithExceptionMessage: + def __init__(self, get_response): + raise MiddlewareNotUsed("spam eggs") + + async def process_request(self, request): + pass + + +class TestMiddlewareNotUsed: + @pytest.fixture(autouse=True) + def setup(self, settings): + settings.DEBUG = True + settings.ROOT_URLCONF = "test_middlewares.urls" + settings.MIDDLEWARE = ["django.middleware.common.CommonMiddleware"] + + async def test_raise_exception(self, rf): + request = rf.get("test_middlewares/view/") + with pytest.raises(MiddlewareNotUsed): + await MyMiddleware(lambda req: HttpResponse()).process_request(request) + + def test_log(self, settings, client, caplog): + settings.MIDDLEWARE = ["test_middlewares.test_exceptions.MyMiddleware"] + with caplog.at_level(logging.DEBUG, logger="django.request"): + client.get("/middleware_exceptions/view/") + assert ( + "MiddlewareNotUsed: 'test_middlewares.test_exceptions.MyMiddleware'" + in caplog.text + ) + + def test_log_custom_message(self, settings, client, caplog): + settings.MIDDLEWARE = [ + "test_middlewares.test_exceptions.MyMiddlewareWithExceptionMessage" + ] + with caplog.at_level(logging.DEBUG, logger="django.request"): + client.get("/middleware_exceptions/view/") + assert ( + "MiddlewareNotUsed('test_middlewares.test_exceptions." + "MyMiddlewareWithExceptionMessage'): spam eggs" in caplog.text + ) + + def test_do_not_log_when_debug_is_false(self, settings, client, caplog): + settings.DEBUG = False + settings.MIDDLEWARE = ["test_middlewares.test_exceptions.MyMiddleware"] + with caplog.at_level(logging.DEBUG, logger="django.request"): + client.get("/middleware_exceptions/view/") + assert not caplog.records + + async def test_async_and_sync_middleware_chain_async_call( + self, settings, async_client, caplog + ): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.SyncAndAsyncMiddleware", + "test_middlewares.test_exceptions.MyMiddleware", + ] + with caplog.at_level(logging.DEBUG, logger="django.request"): + response = await async_client.get("/middleware_exceptions/view/") + assert response.content == b"OK" + assert response.status_code == 200 + assert ( + "Asynchronous handler adapted for middleware " + "test_middlewares.test_exceptions.MyMiddleware." in caplog.text + ) + assert ( + "MiddlewareNotUsed: 'test_middlewares.test_exceptions.MyMiddleware'" + in caplog.text + ) + + +class TestMiddlewareSyncAsync: + @pytest.fixture(autouse=True) + def setup(self, settings): + settings.DEBUG = True + settings.ROOT_URLCONF = "test_middlewares.urls" + + def test_async_middleware(self, settings, client, caplog): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.async_payment_middleware", + ] + with caplog.at_level(logging.DEBUG, "django.request"): + response = client.get("/middleware_exceptions/view/") + assert response.status_code == 402 + assert ( + "Synchronous handler adapted for middleware " + "test_middlewares.middlewares.async_payment_middleware." in caplog.text + ) + + def test_not_sync_or_async_middleware(self, settings, client): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.NotSyncOrAsyncMiddleware", + ] + msg = ( + "Middleware " + "test_middlewares.middlewares.NotSyncOrAsyncMiddleware must " + "have at least one of sync_capable/async_capable set to True." + ) + with pytest.raises(RuntimeError, match=msg): + client.get("/middleware_exceptions/view/") + + async def test_async_middleware_async(self, settings, async_client, caplog): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.async_payment_middleware", + ] + with caplog.at_level("WARNING", "django.request"): + response = await async_client.get("/middleware_exceptions/view/") + assert response.status_code == 402 + assert "Payment Required: /middleware_exceptions/view/" in caplog.text + + def test_async_process_template_response_returns_none_with_sync_client( + self, settings, client + ): + settings.DEBUG = False + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.AsyncNoTemplateResponseMiddleware", + ] + msg = ( + "AsyncNoTemplateResponseMiddleware.process_template_response " + "didn't return an HttpResponse object." + ) + with pytest.raises(ValueError, match=msg): + client.get("/middleware_exceptions/template_response/") + + +class TestAsyncMiddleware: + @pytest.fixture(autouse=True) + def setup(self, settings): + settings.ROOT_URLCONF = "test_middlewares.urls" + + async def test_process_template_response(self, settings, async_client): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.AsyncTemplateResponseMiddleware", + ] + response = await async_client.get("/middleware_exceptions/template_response/") + assert ( + response.content == b"template_response OK\nAsyncTemplateResponseMiddleware" + ) + + async def test_process_template_response_returns_none(self, settings, async_client): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.AsyncNoTemplateResponseMiddleware", + ] + msg = ( + "AsyncNoTemplateResponseMiddleware.process_template_response " + "didn't return an HttpResponse object. It returned None instead." + ) + with pytest.raises(ValueError, match=msg): + await async_client.get("/middleware_exceptions/template_response/") + + async def test_exception_in_render_passed_to_process_exception( + self, settings, async_client + ): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.AsyncProcessExceptionMiddleware", + ] + response = await async_client.get("/middleware_exceptions/exception_in_render/") + assert response.content == b"Exception caught" + + async def test_exception_in_async_render_passed_to_process_exception( + self, settings, async_client + ): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.AsyncProcessExceptionMiddleware", + ] + response = await async_client.get( + "/middleware_exceptions/async_exception_in_render/" + ) + assert response.content == b"Exception caught" + + async def test_view_exception_handled_by_process_exception( + self, settings, async_client + ): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.AsyncProcessExceptionMiddleware", + ] + response = await async_client.get("/middleware_exceptions/error/") + assert response.content == b"Exception caught" + + async def test_process_view_return_response(self, settings, async_client): + settings.MIDDLEWARE = [ + "test_middlewares.middlewares.AsyncProcessViewMiddleware", + ] + response = await async_client.get("/middleware_exceptions/view/") + assert response.content == b"Processed view normal_view" diff --git a/tests/test_middlewares/urls.py b/tests/test_middlewares/urls.py new file mode 100644 index 0000000..80cbb2c --- /dev/null +++ b/tests/test_middlewares/urls.py @@ -0,0 +1,16 @@ +from django.urls import path + +from . import views + +urlpatterns = [ + path("middleware_exceptions/view/", views.normal_view), + path("middleware_exceptions/error/", views.server_error), + path("middleware_exceptions/permission_denied/", views.permission_denied), + path("middleware_exceptions/exception_in_render/", views.exception_in_render), + path("middleware_exceptions/template_response/", views.template_response), + # Async views. + path( + "middleware_exceptions/async_exception_in_render/", + views.async_exception_in_render, + ), +] diff --git a/tests/test_middlewares/views.py b/tests/test_middlewares/views.py new file mode 100644 index 0000000..0f1595b --- /dev/null +++ b/tests/test_middlewares/views.py @@ -0,0 +1,39 @@ +from django.core.exceptions import PermissionDenied +from django.http import HttpResponse +from django.template import engines +from django.template.response import TemplateResponse + + +def normal_view(request): + return HttpResponse("OK") + + +def template_response(request): + template = engines["django"].from_string( + "template_response OK{% for m in mw %}\n{{ m }}{% endfor %}" + ) + return TemplateResponse(request, template, context={"mw": []}) + + +def server_error(request): + raise Exception("Error in view") + + +def permission_denied(request): + raise PermissionDenied() + + +def exception_in_render(request): + class CustomHttpResponse(HttpResponse): + def render(self): + raise Exception("Exception in HttpResponse.render()") + + return CustomHttpResponse("Error") + + +async def async_exception_in_render(request): + class CustomHttpResponse(HttpResponse): + async def render(self): + raise Exception("Exception in HttpResponse.render()") + + return CustomHttpResponse("Error") From 8aaa846571b28c8b53f3d070f252b2cc6fc34610 Mon Sep 17 00:00:00 2001 From: amirreza Date: Wed, 26 Mar 2025 02:45:02 +0330 Subject: [PATCH 02/21] implement async security middleware --- .../middleware/security.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 django_async_extensions/middleware/security.py diff --git a/django_async_extensions/middleware/security.py b/django_async_extensions/middleware/security.py new file mode 100644 index 0000000..2dab776 --- /dev/null +++ b/django_async_extensions/middleware/security.py @@ -0,0 +1,67 @@ +import re + +from django.conf import settings +from django.http import HttpResponsePermanentRedirect + +from django_async_extensions.middleware.base import AsyncMiddlewareMixin + + +class AsyncSecurityMiddleware(AsyncMiddlewareMixin): + def __init__(self, get_response): + super().__init__(get_response) + self.sts_seconds = settings.SECURE_HSTS_SECONDS + self.sts_include_subdomains = settings.SECURE_HSTS_INCLUDE_SUBDOMAINS + self.sts_preload = settings.SECURE_HSTS_PRELOAD + self.content_type_nosniff = settings.SECURE_CONTENT_TYPE_NOSNIFF + self.redirect = settings.SECURE_SSL_REDIRECT + self.redirect_host = settings.SECURE_SSL_HOST + self.redirect_exempt = [re.compile(r) for r in settings.SECURE_REDIRECT_EXEMPT] + self.referrer_policy = settings.SECURE_REFERRER_POLICY + self.cross_origin_opener_policy = settings.SECURE_CROSS_ORIGIN_OPENER_POLICY + + async def process_request(self, request): + path = request.path.lstrip("/") + if ( + self.redirect + and not request.is_secure() + and not any(pattern.search(path) for pattern in self.redirect_exempt) + ): + host = self.redirect_host or request.get_host() + return HttpResponsePermanentRedirect( + "https://%s%s" % (host, request.get_full_path()) + ) + + async def process_response(self, request, response): + if ( + self.sts_seconds + and request.is_secure() + and "Strict-Transport-Security" not in response + ): + sts_header = "max-age=%s" % self.sts_seconds + if self.sts_include_subdomains: + sts_header += "; includeSubDomains" + if self.sts_preload: + sts_header += "; preload" + response.headers["Strict-Transport-Security"] = sts_header + + if self.content_type_nosniff: + response.headers.setdefault("X-Content-Type-Options", "nosniff") + + if self.referrer_policy: + # Support a comma-separated string or iterable of values to allow + # fallback. + response.headers.setdefault( + "Referrer-Policy", + ",".join( + [v.strip() for v in self.referrer_policy.split(",")] + if isinstance(self.referrer_policy, str) + else self.referrer_policy + ), + ) + + if self.cross_origin_opener_policy: + response.setdefault( + "Cross-Origin-Opener-Policy", + self.cross_origin_opener_policy, + ) + return response From ef407c8b0b710e5c129ce67aa7e372f22c545093 Mon Sep 17 00:00:00 2001 From: amirreza Date: Wed, 26 Mar 2025 02:53:22 +0330 Subject: [PATCH 03/21] document async security middleware --- docs/middleware/security_middleware.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 docs/middleware/security_middleware.md diff --git a/docs/middleware/security_middleware.md b/docs/middleware/security_middleware.md new file mode 100644 index 0000000..72bba9c --- /dev/null +++ b/docs/middleware/security_middleware.md @@ -0,0 +1,10 @@ +# AsyncSecurityMiddleware + +it works exactly like django's [SecurityMiddleware](https://docs.djangoproject.com/en/5.1/ref/middleware/#module-django.middleware.security) +except that it's fully async + +### Usage: +remove django's `django.middleware.security.SecurityMiddleware` from the `MIDDLEWARE` setting and add +`django_async_extensions.middleware.security.AsyncSecurityMiddleware` in it's place. + +**note**: this middleware like other middlewares provided in this package can work alongside sync middlewares, and can handle sync views. From cc0db1c36c3d59c26fae4823a15e4755067b47ce Mon Sep 17 00:00:00 2001 From: amirreza Date: Wed, 26 Mar 2025 02:53:41 +0330 Subject: [PATCH 04/21] add tests for async security middleware --- .../test_middlewares/test_middleware_mixin.py | 27 ++ tests/test_middlewares/test_security.py | 381 ++++++++++++++++++ 2 files changed, 408 insertions(+) create mode 100644 tests/test_middlewares/test_security.py diff --git a/tests/test_middlewares/test_middleware_mixin.py b/tests/test_middlewares/test_middleware_mixin.py index 66156df..1abce34 100644 --- a/tests/test_middlewares/test_middleware_mixin.py +++ b/tests/test_middlewares/test_middleware_mixin.py @@ -6,6 +6,7 @@ from django.http.response import HttpResponse from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.middleware.security import AsyncSecurityMiddleware req = HttpResponse() resp = HttpResponse() @@ -30,6 +31,10 @@ async def process_request(self, request): class TestMiddlewareMixin: + middlewares = [ + AsyncSecurityMiddleware, + ] + def test_repr(self): class GetResponse: async def __call__(self): @@ -48,6 +53,28 @@ async def get_response(): "TestMiddlewareMixin.test_repr..get_response>" ) + def test_passing_explicit_none(self, subtests): + msg = "get_response must be provided" + for middleware in self.middlewares: + with subtests.test(middleware=middleware): + with pytest.raises(ValueError, match=msg): + middleware(None) + + def test_coroutine(self, subtests): + async def async_get_response(request): + return HttpResponse() + + def sync_get_response(request): + return HttpResponse() + + for middleware in self.middlewares: + with subtests.test(middleware=middleware.__qualname__): + middleware_instance = middleware(async_get_response) + assert iscoroutinefunction(middleware_instance) + + with pytest.raises(ImproperlyConfigured): + middleware(sync_get_response) + def test_call_is_async(self): assert iscoroutinefunction(AsyncMiddlewareMixin.__call__) diff --git a/tests/test_middlewares/test_security.py b/tests/test_middlewares/test_security.py new file mode 100644 index 0000000..d9bb4ff --- /dev/null +++ b/tests/test_middlewares/test_security.py @@ -0,0 +1,381 @@ +import re + +import pytest + +from django.http import HttpResponse +from django.test import AsyncRequestFactory +from django.test.utils import override_settings + + +class TestSecurityMiddleware: + @pytest.fixture + def set_setting_fixture(self, request, settings): + # request.param should be a list of [settings_name, value] + setattr(settings, request.param[0], request.param[1]) + + @pytest.fixture + def hsts_seconds_fixture(self, request, settings): + settings.SECURE_HSTS_SECONDS = request.param + + @pytest.fixture + def hsts_with_subdomains_fixture(self, request, settings): + settings.SECURE_HSTS_INCLUDE_SUBDOMAINS = request.param + + @pytest.fixture + def ssl_redirect_fixture(self, request, settings): + settings.SECURE_SSL_REDIRECT = request.param + + def middleware(self, *args, **kwargs): + from django_async_extensions.middleware.security import AsyncSecurityMiddleware + + return AsyncSecurityMiddleware(self.response(*args, **kwargs)) + + @property + def secure_request_kwargs(self): + return {"type": "https"} + + def response(self, *args, headers=None, **kwargs): + async def get_response(req): + response = HttpResponse(*args, **kwargs) + if headers: + for k, v in headers.items(): + response.headers[k] = v + return response + + return get_response + + async def process_response(self, *args, secure=False, request=None, **kwargs): + request_kwargs = {} + if secure: + request_kwargs.update(self.secure_request_kwargs) + if request is None: + request = self.request.get("/some/url", secure=secure, **request_kwargs) + ret = await self.middleware(*args, **kwargs).process_request(request) + if ret: + return ret + return await self.middleware(*args, **kwargs)(request) + + request = AsyncRequestFactory() + + async def process_request(self, method, *args, secure=False, **kwargs): + if secure: + kwargs.update(self.secure_request_kwargs) + req = getattr(self.request, method.lower())(*args, secure=secure, **kwargs) + return await self.middleware().process_request(req) + + def test_middleware_instance_attributes(self, settings): + middleware = self.middleware() + assert middleware.sts_seconds == settings.SECURE_HSTS_SECONDS + assert ( + middleware.sts_include_subdomains == settings.SECURE_HSTS_INCLUDE_SUBDOMAINS + ) + assert middleware.sts_preload == settings.SECURE_HSTS_PRELOAD + assert middleware.content_type_nosniff == settings.SECURE_CONTENT_TYPE_NOSNIFF + assert middleware.redirect == settings.SECURE_SSL_REDIRECT + assert middleware.redirect_host == settings.SECURE_SSL_HOST + assert middleware.redirect_exempt == [ + re.compile(r) for r in settings.SECURE_REDIRECT_EXEMPT + ] + assert middleware.referrer_policy == settings.SECURE_REFERRER_POLICY + assert ( + middleware.cross_origin_opener_policy + == settings.SECURE_CROSS_ORIGIN_OPENER_POLICY + ) + + @pytest.mark.parametrize("hsts_seconds_fixture", [3600], indirect=True) + async def test_sts_on(self, hsts_seconds_fixture): + """ + With SECURE_HSTS_SECONDS=3600, the middleware adds + "Strict-Transport-Security: max-age=3600" to the response. + """ + response = await self.process_response(secure=True) + assert response.headers["Strict-Transport-Security"] == "max-age=3600" + + @pytest.mark.parametrize("hsts_seconds_fixture", [3600], indirect=True) + async def test_sts_already_present(self, hsts_seconds_fixture): + """ + The middleware will not override a "Strict-Transport-Security" header + already present in the response. + """ + response = await self.process_response( + secure=True, headers={"Strict-Transport-Security": "max-age=7200"} + ) + assert response.headers["Strict-Transport-Security"] == "max-age=7200" + + @pytest.mark.parametrize("hsts_seconds_fixture", [3600], indirect=True) + async def test_sts_only_if_secure(self, hsts_seconds_fixture): + """ + The "Strict-Transport-Security" header is not added to responses going + over an insecure connection. + """ + response = await self.process_response(secure=False) + assert "Strict-Transport-Security" not in response.headers + + @pytest.mark.parametrize("hsts_seconds_fixture", [0], indirect=True) + async def test_sts_off(self, hsts_seconds_fixture): + """ + With SECURE_HSTS_SECONDS=0, the middleware does not add a + "Strict-Transport-Security" header to the response. + """ + response = await self.process_response(secure=True) + assert "Strict-Transport-Security" not in response.headers + + @pytest.mark.parametrize("hsts_seconds_fixture", [600], indirect=True) + @pytest.mark.parametrize("hsts_with_subdomains_fixture", [True], indirect=True) + async def test_sts_include_subdomains( + self, hsts_seconds_fixture, hsts_with_subdomains_fixture + ): + """ + With SECURE_HSTS_SECONDS non-zero and SECURE_HSTS_INCLUDE_SUBDOMAINS + True, the middleware adds a "Strict-Transport-Security" header with the + "includeSubDomains" directive to the response. + """ + response = await self.process_response(secure=True) + assert ( + response.headers["Strict-Transport-Security"] + == "max-age=600; includeSubDomains" + ) + + @pytest.mark.parametrize("hsts_seconds_fixture", [600], indirect=True) + @pytest.mark.parametrize("hsts_with_subdomains_fixture", [False], indirect=True) + async def test_sts_no_include_subdomains( + self, hsts_seconds_fixture, hsts_with_subdomains_fixture + ): + """ + With SECURE_HSTS_SECONDS non-zero and SECURE_HSTS_INCLUDE_SUBDOMAINS + False, the middleware adds a "Strict-Transport-Security" header without + the "includeSubDomains" directive to the response. + """ + response = await self.process_response(secure=True) + assert response.headers["Strict-Transport-Security"] == "max-age=600" + + @pytest.mark.parametrize("hsts_seconds_fixture", [10886400], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["SECURE_HSTS_PRELOAD", True]], indirect=True + ) + async def test_sts_preload(self, hsts_seconds_fixture, set_setting_fixture): + """ + With SECURE_HSTS_SECONDS non-zero and SECURE_HSTS_PRELOAD True, the + middleware adds a "Strict-Transport-Security" header with the "preload" + directive to the response. + """ + response = await self.process_response(secure=True) + assert ( + response.headers["Strict-Transport-Security"] == "max-age=10886400; preload" + ) + + @pytest.mark.parametrize("hsts_seconds_fixture", [10886400], indirect=True) + @pytest.mark.parametrize("hsts_with_subdomains_fixture", [True], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["SECURE_HSTS_PRELOAD", True]], indirect=True + ) + async def test_sts_subdomains_and_preload( + self, hsts_seconds_fixture, hsts_with_subdomains_fixture, set_setting_fixture + ): + """ + With SECURE_HSTS_SECONDS non-zero, SECURE_HSTS_INCLUDE_SUBDOMAINS and + SECURE_HSTS_PRELOAD True, the middleware adds a "Strict-Transport-Security" + header containing both the "includeSubDomains" and "preload" directives + to the response. + """ + response = await self.process_response(secure=True) + assert ( + response.headers["Strict-Transport-Security"] + == "max-age=10886400; includeSubDomains; preload" + ) + + @pytest.mark.parametrize("hsts_seconds_fixture", [10886400], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["SECURE_HSTS_PRELOAD", False]], indirect=True + ) + async def test_sts_no_preload(self, hsts_seconds_fixture, set_setting_fixture): + """ + With SECURE_HSTS_SECONDS non-zero and SECURE_HSTS_PRELOAD + False, the middleware adds a "Strict-Transport-Security" header without + the "preload" directive to the response. + """ + response = await self.process_response(secure=True) + assert response.headers["Strict-Transport-Security"] == "max-age=10886400" + + @pytest.mark.parametrize( + "set_setting_fixture", [["SECURE_CONTENT_TYPE_NOSNIFF", True]], indirect=True + ) + async def test_content_type_on(self, set_setting_fixture): + """ + With SECURE_CONTENT_TYPE_NOSNIFF set to True, the middleware adds + "X-Content-Type-Options: nosniff" header to the response. + """ + response = await self.process_response() + assert response.headers["X-Content-Type-Options"] == "nosniff" + + @pytest.mark.parametrize( + "set_setting_fixture", [["SECURE_CONTENT_TYPE_NOSNIFF", True]], indirect=True + ) + async def test_content_type_already_present(self, set_setting_fixture): + """ + The middleware will not override an "X-Content-Type-Options" header + already present in the response. + """ + response = await self.process_response( + secure=True, headers={"X-Content-Type-Options": "foo"} + ) + assert response.headers["X-Content-Type-Options"] == "foo" + + @pytest.mark.parametrize( + "set_setting_fixture", [["SECURE_CONTENT_TYPE_NOSNIFF", False]], indirect=True + ) + async def test_content_type_off(self, set_setting_fixture): + """ + With SECURE_CONTENT_TYPE_NOSNIFF False, the middleware does not add an + "X-Content-Type-Options" header to the response. + """ + response = await self.process_response() + assert "X-Content-Type-Options" not in response.headers + + @pytest.mark.parametrize("ssl_redirect_fixture", [True], indirect=True) + async def test_ssl_redirect_on(self, ssl_redirect_fixture): + """ + With SECURE_SSL_REDIRECT True, the middleware redirects any non-secure + requests to the https:// version of the same URL. + """ + ret = await self.process_request("get", "/some/url?query=string") + assert ret.status_code == 301 + assert ret["Location"] == "https://testserver/some/url?query=string" + + @pytest.mark.parametrize("ssl_redirect_fixture", [True], indirect=True) + async def test_no_redirect_ssl(self, ssl_redirect_fixture): + """ + The middleware does not redirect secure requests. + """ + ret = await self.process_request("get", "/some/url", secure=True) + assert ret is None + + @pytest.mark.parametrize( + "set_setting_fixture", + [["SECURE_REDIRECT_EXEMPT", ["^insecure/"]]], + indirect=True, + ) + @pytest.mark.parametrize("ssl_redirect_fixture", [True], indirect=True) + async def test_redirect_exempt(self, set_setting_fixture, ssl_redirect_fixture): + """ + The middleware does not redirect requests with URL path matching an + exempt pattern. + """ + ret = await self.process_request("get", "/insecure/page") + assert ret is None + + @pytest.mark.parametrize( + "set_setting_fixture", + [["SECURE_SSL_HOST", "secure.example.com"]], + indirect=True, + ) + @pytest.mark.parametrize("ssl_redirect_fixture", [True], indirect=True) + async def test_redirect_ssl_host(self, ssl_redirect_fixture, set_setting_fixture): + """ + The middleware redirects to SECURE_SSL_HOST if given. + """ + ret = await self.process_request("get", "/some/url") + assert ret.status_code == 301 + assert ret["Location"] == "https://secure.example.com/some/url" + + @pytest.mark.parametrize("ssl_redirect_fixture", [False], indirect=True) + async def test_ssl_redirect_off(self, ssl_redirect_fixture): + """ + With SECURE_SSL_REDIRECT False, the middleware does not redirect. + """ + ret = await self.process_request("get", "/some/url") + assert ret is None + + @pytest.mark.parametrize( + "set_setting_fixture", [["SECURE_REFERRER_POLICY", None]], indirect=True + ) + async def test_referrer_policy_off(self, set_setting_fixture): + """ + With SECURE_REFERRER_POLICY set to None, the middleware does not add a + "Referrer-Policy" header to the response. + """ + response = await self.process_response() + assert "Referrer-Policy" not in response.headers + + async def test_referrer_policy_on(self, subtests): + """ + With SECURE_REFERRER_POLICY set to a valid value, the middleware adds a + "Referrer-Policy" header to the response. + """ + tests = ( + ("strict-origin", "strict-origin"), + ("strict-origin,origin", "strict-origin,origin"), + ("strict-origin, origin", "strict-origin,origin"), + (["strict-origin", "origin"], "strict-origin,origin"), + (("strict-origin", "origin"), "strict-origin,origin"), + ) + for value, expected in tests: + with ( + subtests.test(value=value), + override_settings(SECURE_REFERRER_POLICY=value), + ): + response = await self.process_response() + assert response.headers["Referrer-Policy"] == expected + + @pytest.mark.parametrize( + "set_setting_fixture", + [["SECURE_REFERRER_POLICY", "strict-origin"]], + indirect=True, + ) + async def test_referrer_policy_already_present(self, set_setting_fixture): + """ + The middleware will not override a "Referrer-Policy" header already + present in the response. + """ + response = await self.process_response( + headers={"Referrer-Policy": "unsafe-url"} + ) + assert response.headers["Referrer-Policy"] == "unsafe-url" + + @pytest.mark.parametrize( + "set_setting_fixture", + [["SECURE_CROSS_ORIGIN_OPENER_POLICY", None]], + indirect=True, + ) + async def test_coop_off(self, set_setting_fixture): + """ + With SECURE_CROSS_ORIGIN_OPENER_POLICY set to None, the middleware does + not add a "Cross-Origin-Opener-Policy" header to the response. + """ + assert "Cross-Origin-Opener-Policy" not in await self.process_response() + + async def test_coop_default(self): + """SECURE_CROSS_ORIGIN_OPENER_POLICY defaults to same-origin.""" + response = await self.process_response() + assert response.headers["Cross-Origin-Opener-Policy"] == "same-origin" + + async def test_coop_on(self, subtests): + """ + With SECURE_CROSS_ORIGIN_OPENER_POLICY set to a valid value, the + middleware adds a "Cross-Origin_Opener-Policy" header to the response. + """ + tests = ["same-origin", "same-origin-allow-popups", "unsafe-none"] + for value in tests: + with ( + subtests.test(value=value), + override_settings( + SECURE_CROSS_ORIGIN_OPENER_POLICY=value, + ), + ): + response = await self.process_response() + assert response.headers["Cross-Origin-Opener-Policy"] == value + + @pytest.mark.parametrize( + "set_setting_fixture", + [["SECURE_CROSS_ORIGIN_OPENER_POLICY", "unsafe-none"]], + indirect=True, + ) + async def test_coop_already_present(self, set_setting_fixture): + """ + The middleware doesn't override a "Cross-Origin-Opener-Policy" header + already present in the response. + """ + response = await self.process_response( + headers={"Cross-Origin-Opener-Policy": "same-origin"} + ) + assert response.headers["Cross-Origin-Opener-Policy"] == "same-origin" From 37043c663f838bfe448f5db47e0b2a4fca213588 Mon Sep 17 00:00:00 2001 From: amirreza Date: Wed, 26 Mar 2025 02:56:22 +0330 Subject: [PATCH 05/21] implement async locale middleware --- django_async_extensions/middleware/locale.py | 81 ++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 django_async_extensions/middleware/locale.py diff --git a/django_async_extensions/middleware/locale.py b/django_async_extensions/middleware/locale.py new file mode 100644 index 0000000..2cb130e --- /dev/null +++ b/django_async_extensions/middleware/locale.py @@ -0,0 +1,81 @@ +from django.conf import settings +from django.conf.urls.i18n import is_language_prefix_patterns_used +from django.http import HttpResponseRedirect +from django.urls import get_script_prefix, is_valid_path +from django.utils import translation +from django.utils.cache import patch_vary_headers + +from django_async_extensions.middleware.base import AsyncMiddlewareMixin + + +class AsyncLocaleMiddleware(AsyncMiddlewareMixin): + """ + Parse a request and decide what translation object to install in the + current thread context. This allows pages to be dynamically translated to + the language the user desires (if the language is available). + """ + + response_redirect_class = HttpResponseRedirect + + async def process_request(self, request): + urlconf = getattr(request, "urlconf", settings.ROOT_URLCONF) + ( + i18n_patterns_used, + prefixed_default_language, + ) = is_language_prefix_patterns_used(urlconf) + language = translation.get_language_from_request( + request, check_path=i18n_patterns_used + ) + language_from_path = translation.get_language_from_path(request.path_info) + if ( + not language_from_path + and i18n_patterns_used + and not prefixed_default_language + ): + language = settings.LANGUAGE_CODE + translation.activate(language) + request.LANGUAGE_CODE = translation.get_language() + + async def process_response(self, request, response): + language = translation.get_language() + language_from_path = translation.get_language_from_path(request.path_info) + urlconf = getattr(request, "urlconf", settings.ROOT_URLCONF) + ( + i18n_patterns_used, + prefixed_default_language, + ) = is_language_prefix_patterns_used(urlconf) + + if ( + response.status_code == 404 + and not language_from_path + and i18n_patterns_used + and prefixed_default_language + ): + # Maybe the language code is missing in the URL? Try adding the + # language prefix and redirecting to that URL. + language_path = "/%s%s" % (language, request.path_info) + path_valid = is_valid_path(language_path, urlconf) + path_needs_slash = not path_valid and ( + settings.APPEND_SLASH + and not language_path.endswith("/") + and is_valid_path("%s/" % language_path, urlconf) + ) + + if path_valid or path_needs_slash: + script_prefix = get_script_prefix() + # Insert language after the script prefix and before the + # rest of the URL + language_url = request.get_full_path( + force_append_slash=path_needs_slash + ).replace(script_prefix, "%s%s/" % (script_prefix, language), 1) + # Redirect to the language-specific URL as detected by + # get_language_from_request(). HTTP caches may cache this + # redirect, so add the Vary header. + redirect = self.response_redirect_class(language_url) + patch_vary_headers(redirect, ("Accept-Language", "Cookie")) + return redirect + + if not (i18n_patterns_used and language_from_path): + patch_vary_headers(response, ("Accept-Language",)) + response.headers.setdefault("Content-Language", language) + return response From e7fa022305ce4a3d0aca353a9fcc626bc7860d0d Mon Sep 17 00:00:00 2001 From: amirreza Date: Wed, 26 Mar 2025 19:49:59 +0330 Subject: [PATCH 06/21] document async locale middleware --- docs/middleware/locale_middleware.md | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 docs/middleware/locale_middleware.md diff --git a/docs/middleware/locale_middleware.md b/docs/middleware/locale_middleware.md new file mode 100644 index 0000000..96c4f01 --- /dev/null +++ b/docs/middleware/locale_middleware.md @@ -0,0 +1,11 @@ +# AsyncLocaleMiddleware + +it works exactly like django's [LocaleMiddleware](https://docs.djangoproject.com/en/5.1/ref/middleware/#module-django.middleware.locale) +except that it's fully async +read django's [Internationalization](https://docs.djangoproject.com/en/5.1/topics/i18n/translation/) documentations for a more in depth look. + +### Usage: +remove django's `django.middleware.locale.LocaleMiddleware` from the `MIDDLEWARE` setting and add +`django_async_extensions.middleware.locale.AsyncLocaleMiddleware` in it's place. + +**note**: this middleware like other middlewares provided in this package can work alongside sync middlewares, and can handle sync views. From 3879749ebd90732128766aff0ca4fe94b9e32545 Mon Sep 17 00:00:00 2001 From: amirreza Date: Wed, 26 Mar 2025 19:50:40 +0330 Subject: [PATCH 07/21] add tests for async locale middleware --- tests/test_middlewares/test_locale.py | 25 +++++++++++++++++++ .../test_middlewares/test_middleware_mixin.py | 2 ++ tests/test_middlewares/urls.py | 16 ++++++++++++ 3 files changed, 43 insertions(+) create mode 100644 tests/test_middlewares/test_locale.py diff --git a/tests/test_middlewares/test_locale.py b/tests/test_middlewares/test_locale.py new file mode 100644 index 0000000..5eb4465 --- /dev/null +++ b/tests/test_middlewares/test_locale.py @@ -0,0 +1,25 @@ +import pytest + + +@pytest.mark.django_db +class TestLocaleMiddleware: + @pytest.fixture(autouse=True) + def set_settings(self, settings): + settings.USE_I18N = True + settings.LANGUAGES = [("en", "English"), ("fr", "French")] + settings.MIDDLEWARE = [ + "django_async_extensions.middleware.locale.AsyncLocaleMiddleware", + "django.middleware.common.CommonMiddleware", + ] + settings.ROOT_URLCONF = "test_middlewares.urls" + + async def test_streaming_response(self, async_client): + # Regression test for #5241 + response = await async_client.get("/fr/streaming/") + assert b"Oui/Non" in b"".join( + [content async for content in response.streaming_content] + ) + response = await async_client.get("/en/streaming/") + assert b"Yes/No" in b"".join( + [content async for content in response.streaming_content] + ) diff --git a/tests/test_middlewares/test_middleware_mixin.py b/tests/test_middlewares/test_middleware_mixin.py index 1abce34..d3ab5cb 100644 --- a/tests/test_middlewares/test_middleware_mixin.py +++ b/tests/test_middlewares/test_middleware_mixin.py @@ -6,6 +6,7 @@ from django.http.response import HttpResponse from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.middleware.locale import AsyncLocaleMiddleware from django_async_extensions.middleware.security import AsyncSecurityMiddleware req = HttpResponse() @@ -33,6 +34,7 @@ async def process_request(self, request): class TestMiddlewareMixin: middlewares = [ AsyncSecurityMiddleware, + AsyncLocaleMiddleware, ] def test_repr(self): diff --git a/tests/test_middlewares/urls.py b/tests/test_middlewares/urls.py index 80cbb2c..4d07945 100644 --- a/tests/test_middlewares/urls.py +++ b/tests/test_middlewares/urls.py @@ -1,7 +1,18 @@ +from django.conf.urls.i18n import i18n_patterns +from django.http import HttpResponse, StreamingHttpResponse from django.urls import path +from django.utils.translation import gettext_lazy as _ + from . import views + +async def stream_http_generator(): + yield _("Yes") + yield "/" + yield _("No") + + urlpatterns = [ path("middleware_exceptions/view/", views.normal_view), path("middleware_exceptions/error/", views.server_error), @@ -14,3 +25,8 @@ views.async_exception_in_render, ), ] + +urlpatterns += i18n_patterns( + path("simple/", lambda r: HttpResponse()), + path("streaming/", lambda r: StreamingHttpResponse(stream_http_generator())), +) From 686161466e933726b0129e8ffc86af97a3fba402 Mon Sep 17 00:00:00 2001 From: amirreza Date: Wed, 26 Mar 2025 21:15:50 +0330 Subject: [PATCH 08/21] handle sync views before this you could use sync views, but had to override the middleware to handle that, now it works out of the box --- django_async_extensions/utils/decorators.py | 25 ++-- docs/middleware/decorate_views.md | 33 +++++- tests/test_async_utils/test_decorators.py | 120 ++++++++++++++++++-- 3 files changed, 157 insertions(+), 21 deletions(-) diff --git a/django_async_extensions/utils/decorators.py b/django_async_extensions/utils/decorators.py index f7414e2..0c77e40 100644 --- a/django_async_extensions/utils/decorators.py +++ b/django_async_extensions/utils/decorators.py @@ -3,7 +3,7 @@ from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async -def decorator_from_middleware_with_args(middleware_class): +def decorator_from_middleware_with_args(middleware_class, async_only=True): """ Like decorator_from_middleware, but return a function that accepts the arguments to be passed to the middleware_class. @@ -16,22 +16,31 @@ def decorator_from_middleware_with_args(middleware_class): def my_view(request): # ... """ - return make_middleware_decorator(middleware_class) + return make_middleware_decorator(middleware_class, async_only=async_only) -def decorator_from_middleware(middleware_class): +def decorator_from_middleware(middleware_class, async_only=True): """ Given a middleware class (not an instance), return a view decorator. This lets you use middleware functionality on a per-view basis. The middleware is created with no params passed. """ - return make_middleware_decorator(middleware_class)() + return make_middleware_decorator(middleware_class, async_only=async_only)() -def make_middleware_decorator(middleware_class): +def make_middleware_decorator(middleware_class, async_only=True): def _make_decorator(*m_args, **m_kwargs): def _decorator(view_func): - middleware = middleware_class(view_func, *m_args, **m_kwargs) + _view_func = view_func + if all( + [ + not iscoroutinefunction(view_func), + not iscoroutinefunction(getattr(view_func, "__call__", None)), + async_only, + ] + ): + _view_func = sync_to_async(view_func) + middleware = middleware_class(_view_func, *m_args, **m_kwargs) async def _pre_process_request(request, *args, **kwargs): if hasattr(middleware, "process_request"): @@ -87,7 +96,9 @@ async def callback(response): return await middleware.process_response(request, response) return response - if iscoroutinefunction(view_func): + if iscoroutinefunction(view_func) or iscoroutinefunction( + getattr(view_func, "__call__", view_func) + ): async def _view_wrapper(request, *args, **kwargs): result = await _pre_process_request(request, *args, **kwargs) diff --git a/docs/middleware/decorate_views.md b/docs/middleware/decorate_views.md index 491f3a1..ad11442 100644 --- a/docs/middleware/decorate_views.md +++ b/docs/middleware/decorate_views.md @@ -6,9 +6,6 @@ they work almost exactly like django's [decorator_from_middleware](https://docs. and [decorator_from_middleware_with_args](https://docs.djangoproject.com/en/5.1/ref/utils/#django.utils.decorators.decorator_from_middleware_with_args) but it expects an async middleware as described in [AsyncMiddlewareMixin](base.md) -**Important:** if you are using a middleware that inherits from [AsyncMiddlewareMixin](base.md) you can only decorate async views -if you need to decorate a sync view change middleware's `__init__()` method to accept async `get_response` argument. - with an async view ```python from django.http.response import HttpResponse @@ -30,12 +27,18 @@ async def my_view(request): ``` -if you need to use a sync view design your middleware like this +if your view is sync, it'll be wrapped in `sync_to_async` before getting passed down to middleware. + +if you need, you can disable this by passing `async_only=False`. +note that the middlewares presented in this package will error if you do that, so you have to override the `__init__()` and `__call__()` methods to handle that. + ```python -from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django.http.response import HttpResponse from asgiref.sync import iscoroutinefunction, markcoroutinefunction +from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.utils.decorators import decorator_from_middleware class MyMiddleware(AsyncMiddlewareMixin): sync_capable = True @@ -51,6 +54,24 @@ class MyMiddleware(AsyncMiddlewareMixin): if self.async_mode: # Mark the class as async-capable. markcoroutinefunction(self) + + async def __call__(self, request): + response = None + if hasattr(self, "process_request"): + response = await self.process_request(request) + response = response or self.get_response(request) # here call the method in a sync manner, or handle it in another way + if hasattr(self, "process_response"): + response = await self.process_response(request, response) + return response - super().__init__() + async def process_request(self, request): + return HttpResponse() + + +deco = decorator_from_middleware(MyMiddleware, async_only=False) + + +@deco +def my_view(request): + return HttpResponse() ``` diff --git a/tests/test_async_utils/test_decorators.py b/tests/test_async_utils/test_decorators.py index 50f9b2b..8012be0 100644 --- a/tests/test_async_utils/test_decorators.py +++ b/tests/test_async_utils/test_decorators.py @@ -12,9 +12,6 @@ class ProcessViewMiddleware(AsyncMiddlewareMixin): - def __init__(self, get_response): - self.get_response = get_response - async def process_view(self, request, view_func, view_args, view_kwargs): pass @@ -49,9 +46,6 @@ async def __call__(self, request): class FullMiddleware(AsyncMiddlewareMixin): - def __init__(self, get_response): - self.get_response = get_response - async def process_request(self, request): request.process_request_reached = True @@ -72,6 +66,30 @@ async def process_response(self, request, response): full_dec = decorator_from_middleware(FullMiddleware) +class MiddlewareSyncGetResponse(FullMiddleware): + sync_capable = True + + def __init__(self, get_response): + self.get_response = get_response + + async def __call__(self, request): + response = None + if hasattr(self, "process_request"): + response = await self.process_request(request) + response = response or self.get_response(request) + if hasattr(self, "process_response"): + response = await self.process_response(request, response) + return response + + +full_sync_dec = decorator_from_middleware(MiddlewareSyncGetResponse, async_only=False) + + +@full_sync_dec +def process_view_sync(request): + return HttpResponse() + + class TestDecoratorFromMiddleware: """ Tests for view decorators created using @@ -89,7 +107,7 @@ def test_process_view_middleware(self): async def test_process_view_middleware_async(self, async_rf): await async_process_view(async_rf.get("/")) - async def test_sync_process_view_raises_in_async_context(self): + async def test_sync_process_view_in_async_context_errors(self): msg = ( "You cannot use AsyncToSync in the same thread as an async event loop" " - just await the async function directly." @@ -104,7 +122,7 @@ def test_callable_process_view_middleware(self): class_process_view(self.rf.get("/")) async def test_callable_process_view_middleware_async(self, async_rf): - await async_process_view(async_rf.get("/")) + await async_class_process_view(async_rf.get("/")) def test_full_dec_normal(self): """ @@ -142,6 +160,38 @@ async def normal_view(request): assert getattr(request, "process_template_response_reached", False) is False assert getattr(request, "process_response_reached", False) + def test_full_sync_dec_normal(self): + @full_sync_dec + def normal_view(request): + template = engines["django"].from_string("Hello world") + return HttpResponse(template.render()) + + request = self.rf.get("/") + normal_view(request) + assert getattr(request, "process_request_reached", False) + assert getattr(request, "process_view_reached", False) + # process_template_response must not be called for HttpResponse + assert getattr(request, "process_template_response_reached", False) is False + assert getattr(request, "process_response_reached", False) + + async def test_full_sync_dec_normal_async(self, async_rf): + """ + All methods of middleware are called for normal HttpResponses + """ + + @full_sync_dec + async def normal_view(request): + template = engines["django"].from_string("Hello world") + return HttpResponse(template.render()) + + request = async_rf.get("/") + await normal_view(request) + assert getattr(request, "process_request_reached", False) + assert getattr(request, "process_view_reached", False) + # process_template_response must not be called for HttpResponse + assert getattr(request, "process_template_response_reached", False) is False + assert getattr(request, "process_response_reached", False) + def test_full_dec_templateresponse(self): """ All methods of middleware are called for TemplateResponses in @@ -195,3 +245,57 @@ async def template_response_view(request): assert getattr(request, "process_response_reached", False) # process_response saw the rendered content assert request.process_response_content == b"Hello world" + + def test_full_sync_dec_templateresponse(self): + """ + All methods of middleware are called for TemplateResponses in + the right sequence. + """ + + @full_sync_dec + def template_response_view(request): + template = engines["django"].from_string("Hello world") + return TemplateResponse(request, template) + + request = self.rf.get("/") + response = template_response_view(request) + assert getattr(request, "process_request_reached", False) + assert getattr(request, "process_view_reached", False) + assert getattr(request, "process_template_response_reached", False) + # response must not be rendered yet. + assert response._is_rendered is False + # process_response must not be called until after response is rendered, + # otherwise some decorators like csrf_protect and gzip_page will not + # work correctly. See #16004 + assert getattr(request, "process_response_reached", False) is False + response.render() + assert getattr(request, "process_response_reached", False) + # process_response saw the rendered content + assert request.process_response_content == b"Hello world" + + async def test_full_sync_dec_templateresponse_async(self, async_rf): + """ + All methods of middleware are called for TemplateResponses in + the right sequence. + """ + + @full_sync_dec + async def template_response_view(request): + template = engines["django"].from_string("Hello world") + return TemplateResponse(request, template) + + request = async_rf.get("/") + response = await template_response_view(request) + assert getattr(request, "process_request_reached", False) + assert getattr(request, "process_view_reached", False) + assert getattr(request, "process_template_response_reached", False) + # response must not be rendered yet. + assert response._is_rendered is False + # process_response must not be called until after response is rendered, + # otherwise some decorators like csrf_protect and gzip_page will not + # work correctly. See #16004 + assert getattr(request, "process_response_reached", False) is False + await sync_to_async(response.render)() + assert getattr(request, "process_response_reached", False) + # process_response saw the rendered content + assert request.process_response_content == b"Hello world" From 52a36542436bf3a6fd09c94671e8636feb98028a Mon Sep 17 00:00:00 2001 From: amirreza Date: Sat, 29 Mar 2025 03:47:02 +0330 Subject: [PATCH 09/21] implemented async http middleware --- django_async_extensions/middleware/http.py | 41 ++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 django_async_extensions/middleware/http.py diff --git a/django_async_extensions/middleware/http.py b/django_async_extensions/middleware/http.py new file mode 100644 index 0000000..6b7f85f --- /dev/null +++ b/django_async_extensions/middleware/http.py @@ -0,0 +1,41 @@ +from django.utils.cache import cc_delim_re, get_conditional_response, set_response_etag +from django.utils.http import parse_http_date_safe + +from django_async_extensions.middleware.base import AsyncMiddlewareMixin + + +class AsyncConditionalGetMiddleware(AsyncMiddlewareMixin): + """ + Handle conditional GET operations. If the response has an ETag or + Last-Modified header and the request has If-None-Match or If-Modified-Since, + replace the response with HttpNotModified. Add an ETag header if needed. + """ + + async def process_response(self, request, response): + # It's too late to prevent an unsafe request with a 412 response, and + # for a HEAD request, the response body is always empty so computing + # an accurate ETag isn't possible. + if request.method != "GET": + return response + + if self.needs_etag(response) and not response.has_header("ETag"): + set_response_etag(response) + + etag = response.get("ETag") + last_modified = response.get("Last-Modified") + last_modified = last_modified and parse_http_date_safe(last_modified) + + if etag or last_modified: + return get_conditional_response( + request, + etag=etag, + last_modified=last_modified, + response=response, + ) + + return response + + def needs_etag(self, response): + """Return True if an ETag header should be added to response.""" + cache_control_headers = cc_delim_re.split(response.get("Cache-Control", "")) + return all(header.lower() != "no-store" for header in cache_control_headers) From fcae99c7452edbfa882c8bdde9c8d196c711f602 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sat, 29 Mar 2025 03:47:23 +0330 Subject: [PATCH 10/21] document async http middleware --- docs/middleware/conditional_get_middleware.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 docs/middleware/conditional_get_middleware.md diff --git a/docs/middleware/conditional_get_middleware.md b/docs/middleware/conditional_get_middleware.md new file mode 100644 index 0000000..225cff8 --- /dev/null +++ b/docs/middleware/conditional_get_middleware.md @@ -0,0 +1,10 @@ +# AsyncConditionalGetMiddleware + +it works exactly like django's [ConditionalGetMiddleware](https://docs.djangoproject.com/en/5.1/ref/middleware/#module-django.middleware.http) +except that it's fully async + +### Usage: +remove django's `django.middleware.http.ConditionalGetMiddleware` from the `MIDDLEWARE` setting and add +`django_async_extensions.middleware.http.AsyncConditionalGetMiddleware` in it's place. + +**note**: this middleware like other middlewares provided in this package can work alongside sync middlewares, and can handle sync views. From 809b9a8fc1366122943343b4082f210cd724de0b Mon Sep 17 00:00:00 2001 From: amirreza Date: Sat, 29 Mar 2025 03:49:30 +0330 Subject: [PATCH 11/21] add tests for async http middleware --- tests/test_middlewares/cond_get_urls.py | 6 + tests/test_middlewares/test_http.py | 239 ++++++++++++++++++ .../test_middlewares/test_middleware_mixin.py | 2 + 3 files changed, 247 insertions(+) create mode 100644 tests/test_middlewares/cond_get_urls.py create mode 100644 tests/test_middlewares/test_http.py diff --git a/tests/test_middlewares/cond_get_urls.py b/tests/test_middlewares/cond_get_urls.py new file mode 100644 index 0000000..884ef6a --- /dev/null +++ b/tests/test_middlewares/cond_get_urls.py @@ -0,0 +1,6 @@ +from django.http import HttpResponse +from django.urls import path + +urlpatterns = [ + path("", lambda request: HttpResponse("root is here")), +] diff --git a/tests/test_middlewares/test_http.py b/tests/test_middlewares/test_http.py new file mode 100644 index 0000000..7ac1f64 --- /dev/null +++ b/tests/test_middlewares/test_http.py @@ -0,0 +1,239 @@ +import pytest + +from django.http import StreamingHttpResponse, HttpResponse +from django.test import AsyncRequestFactory, AsyncClient + +from django_async_extensions.middleware.http import AsyncConditionalGetMiddleware + + +client = AsyncClient() + + +@pytest.fixture(autouse=True) +def urlconf_setting_set(settings): + old_urlconf = settings.ROOT_URLCONF + settings.ROOT_URLCONF = "test_middlewares.cond_get_urls" + yield settings + settings.ROOT_URLCONF = old_urlconf + + +class TestConditionalGetMiddleware: + request_factory = AsyncRequestFactory() + + @pytest.fixture(autouse=True) + def setup(self): + self.req = self.request_factory.get("/") + self.resp_headers = {} + + async def get_response(self, req): + resp = await client.get(req.path_info) + for key, value in self.resp_headers.items(): + resp[key] = value + return resp + + # Tests for the ETag header + + async def test_middleware_calculates_etag(self): + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 200 + assert "" != resp["ETag"] + + async def test_middleware_wont_overwrite_etag(self): + self.resp_headers["ETag"] = "eggs" + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 200 + assert "eggs" == resp["ETag"] + + async def test_no_etag_streaming_response(self): + async def get_response(req): + return StreamingHttpResponse(["content"]) + + response = await AsyncConditionalGetMiddleware(get_response)(self.req) + assert response.has_header("ETag") is False + + async def test_no_etag_response_empty_content(self): + async def get_response(req): + return HttpResponse() + + response = await AsyncConditionalGetMiddleware(get_response)(self.req) + assert response.has_header("ETag") is False + + async def test_no_etag_no_store_cache(self): + self.resp_headers["Cache-Control"] = "No-Cache, No-Store, Max-age=0" + response = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert response.has_header("ETag") is False + + async def test_etag_extended_cache_control(self): + self.resp_headers["Cache-Control"] = 'my-directive="my-no-store"' + response = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert response.has_header("ETag") + + async def test_if_none_match_and_no_etag(self): + self.req.META["HTTP_IF_NONE_MATCH"] = "spam" + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 200 + + async def test_no_if_none_match_and_etag(self): + self.resp_headers["ETag"] = "eggs" + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 200 + + async def test_if_none_match_and_same_etag(self): + self.req.META["HTTP_IF_NONE_MATCH"] = '"spam"' + self.resp_headers["ETag"] = '"spam"' + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 304 + + async def test_if_none_match_and_different_etag(self): + self.req.META["HTTP_IF_NONE_MATCH"] = "spam" + self.resp_headers["ETag"] = "eggs" + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 200 + + async def test_if_none_match_and_redirect(self): + async def get_response(req): + resp = await client.get(req.path_info) + resp["ETag"] = "spam" + resp["Location"] = "/" + resp.status_code = 301 + return resp + + self.req.META["HTTP_IF_NONE_MATCH"] = "spam" + resp = await AsyncConditionalGetMiddleware(get_response)(self.req) + assert resp.status_code == 301 + + async def test_if_none_match_and_client_error(self): + async def get_response(req): + resp = await client.get(req.path_info) + resp["ETag"] = "spam" + resp.status_code = 400 + return resp + + self.req.META["HTTP_IF_NONE_MATCH"] = "spam" + resp = await AsyncConditionalGetMiddleware(get_response)(self.req) + assert resp.status_code == 400 + + # Tests for the Last-Modified header + + async def test_if_modified_since_and_no_last_modified(self): + self.req.META["HTTP_IF_MODIFIED_SINCE"] = "Sat, 12 Feb 2011 17:38:44 GMT" + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 200 + + async def test_no_if_modified_since_and_last_modified(self): + self.resp_headers["Last-Modified"] = "Sat, 12 Feb 2011 17:38:44 GMT" + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 200 + + async def test_if_modified_since_and_same_last_modified(self): + self.req.META["HTTP_IF_MODIFIED_SINCE"] = "Sat, 12 Feb 2011 17:38:44 GMT" + self.resp_headers["Last-Modified"] = "Sat, 12 Feb 2011 17:38:44 GMT" + self.resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert self.resp.status_code == 304 + + async def test_if_modified_since_and_last_modified_in_the_past(self): + self.req.META["HTTP_IF_MODIFIED_SINCE"] = "Sat, 12 Feb 2011 17:38:44 GMT" + self.resp_headers["Last-Modified"] = "Sat, 12 Feb 2011 17:35:44 GMT" + resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert resp.status_code == 304 + + async def test_if_modified_since_and_last_modified_in_the_future(self): + self.req.META["HTTP_IF_MODIFIED_SINCE"] = "Sat, 12 Feb 2011 17:38:44 GMT" + self.resp_headers["Last-Modified"] = "Sat, 12 Feb 2011 17:41:44 GMT" + self.resp = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + assert self.resp.status_code == 200 + + async def test_if_modified_since_and_redirect(self): + async def get_response(req): + resp = await client.get(req.path_info) + resp["Last-Modified"] = "Sat, 12 Feb 2011 17:35:44 GMT" + resp["Location"] = "/" + resp.status_code = 301 + return resp + + self.req.META["HTTP_IF_MODIFIED_SINCE"] = "Sat, 12 Feb 2011 17:38:44 GMT" + resp = await AsyncConditionalGetMiddleware(get_response)(self.req) + assert resp.status_code == 301 + + async def test_if_modified_since_and_client_error(self): + async def get_response(req): + resp = await client.get(req.path_info) + resp["Last-Modified"] = "Sat, 12 Feb 2011 17:35:44 GMT" + resp.status_code = 400 + return resp + + self.req.META["HTTP_IF_MODIFIED_SINCE"] = "Sat, 12 Feb 2011 17:38:44 GMT" + resp = await AsyncConditionalGetMiddleware(get_response)(self.req) + assert resp.status_code == 400 + + async def test_not_modified_headers(self): + """ + The 304 Not Modified response should include only the headers required + by RFC 9110 Section 15.4.5, Last-Modified, and the cookies. + """ + + async def get_response(req): + resp = await client.get(req.path_info) + resp["Date"] = "Sat, 12 Feb 2011 17:35:44 GMT" + resp["Last-Modified"] = "Sat, 12 Feb 2011 17:35:44 GMT" + resp["Expires"] = "Sun, 13 Feb 2011 17:35:44 GMT" + resp["Vary"] = "Cookie" + resp["Cache-Control"] = "public" + resp["Content-Location"] = "/alt" + resp["Content-Language"] = "en" # shouldn't be preserved + resp["ETag"] = '"spam"' + resp.set_cookie("key", "value") + return resp + + self.req.META["HTTP_IF_NONE_MATCH"] = '"spam"' + + new_response = await AsyncConditionalGetMiddleware(get_response)(self.req) + assert new_response.status_code == 304 + base_response = await get_response(self.req) + for header in ( + "Cache-Control", + "Content-Location", + "Date", + "ETag", + "Expires", + "Last-Modified", + "Vary", + ): + assert new_response.headers[header] == base_response.headers[header] + assert new_response.cookies == base_response.cookies + assert "Content-Language" not in new_response + + async def test_no_unsafe(self): + """ + ConditionalGetMiddleware shouldn't return a conditional response on an + unsafe request. A response has already been generated by the time + ConditionalGetMiddleware is called, so it's too late to return a 412 + Precondition Failed. + """ + + async def get_200_response(req): + return HttpResponse(status=200) + + response = await AsyncConditionalGetMiddleware(self.get_response)(self.req) + etag = response.headers["ETag"] + put_request = self.request_factory.put("/", headers={"if-match": etag}) + conditional_get_response = await AsyncConditionalGetMiddleware( + get_200_response + )(put_request) + assert conditional_get_response.status_code == 200 # should never be a 412 + + async def test_no_head(self): + """ + ConditionalGetMiddleware shouldn't compute and return an ETag on a + HEAD request since it can't do so accurately without access to the + response body of the corresponding GET. + """ + + async def get_200_response(req): + return HttpResponse(status=200) + + request = self.request_factory.head("/") + conditional_get_response = await AsyncConditionalGetMiddleware( + get_200_response + )(request) + assert "ETag" not in conditional_get_response diff --git a/tests/test_middlewares/test_middleware_mixin.py b/tests/test_middlewares/test_middleware_mixin.py index d3ab5cb..3c4ae48 100644 --- a/tests/test_middlewares/test_middleware_mixin.py +++ b/tests/test_middlewares/test_middleware_mixin.py @@ -6,6 +6,7 @@ from django.http.response import HttpResponse from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.middleware.http import AsyncConditionalGetMiddleware from django_async_extensions.middleware.locale import AsyncLocaleMiddleware from django_async_extensions.middleware.security import AsyncSecurityMiddleware @@ -35,6 +36,7 @@ class TestMiddlewareMixin: middlewares = [ AsyncSecurityMiddleware, AsyncLocaleMiddleware, + AsyncConditionalGetMiddleware, ] def test_repr(self): From 6021ab5d315cae1bbb75aa61318c42f7724cd3e0 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sat, 29 Mar 2025 22:41:25 +0330 Subject: [PATCH 12/21] update django-upgrade's config --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d42c187..ed11efa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: 1.23.1 hooks: - id: django-upgrade - args: [--target-version, "5.1"] + args: [--target-version, "5.2", "--skip", "request_headers"] - repo: https://github.com/psf/black rev: 25.1.0 From 3b05fc00202ed660118a42a369e4eda267d1d1ed Mon Sep 17 00:00:00 2001 From: amirreza Date: Sat, 29 Mar 2025 22:41:40 +0330 Subject: [PATCH 13/21] implement async gzip middleware --- django_async_extensions/middleware/gzip.py | 76 ++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 django_async_extensions/middleware/gzip.py diff --git a/django_async_extensions/middleware/gzip.py b/django_async_extensions/middleware/gzip.py new file mode 100644 index 0000000..bea7aef --- /dev/null +++ b/django_async_extensions/middleware/gzip.py @@ -0,0 +1,76 @@ +from django.utils.cache import patch_vary_headers +from django.utils.regex_helper import _lazy_re_compile +from django.utils.text import compress_sequence, compress_string + +from django_async_extensions.middleware.base import AsyncMiddlewareMixin + + +re_accepts_gzip = _lazy_re_compile(r"\bgzip\b") + + +class AsyncGZipMiddleware(AsyncMiddlewareMixin): + """ + Compress content if the browser allows gzip compression. + Set the Vary header accordingly, so that caches will base their storage + on the Accept-Encoding header. + """ + + max_random_bytes = 100 + + async def process_response(self, request, response): + # It's not worth attempting to compress really short responses. + if not response.streaming and len(response.content) < 200: + return response + + # Avoid gzipping if we've already got a content-encoding. + if response.has_header("Content-Encoding"): + return response + + patch_vary_headers(response, ("Accept-Encoding",)) + + ae = request.headers.get("accept-encoding", "") + if not re_accepts_gzip.search(ae): + return response + + if response.streaming: + if response.is_async: + # pull to lexical scope to capture fixed reference in case + # streaming_content is set again later. + orignal_iterator = response.streaming_content + + async def gzip_wrapper(): + async for chunk in orignal_iterator: + yield compress_string( + chunk, + max_random_bytes=self.max_random_bytes, + ) + + response.streaming_content = gzip_wrapper() + else: + response.streaming_content = compress_sequence( + response.streaming_content, + max_random_bytes=self.max_random_bytes, + ) + # Delete the `Content-Length` header for streaming content, because + # we won't know the compressed size until we stream it. + del response.headers["Content-Length"] + else: + # Return the compressed content only if it's actually shorter. + compressed_content = compress_string( + response.content, + max_random_bytes=self.max_random_bytes, + ) + if len(compressed_content) >= len(response.content): + return response + response.content = compressed_content + response.headers["Content-Length"] = str(len(response.content)) + + # If there is a strong ETag, make it weak to fulfill the requirements + # of RFC 9110 Section 8.8.1 while also allowing conditional request + # matches on ETags. + etag = response.get("ETag") + if etag and etag.startswith('"'): + response.headers["ETag"] = "W/" + etag + response.headers["Content-Encoding"] = "gzip" + + return response From b965f734b60179c6fe9e558dfadbdd005342eda2 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sat, 29 Mar 2025 22:42:16 +0330 Subject: [PATCH 14/21] document async gzip middleware --- docs/middleware/gzip_middleware.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 docs/middleware/gzip_middleware.md diff --git a/docs/middleware/gzip_middleware.md b/docs/middleware/gzip_middleware.md new file mode 100644 index 0000000..432b9db --- /dev/null +++ b/docs/middleware/gzip_middleware.md @@ -0,0 +1,20 @@ +# AsyncGZipMiddleware + +it works exactly like django's [GZipMiddleware](https://docs.djangoproject.com/en/5.1/ref/middleware/#module-django.middleware.gzip) +except that it's fully async + +------------------------ +**important:** +Security researchers revealed that when compression techniques (including GZipMiddleware) are used on a website, the site may become exposed to a number of possible attacks. + +To mitigate attacks, Django implements a technique called Heal The Breach (HTB). It adds up to 100 bytes (see [max_random_bytes](https://docs.djangoproject.com/en/5.1/ref/middleware/#django.middleware.gzip.GZipMiddleware.max_random_bytes)) of random bytes to each response to make the attacks less effective. + +For more details, see the [BREACH paper (PDF)](https://www.breachattack.com/resources/BREACH%20-%20SSL,%20gone%20in%2030%20seconds.pdf), [breachattack.com](https://www.breachattack.com/), and the [Heal The Breach (HTB) paper](https://ieeexplore.ieee.org/document/9754554). + +------------------------- + +### Usage: +remove django's `django.middleware.gzip.GZipMiddleware` from the `MIDDLEWARE` setting (if it's in there) and add +`django_async_extensions.middleware.gzip.AsyncGZipMiddleware` in it's place. + +**note**: this middleware like other middlewares provided in this package can work alongside sync middlewares, and can handle sync views. From 9384f2a5403abbea47e8977a3980113d172d1671 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sat, 29 Mar 2025 22:42:26 +0330 Subject: [PATCH 15/21] add tests for async gzip middleware --- tests/test_middlewares/test_gzip.py | 279 ++++++++++++++++++ .../test_middlewares/test_middleware_mixin.py | 2 + 2 files changed, 281 insertions(+) create mode 100644 tests/test_middlewares/test_gzip.py diff --git a/tests/test_middlewares/test_gzip.py b/tests/test_middlewares/test_gzip.py new file mode 100644 index 0000000..a00f018 --- /dev/null +++ b/tests/test_middlewares/test_gzip.py @@ -0,0 +1,279 @@ +import gzip +import random +import struct +from io import BytesIO + +import pytest + +from django.http import FileResponse, HttpResponse, StreamingHttpResponse +from django.test import AsyncRequestFactory + +from django_async_extensions.middleware.gzip import AsyncGZipMiddleware +from django_async_extensions.middleware.http import AsyncConditionalGetMiddleware + +int2byte = struct.Struct(">B").pack + + +class TestGZipMiddleware: + """ + Tests the GZipMiddleware. + """ + + short_string = b"This string is too short to be worth compressing." + compressible_string = b"a" * 500 + incompressible_string = b"".join( + int2byte(random.randint(0, 255)) for _ in range(500) # noqa: S311 + ) + sequence = [b"a" * 500, b"b" * 200, b"a" * 300] + sequence_unicode = ["a" * 500, "é" * 200, "a" * 300] + request_factory = AsyncRequestFactory() + + @pytest.fixture(autouse=True) + def setup(self): + self.req = self.request_factory.get("/") + self.req.META["HTTP_ACCEPT_ENCODING"] = "gzip, deflate" + self.req.META["HTTP_USER_AGENT"] = ( + "Mozilla/5.0 (Windows NT 5.1; rv:9.0.1) Gecko/20100101 Firefox/9.0.1" + ) + self.resp = HttpResponse() + self.resp.status_code = 200 + self.resp.content = self.compressible_string + self.resp["Content-Type"] = "text/html; charset=UTF-8" + + async def get_response(self, request): + return self.resp + + @staticmethod + def decompress(gzipped_string): + with gzip.GzipFile(mode="rb", fileobj=BytesIO(gzipped_string)) as f: + return f.read() + + @staticmethod + def get_mtime(gzipped_string): + with gzip.GzipFile(mode="rb", fileobj=BytesIO(gzipped_string)) as f: + f.read() # must read the data before accessing the header + return f.mtime + + async def test_compress_response(self): + """ + Compression is performed on responses with compressible content. + """ + r = await AsyncGZipMiddleware(self.get_response)(self.req) + assert self.decompress(r.content) == self.compressible_string + assert r.get("Content-Encoding") == "gzip" + assert r.get("Content-Length") == str(len(r.content)) + + async def test_compress_streaming_response(self): + """ + Compression is performed on responses with streaming content. + """ + + async def get_stream_response(request): + resp = StreamingHttpResponse(self.sequence) + resp["Content-Type"] = "text/html; charset=UTF-8" + return resp + + r = await AsyncGZipMiddleware(get_stream_response)(self.req) + assert self.decompress(b"".join(r)) == b"".join(self.sequence) + assert r.get("Content-Encoding") == "gzip" + assert r.has_header("Content-Length") is False + + async def test_compress_async_streaming_response(self): + """ + Compression is performed on responses with async streaming content. + """ + + async def get_stream_response(request): + async def iterator(): + for chunk in self.sequence: + yield chunk + + resp = StreamingHttpResponse(iterator()) + resp["Content-Type"] = "text/html; charset=UTF-8" + return resp + + r = await AsyncGZipMiddleware(get_stream_response)(self.req) + assert self.decompress(b"".join([chunk async for chunk in r])) == b"".join( + self.sequence + ) + assert r.get("Content-Encoding") == "gzip" + assert r.has_header("Content-Length") is False + + async def test_compress_streaming_response_unicode(self): + """ + Compression is performed on responses with streaming Unicode content. + """ + + async def get_stream_response_unicode(request): + resp = StreamingHttpResponse(self.sequence_unicode) + resp["Content-Type"] = "text/html; charset=UTF-8" + return resp + + r = await AsyncGZipMiddleware(get_stream_response_unicode)(self.req) + + assert self.decompress(b"".join(r)) == b"".join( + x.encode() for x in self.sequence_unicode + ) + assert r.get("Content-Encoding") == "gzip" + assert r.has_header("Content-Length") is False + + async def test_compress_file_response(self): + """ + Compression is performed on FileResponse. + """ + with open(__file__, "rb") as file1: + + async def get_response(req): + file_resp = FileResponse(file1) + file_resp["Content-Type"] = "text/html; charset=UTF-8" + return file_resp + + r = await AsyncGZipMiddleware(get_response)(self.req) + with open(__file__, "rb") as file2: + assert self.decompress(b"".join(r)) == file2.read() + assert r.get("Content-Encoding") == "gzip" + assert r.file_to_stream is not file1 + + async def test_compress_non_200_response(self): + """ + Compression is performed on responses with a status other than 200 + (#10762). + """ + self.resp.status_code = 404 + r = await AsyncGZipMiddleware(self.get_response)(self.req) + assert self.decompress(r.content) == self.compressible_string + assert r.get("Content-Encoding") == "gzip" + + async def test_no_compress_short_response(self): + """ + Compression isn't performed on responses with short content. + """ + self.resp.content = self.short_string + r = await AsyncGZipMiddleware(self.get_response)(self.req) + assert r.content == self.short_string + assert r.get("Content-Encoding") is None + + async def test_no_compress_compressed_response(self): + """ + Compression isn't performed on responses that are already compressed. + """ + self.resp["Content-Encoding"] = "deflate" + r = await AsyncGZipMiddleware(self.get_response)(self.req) + assert r.content == self.compressible_string + assert r.get("Content-Encoding") == "deflate" + + async def test_no_compress_incompressible_response(self): + """ + Compression isn't performed on responses with incompressible content. + """ + self.resp.content = self.incompressible_string + r = await AsyncGZipMiddleware(self.get_response)(self.req) + assert r.content == self.incompressible_string + assert r.get("Content-Encoding") is None + + async def test_compress_deterministic(self): + """ + Compression results are the same for the same content and don't + include a modification time (since that would make the results + of compression non-deterministic and prevent + ConditionalGetMiddleware from recognizing conditional matches + on gzipped content). + """ + + class DeterministicGZipMiddleware(AsyncGZipMiddleware): + max_random_bytes = 0 + + r1 = await DeterministicGZipMiddleware(self.get_response)(self.req) + r2 = await DeterministicGZipMiddleware(self.get_response)(self.req) + assert r1.content == r2.content + assert self.get_mtime(r1.content) == 0 + assert self.get_mtime(r2.content) == 0 + + async def test_random_bytes(self, mocker): + """A random number of bytes is added to mitigate the BREACH attack.""" + mocker.patch( + "django.utils.text.secrets.randbelow", autospec=True, return_value=3 + ) + r = await AsyncGZipMiddleware(self.get_response)(self.req) + # The fourth byte of a gzip stream contains flags. + assert r.content[3] == gzip.FNAME + # A 3 byte filename "aaa" and a null byte are added. + assert r.content[10:14] == b"aaa\x00" + assert self.decompress(r.content) == self.compressible_string + + async def test_random_bytes_streaming_response(self, mocker): + """A random number of bytes is added to mitigate the BREACH attack.""" + + async def get_stream_response(request): + resp = StreamingHttpResponse(self.sequence) + resp["Content-Type"] = "text/html; charset=UTF-8" + return resp + + mocker.patch( + "django.utils.text.secrets.randbelow", autospec=True, return_value=3 + ) + r = await AsyncGZipMiddleware(get_stream_response)(self.req) + content = b"".join(r) + # The fourth byte of a gzip stream contains flags. + assert content[3] == gzip.FNAME + # A 3 byte filename "aaa" and a null byte are added. + assert content[10:14] == b"aaa\x00" + assert self.decompress(content) == b"".join(self.sequence) + + +class TestETagGZipMiddleware: + """ + ETags are handled properly by GZipMiddleware. + """ + + rf = AsyncRequestFactory() + compressible_string = b"a" * 500 + + async def test_strong_etag_modified(self): + """ + GZipMiddleware makes a strong ETag weak. + """ + + async def get_response(req): + response = HttpResponse(self.compressible_string) + response.headers["ETag"] = '"eggs"' + return response + + request = self.rf.get("/", headers={"accept-encoding": "gzip, deflate"}) + gzip_response = await AsyncGZipMiddleware(get_response)(request) + assert gzip_response.headers["ETag"] == 'W/"eggs"' + + async def test_weak_etag_not_modified(self): + """ + GZipMiddleware doesn't modify a weak ETag. + """ + + async def get_response(req): + response = HttpResponse(self.compressible_string) + response.headers["ETag"] = 'W/"eggs"' + return response + + request = self.rf.get("/", headers={"accept-encoding": "gzip, deflate"}) + gzip_response = await AsyncGZipMiddleware(get_response)(request) + assert gzip_response.headers["ETag"] == 'W/"eggs"' + + async def test_etag_match(self): + """ + GZipMiddleware allows 304 Not Modified responses. + """ + + async def get_response(req): + return HttpResponse(self.compressible_string) + + async def get_cond_response(req): + return await AsyncConditionalGetMiddleware(get_response)(req) + + request = self.rf.get("/", headers={"accept-encoding": "gzip, deflate"}) + response = await AsyncGZipMiddleware(get_cond_response)(request) + gzip_etag = response.headers["ETag"] + next_request = self.rf.get( + "/", + headers={"accept-encoding": "gzip, deflate", "if-none-match": gzip_etag}, + ) + next_response = await AsyncConditionalGetMiddleware(get_response)(next_request) + assert next_response.status_code == 304 diff --git a/tests/test_middlewares/test_middleware_mixin.py b/tests/test_middlewares/test_middleware_mixin.py index 3c4ae48..8263028 100644 --- a/tests/test_middlewares/test_middleware_mixin.py +++ b/tests/test_middlewares/test_middleware_mixin.py @@ -6,6 +6,7 @@ from django.http.response import HttpResponse from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.middleware.gzip import AsyncGZipMiddleware from django_async_extensions.middleware.http import AsyncConditionalGetMiddleware from django_async_extensions.middleware.locale import AsyncLocaleMiddleware from django_async_extensions.middleware.security import AsyncSecurityMiddleware @@ -37,6 +38,7 @@ class TestMiddlewareMixin: AsyncSecurityMiddleware, AsyncLocaleMiddleware, AsyncConditionalGetMiddleware, + AsyncGZipMiddleware, ] def test_repr(self): From 813fbc68e4ba2f496dc19441f59dfe371f82c959 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sun, 30 Mar 2025 00:00:07 +0330 Subject: [PATCH 16/21] implement async common middlewares --- django_async_extensions/middleware/common.py | 179 +++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 django_async_extensions/middleware/common.py diff --git a/django_async_extensions/middleware/common.py b/django_async_extensions/middleware/common.py new file mode 100644 index 0000000..d9acf0e --- /dev/null +++ b/django_async_extensions/middleware/common.py @@ -0,0 +1,179 @@ +import re +from urllib.parse import urlsplit + +from django.conf import settings +from django.core.exceptions import PermissionDenied +from django.core.mail import mail_managers +from django.http import HttpResponsePermanentRedirect +from django.urls import is_valid_path +from django.utils.http import escape_leading_slashes + +from django_async_extensions.middleware.base import AsyncMiddlewareMixin + + +class AsyncCommonMiddleware(AsyncMiddlewareMixin): + """ + "Common" middleware for taking care of some basic operations: + + - Forbid access to User-Agents in settings.DISALLOWED_USER_AGENTS + + - URL rewriting: Based on the APPEND_SLASH and PREPEND_WWW settings, + append missing slashes and/or prepends missing "www."s. + + - If APPEND_SLASH is set and the initial URL doesn't end with a + slash, and it is not found in urlpatterns, form a new URL by + appending a slash at the end. If this new URL is found in + urlpatterns, return an HTTP redirect to this new URL; otherwise + process the initial URL as usual. + + This behavior can be customized by subclassing AsyncCommonMiddleware and + overriding the response_redirect_class attribute. + """ + + response_redirect_class = HttpResponsePermanentRedirect + + async def process_request(self, request): + """ + Check for denied User-Agents and rewrite the URL based on + settings.APPEND_SLASH and settings.PREPEND_WWW + """ + + # Check for denied User-Agents + user_agent = request.META.get("HTTP_USER_AGENT") + if user_agent is not None: + for user_agent_regex in settings.DISALLOWED_USER_AGENTS: + if user_agent_regex.search(user_agent): + raise PermissionDenied("Forbidden user agent") + + # Check for a redirect based on settings.PREPEND_WWW + host = request.get_host() + + if settings.PREPEND_WWW and host and not host.startswith("www."): + # Check if we also need to append a slash so we can do it all + # with a single redirect. (This check may be somewhat expensive, + # so we only do it if we already know we're sending a redirect, + # or in process_response if we get a 404.) + if self.should_redirect_with_slash(request): + path = self.get_full_path_with_slash(request) + else: + path = request.get_full_path() + + return self.response_redirect_class(f"{request.scheme}://www.{host}{path}") + + def should_redirect_with_slash(self, request): + """ + Return True if settings.APPEND_SLASH is True and appending a slash to + the request path turns an invalid path into a valid one. + """ + if settings.APPEND_SLASH and not request.path_info.endswith("/"): + urlconf = getattr(request, "urlconf", None) + if not is_valid_path(request.path_info, urlconf): + match = is_valid_path("%s/" % request.path_info, urlconf) + if match: + view = match.func + return getattr(view, "should_append_slash", True) + return False + + def get_full_path_with_slash(self, request): + """ + Return the full path of the request with a trailing slash appended. + + Raise a RuntimeError if settings.DEBUG is True and request.method is + DELETE, POST, PUT, or PATCH. + """ + new_path = request.get_full_path(force_append_slash=True) + # Prevent construction of scheme relative urls. + new_path = escape_leading_slashes(new_path) + if settings.DEBUG and request.method in ("DELETE", "POST", "PUT", "PATCH"): + raise RuntimeError( + "You called this URL via %(method)s, but the URL doesn't end " + "in a slash and you have APPEND_SLASH set. Django can't " + "redirect to the slash URL while maintaining %(method)s data. " + "Change your form to point to %(url)s (note the trailing " + "slash), or set APPEND_SLASH=False in your Django settings." + % { + "method": request.method, + "url": request.get_host() + new_path, + } + ) + return new_path + + async def process_response(self, request, response): + """ + When the status code of the response is 404, it may redirect to a path + with an appended slash if should_redirect_with_slash() returns True. + """ + # If the given URL is "Not Found", then check if we should redirect to + # a path with a slash appended. + if response.status_code == 404 and self.should_redirect_with_slash(request): + return self.response_redirect_class(self.get_full_path_with_slash(request)) + + # Add the Content-Length header to non-streaming responses if not + # already set. + if not response.streaming and not response.has_header("Content-Length"): + response.headers["Content-Length"] = str(len(response.content)) + + return response + + +class AsyncBrokenLinkEmailsMiddleware(AsyncMiddlewareMixin): + async def process_response(self, request, response): + """Send broken link emails for relevant 404 NOT FOUND responses.""" + if response.status_code == 404 and not settings.DEBUG: + domain = request.get_host() + path = request.get_full_path() + referer = request.META.get("HTTP_REFERER", "") + + if not self.is_ignorable_request(request, path, domain, referer): + ua = request.META.get("HTTP_USER_AGENT", "") + ip = request.META.get("REMOTE_ADDR", "") + mail_managers( + "Broken %slink on %s" + % ( + ( + "INTERNAL " + if self.is_internal_request(domain, referer) + else "" + ), + domain, + ), + "Referrer: %s\nRequested URL: %s\nUser agent: %s\n" + "IP address: %s\n" % (referer, path, ua, ip), + fail_silently=True, + ) + return response + + def is_internal_request(self, domain, referer): + """ + Return True if the referring URL is the same domain as the current + request. + """ + # Different subdomains are treated as different domains. + return bool(re.match("^https?://%s/" % re.escape(domain), referer)) + + def is_ignorable_request(self, request, uri, domain, referer): + """ + Return True if the given request *shouldn't* notify the site managers + according to project settings or in situations outlined by the inline + comments. + """ + # The referer is empty. + if not referer: + return True + + # APPEND_SLASH is enabled and the referer is equal to the current URL + # without a trailing slash indicating an internal redirect. + if settings.APPEND_SLASH and uri.endswith("/") and referer == uri[:-1]: + return True + + # A '?' in referer is identified as a search engine source. + if not self.is_internal_request(domain, referer) and "?" in referer: + return True + + # The referer is equal to the current URL, ignoring the scheme (assumed + # to be a poorly implemented bot). + parsed_referer = urlsplit(referer) + if parsed_referer.netloc in ["", domain] and parsed_referer.path == uri: + return True + + return any(pattern.search(uri) for pattern in settings.IGNORABLE_404_URLS) From 08231601b398705ad37b71b2238a25b3e23ce3b8 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sun, 30 Mar 2025 00:06:40 +0330 Subject: [PATCH 17/21] document async common middlewares --- docs/middleware/common_middlewares.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 docs/middleware/common_middlewares.md diff --git a/docs/middleware/common_middlewares.md b/docs/middleware/common_middlewares.md new file mode 100644 index 0000000..75d71e8 --- /dev/null +++ b/docs/middleware/common_middlewares.md @@ -0,0 +1,22 @@ +# Common Middlewares + +## AsyncCommonMiddleware +it works exactly like django's [CommonMiddleware](https://docs.djangoproject.com/en/5.1/ref/middleware/#django.middleware.common.CommonMiddleware) +except that it's fully async + +### Usage: +remove django's `django.middleware.common.CommonMiddleware` from the `MIDDLEWARE` setting and add +`django_async_extensions.middleware.common.AsyncCommonMiddleware` in it's place. + + + +## AsyncBrokenLinkEmailsMiddleware +it works exactly like django's [BrokenLinkEmailsMiddleware](https://docs.djangoproject.com/en/5.1/ref/middleware/#django.middleware.common.BrokenLinkEmailsMiddleware) +except that it's fully async + +### Usage: +remove django's `django.middleware.common.LinkEmailsMiddleware` from the `MIDDLEWARE` setting if it's in there and add +`django_async_extensions.middleware.common.AsyncBrokenLinkEmailsMiddleware` in it's place. + + +**note**: these two middlewares like other middlewares provided in this package can work alongside sync middlewares, and can handle sync views. From 20c5939395810c1a22f82ba5cdb43fe05b4f7c59 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sun, 30 Mar 2025 00:06:44 +0330 Subject: [PATCH 18/21] add tests for async common middleware --- tests/test_middlewares/extra_urls.py | 9 + tests/test_middlewares/test_common.py | 580 ++++++++++++++++++ .../test_middlewares/test_middleware_mixin.py | 6 + tests/test_middlewares/urls.py | 10 +- tests/test_middlewares/views.py | 19 + 5 files changed, 623 insertions(+), 1 deletion(-) create mode 100644 tests/test_middlewares/extra_urls.py create mode 100644 tests/test_middlewares/test_common.py diff --git a/tests/test_middlewares/extra_urls.py b/tests/test_middlewares/extra_urls.py new file mode 100644 index 0000000..faa6ed9 --- /dev/null +++ b/tests/test_middlewares/extra_urls.py @@ -0,0 +1,9 @@ +from django.urls import path + +from . import views + +urlpatterns = [ + path("customurlconf/noslash", views.empty_view), + path("customurlconf/slash/", views.empty_view), + path("customurlconf/needsquoting#/", views.empty_view), +] diff --git a/tests/test_middlewares/test_common.py b/tests/test_middlewares/test_common.py new file mode 100644 index 0000000..c3147e9 --- /dev/null +++ b/tests/test_middlewares/test_common.py @@ -0,0 +1,580 @@ +import re +from urllib.parse import quote +import pytest +from django.core import mail +from django.core.exceptions import PermissionDenied + +from django.http import ( + HttpResponse, + HttpResponseNotFound, + HttpResponsePermanentRedirect, + StreamingHttpResponse, + HttpResponseRedirect, +) +from django.test import AsyncRequestFactory, AsyncClient + +from django_async_extensions.middleware.common import ( + AsyncCommonMiddleware, + AsyncBrokenLinkEmailsMiddleware, +) + +client = AsyncClient() + + +async def get_response_empty(request): + return HttpResponse() + + +async def get_response_404(request): + return HttpResponseNotFound() + + +@pytest.fixture +def append_slash_fixture(request, settings): + old_append_slash = settings.APPEND_SLASH + settings.APPEND_SLASH = request.param + + yield settings + settings.APPEND_SLASH = old_append_slash + + +class TestCommonMiddleware: + rf = AsyncRequestFactory() + + @pytest.fixture(autouse=True) + def urlconf_setting_set(self, settings): + old_urlconf = settings.ROOT_URLCONF + settings.ROOT_URLCONF = "test_middlewares.urls" + yield settings + settings.ROOT_URLCONF = old_urlconf + + @pytest.fixture + def set_setting_fixture(self, request, settings): + # request.param should be a list of [settings_name, value] + old_setting = getattr(settings, request.param[0]) + setattr(settings, request.param[0], request.param[1]) + yield settings + setattr(settings, request.param[0], old_setting) + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_have_slash(self, append_slash_fixture): + """ + URLs with slashes should go unmolested. + """ + request = self.rf.get("/slash/") + assert ( + await AsyncCommonMiddleware(get_response_404).process_request(request) + is None + ) + response = await AsyncCommonMiddleware(get_response_404)(request) + assert response.status_code == 404 + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_slashless_resource(self, append_slash_fixture): + """ + Matches to explicit slashless URLs should go unmolested. + """ + + async def get_response(req): + return HttpResponse("Here's the text of the web page.") + + request = self.rf.get("/noslash") + assert ( + await AsyncCommonMiddleware(get_response).process_request(request) is None + ) + response = await AsyncCommonMiddleware(get_response)(request) + assert response.content == b"Here's the text of the web page." + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_slashless_unknown(self, append_slash_fixture): + """ + APPEND_SLASH should not redirect to unknown resources. + """ + request = self.rf.get("/unknown") + response = await AsyncCommonMiddleware(get_response_404)(request) + assert response.status_code == 404 + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_redirect(self, append_slash_fixture, settings): + """ + APPEND_SLASH should redirect slashless URLs to a valid pattern. + """ + request = self.rf.get("/slash") + r = await AsyncCommonMiddleware(get_response_empty).process_request(request) + assert r is None + response = HttpResponseNotFound() + r = await AsyncCommonMiddleware(get_response_empty).process_response( + request, response + ) + assert r.status_code == 301 + assert r.url == "/slash/" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_redirect_querystring(self, append_slash_fixture): + """ + APPEND_SLASH should preserve querystrings when redirecting. + """ + request = self.rf.get("/slash?test=1") + resp = await AsyncCommonMiddleware(get_response_404)(request) + assert resp.url == "/slash/?test=1" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_redirect_querystring_have_slash( + self, append_slash_fixture + ): + """ + APPEND_SLASH should append slash to path when redirecting a request + with a querystring ending with slash. + """ + request = self.rf.get("/slash?test=slash/") + resp = await AsyncCommonMiddleware(get_response_404)(request) + assert isinstance(resp, HttpResponsePermanentRedirect) + assert resp.url == "/slash/?test=slash/" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + @pytest.mark.parametrize("set_setting_fixture", [["DEBUG", True]], indirect=True) + async def test_append_slash_no_redirect_in_DEBUG( + self, append_slash_fixture, set_setting_fixture + ): + """ + While in debug mode, an exception is raised with a warning + when a failed attempt is made to DELETE, POST, PUT, or PATCH to an URL + which would normally be redirected to a slashed version. + """ + msg = "maintaining %s data. Change your form to point to testserver/slash/" + request = self.rf.get("/slash") + request.method = "POST" + with pytest.raises(RuntimeError, match=msg % request.method): + await AsyncCommonMiddleware(get_response_404)(request) + request = self.rf.get("/slash") + request.method = "PUT" + with pytest.raises(RuntimeError, match=msg % request.method): + await AsyncCommonMiddleware(get_response_404)(request) + request = self.rf.get("/slash") + request.method = "PATCH" + with pytest.raises(RuntimeError, match=msg % request.method): + await AsyncCommonMiddleware(get_response_404)(request) + request = self.rf.delete("/slash") + with pytest.raises(RuntimeError, match=msg % request.method): + await AsyncCommonMiddleware(get_response_404)(request) + + @pytest.mark.parametrize("append_slash_fixture", [False], indirect=True) + async def test_append_slash_disabled(self, append_slash_fixture): + """ + Disabling append slash functionality should leave slashless URLs alone. + """ + request = self.rf.get("/slash") + response = await AsyncCommonMiddleware(get_response_404)(request) + assert response.status_code == 404 + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_opt_out(self, append_slash_fixture): + """ + Views marked with @no_append_slash should be left alone. + """ + request = self.rf.get("/sensitive_fbv") + response = await AsyncCommonMiddleware(get_response_404)(request) + assert response.status_code == 404 + + request = self.rf.get("/sensitive_cbv") + response = await AsyncCommonMiddleware(get_response_404)(request) + assert response.status_code == 404 + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_quoted(self, append_slash_fixture): + """ + URLs which require quoting should be redirected to their slash version. + """ + request = self.rf.get(quote("/needsquoting#")) + r = await AsyncCommonMiddleware(get_response_404)(request) + assert r.status_code == 301 + assert r.url == "/needsquoting%23/" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_leading_slashes(self, append_slash_fixture): + """ + Paths starting with two slashes are escaped to prevent open redirects. + If there's a URL pattern that allows paths to start with two slashes, a + request with path //evil.com must not redirect to //evil.com/ (appended + slash) which is a schemaless absolute URL. The browser would navigate + to evil.com/. + """ + # Use 4 slashes because of RequestFactory behavior. + request = self.rf.get("////evil.com/security") + r = await AsyncCommonMiddleware(get_response_404).process_request(request) + assert r is None + response = HttpResponseNotFound() + r = await AsyncCommonMiddleware(get_response_404).process_response( + request, response + ) + assert r.status_code == 301 + assert r.url == "/%2Fevil.com/security/" + r = await AsyncCommonMiddleware(get_response_404)(request) + assert r.status_code == 301 + assert r.url == "/%2Fevil.com/security/" + + @pytest.mark.parametrize("append_slash_fixture", [False], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["PREPEND_WWW", True]], indirect=True + ) + async def test_prepend_www(self, append_slash_fixture, set_setting_fixture): + request = self.rf.get("/path/") + r = await AsyncCommonMiddleware(get_response_empty).process_request(request) + assert r.status_code == 301 + assert r.url == "http://www.testserver/path/" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["PREPEND_WWW", True]], indirect=True + ) + async def test_prepend_www_append_slash_have_slash( + self, append_slash_fixture, set_setting_fixture + ): + request = self.rf.get("/slash/") + r = await AsyncCommonMiddleware(get_response_empty).process_request(request) + assert r.status_code == 301 + assert r.url == "http://www.testserver/slash/" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["PREPEND_WWW", True]], indirect=True + ) + async def test_prepend_www_append_slash_slashless( + self, append_slash_fixture, set_setting_fixture + ): + request = self.rf.get("/slash") + r = await AsyncCommonMiddleware(get_response_empty).process_request(request) + assert r.status_code == 301 + assert r.url == "http://www.testserver/slash/" + + # The following tests examine expected behavior given a custom URLconf that + # overrides the default one through the request object. + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_have_slash_custom_urlconf(self, append_slash_fixture): + """ + URLs with slashes should go unmolested. + """ + request = self.rf.get("/customurlconf/slash/") + request.urlconf = "test_middlewares.extra_urls" + assert ( + await AsyncCommonMiddleware(get_response_404).process_request(request) + is None + ) + response = await AsyncCommonMiddleware(get_response_404)(request) + assert response.status_code == 404 + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_slashless_resource_custom_urlconf( + self, append_slash_fixture + ): + """ + Matches to explicit slashless URLs should go unmolested. + """ + + async def get_response(req): + return HttpResponse("web content") + + request = self.rf.get("/customurlconf/noslash") + request.urlconf = "test_middlewares.extra_urls" + assert ( + await AsyncCommonMiddleware(get_response).process_request(request) is None + ) + response = await AsyncCommonMiddleware(get_response)(request) + assert response.content == b"web content" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_slashless_unknown_custom_urlconf( + self, append_slash_fixture + ): + """ + APPEND_SLASH should not redirect to unknown resources. + """ + request = self.rf.get("/customurlconf/unknown") + request.urlconf = "test_middlewares.extra_urls" + assert ( + await AsyncCommonMiddleware(get_response_404).process_request(request) + is None + ) + response = await AsyncCommonMiddleware(get_response_404)(request) + assert response.status_code == 404 + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_redirect_custom_urlconf(self, append_slash_fixture): + """ + APPEND_SLASH should redirect slashless URLs to a valid pattern. + """ + request = self.rf.get("/customurlconf/slash") + request.urlconf = "test_middlewares.extra_urls" + r = await AsyncCommonMiddleware(get_response_404)(request) + assert r, ( + "CommonMiddleware failed to return APPEND_SLASH redirect" + " using request.urlconf" + ) + + assert r.status_code == 301 + assert r.url == "/customurlconf/slash/" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + @pytest.mark.parametrize("set_setting_fixture", [["DEBUG", True]], indirect=True) + async def test_append_slash_no_redirect_on_POST_in_DEBUG_custom_urlconf( + self, append_slash_fixture, set_setting_fixture + ): + """ + While in debug mode, an exception is raised with a warning + when a failed attempt is made to POST to an URL which would normally be + redirected to a slashed version. + """ + request = self.rf.get("/customurlconf/slash") + request.urlconf = "test_middlewares.extra_urls" + request.method = "POST" + with pytest.raises(RuntimeError, match="end in a slash"): + await AsyncCommonMiddleware(get_response_404)(request) + + @pytest.mark.parametrize("append_slash_fixture", [False], indirect=True) + async def test_append_slash_disabled_custom_urlconf(self, append_slash_fixture): + """ + Disabling append slash functionality should leave slashless URLs alone. + """ + request = self.rf.get("/customurlconf/slash") + request.urlconf = "test_middlewares.extra_urls" + assert ( + await AsyncCommonMiddleware(get_response_404).process_request(request) + is None + ) + response = await AsyncCommonMiddleware(get_response_404)(request) + assert response.status_code, 404 + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_append_slash_quoted_custom_urlconf(self, append_slash_fixture): + """ + URLs which require quoting should be redirected to their slash version. + """ + request = self.rf.get(quote("/customurlconf/needsquoting#")) + request.urlconf = "test_middlewares.extra_urls" + r = await AsyncCommonMiddleware(get_response_404)(request) + assert r is not None, ( + "CommonMiddleware failed to return APPEND_SLASH" + " redirect using request.urlconf" + ) + + assert r.status_code == 301 + assert r.url == "/customurlconf/needsquoting%23/" + + @pytest.mark.parametrize("append_slash_fixture", [False], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["PREPEND_WWW", True]], indirect=True + ) + async def test_prepend_www_custom_urlconf( + self, append_slash_fixture, set_setting_fixture + ): + request = self.rf.get("/customurlconf/path/") + request.urlconf = "test_middlewares.extra_urls" + r = await AsyncCommonMiddleware(get_response_empty).process_request(request) + assert r.status_code == 301 + assert r.url == "http://www.testserver/customurlconf/path/" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["PREPEND_WWW", True]], indirect=True + ) + async def test_prepend_www_append_slash_have_slash_custom_urlconf( + self, append_slash_fixture, set_setting_fixture + ): + request = self.rf.get("/customurlconf/slash/") + request.urlconf = "test_middlewares.extra_urls" + r = await AsyncCommonMiddleware(get_response_empty).process_request(request) + assert r.status_code == 301 + assert r.url == "http://www.testserver/customurlconf/slash/" + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + @pytest.mark.parametrize( + "set_setting_fixture", [["PREPEND_WWW", True]], indirect=True + ) + async def test_prepend_www_append_slash_slashless_custom_urlconf( + self, append_slash_fixture, set_setting_fixture + ): + request = self.rf.get("/customurlconf/slash") + request.urlconf = "test_middlewares.extra_urls" + r = await AsyncCommonMiddleware(get_response_empty).process_request(request) + assert r.status_code == 301 + assert r.url == "http://www.testserver/customurlconf/slash/" + + # Tests for the Content-Length header + + async def test_content_length_header_added(self): + async def get_response(req): + response = HttpResponse("content") + assert b"Content-Length" not in response.headers + return response + + response = await AsyncCommonMiddleware(get_response)(self.rf.get("/")) + assert int(response.headers["Content-Length"]) == len(response.content) + + async def test_content_length_header_not_added_for_streaming_response(self): + async def get_response(req): + response = StreamingHttpResponse("content") + assert b"Content-Length" not in response + return response + + response = await AsyncCommonMiddleware(get_response)(self.rf.get("/")) + assert b"Content-Length" not in response + + async def test_content_length_header_not_changed(self): + bad_content_length = 500 + + async def get_response(req): + response = HttpResponse() + response.headers["Content-Length"] = bad_content_length + return response + + response = await AsyncCommonMiddleware(get_response)(self.rf.get("/")) + assert int(response.headers["Content-Length"]) == bad_content_length + + # Other tests + + @pytest.mark.parametrize( + "set_setting_fixture", + [["DISALLOWED_USER_AGENTS", [re.compile(r"foo")]]], + indirect=True, + ) + async def test_disallowed_user_agents(self, set_setting_fixture): + request = self.rf.get("/slash") + request.META["HTTP_USER_AGENT"] = "foo" + with pytest.raises(PermissionDenied, match="Forbidden user agent"): + await AsyncCommonMiddleware(get_response_empty).process_request(request) + + async def test_non_ascii_query_string_does_not_crash(self): + """Regression test for #15152""" + request = self.rf.get("/slash") + request.META["QUERY_STRING"] = "drink=café" + r = await AsyncCommonMiddleware(get_response_empty).process_request(request) + assert r is None + response = HttpResponseNotFound() + r = await AsyncCommonMiddleware(get_response_empty).process_response( + request, response + ) + assert r.status_code == 301 + + async def test_response_redirect_class(self): + request = self.rf.get("/slash") + r = await AsyncCommonMiddleware(get_response_404)(request) + assert r.status_code == 301 + assert r.url == "/slash/" + assert isinstance(r, HttpResponsePermanentRedirect) + + async def test_response_redirect_class_subclass(self): + class MyCommonMiddleware(AsyncCommonMiddleware): + response_redirect_class = HttpResponseRedirect + + request = self.rf.get("/slash") + r = await MyCommonMiddleware(get_response_404)(request) + assert r.status_code == 302 + assert r.url == "/slash/" + assert isinstance(r, HttpResponseRedirect) + + +class TestBrokenLinkEmailsMiddleware: + rf = AsyncRequestFactory() + + @pytest.fixture(autouse=True) + def setting_fixture(self, settings): + old_ignorable_404_urls = settings.IGNORABLE_404_URLS + old_managers = settings.MANAGERS + settings.IGNORABLE_404_URLS = [re.compile(r"foo")] + settings.MANAGERS = [("PHD", "PHB@dilbert.com")] + yield settings + settings.IGNORABLE_404_URLS = old_ignorable_404_urls + settings.MANAGERS = old_managers + + @pytest.fixture(autouse=True) + def setup(self): + self.req = self.rf.get("/regular_url/that/does/not/exist") + + async def get_response(self, req): + return await client.get(req.path) + + async def test_404_error_reporting(self): + self.req.META["HTTP_REFERER"] = "/another/url/" + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 1 + assert "Broken" in mail.outbox[0].subject + + async def test_404_error_reporting_no_referer(self): + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 0 + + async def test_404_error_reporting_ignored_url(self): + self.req.path = self.req.path_info = "foo_url/that/does/not/exist" + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 0 + + async def test_custom_request_checker(self): + class SubclassedMiddleware(AsyncBrokenLinkEmailsMiddleware): + ignored_user_agent_patterns = ( + re.compile(r"Spider.*"), + re.compile(r"Robot.*"), + ) + + def is_ignorable_request(self, request, uri, domain, referer): + """Check user-agent in addition to normal checks.""" + if super().is_ignorable_request(request, uri, domain, referer): + return True + user_agent = request.META["HTTP_USER_AGENT"] + return any( + pattern.search(user_agent) + for pattern in self.ignored_user_agent_patterns + ) + + self.req.META["HTTP_REFERER"] = "/another/url/" + self.req.META["HTTP_USER_AGENT"] = "Spider machine 3.4" + await SubclassedMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 0 + self.req.META["HTTP_USER_AGENT"] = "My user agent" + await SubclassedMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 1 + + async def test_referer_equal_to_requested_url(self, settings): + """ + Some bots set the referer to the current URL to avoid being blocked by + an referer check (#25302). + """ + self.req.META["HTTP_REFERER"] = self.req.path + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 0 + + # URL with scheme and domain should also be ignored + self.req.META["HTTP_REFERER"] = "http://testserver%s" % self.req.path + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 0 + + # URL with a different scheme should be ignored as well because bots + # tend to use http:// in referers even when browsing HTTPS websites. + self.req.META["HTTP_X_PROTO"] = "https" + self.req.META["SERVER_PORT"] = 443 + settings.SECURE_PROXY_SSL_HEADER = ("HTTP_X_PROTO", "https") + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 0 + + async def test_referer_equal_to_requested_url_on_another_domain(self): + self.req.META["HTTP_REFERER"] = "http://anotherserver%s" % self.req.path + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 1 + + @pytest.mark.parametrize("append_slash_fixture", [True], indirect=True) + async def test_referer_equal_to_requested_url_without_trailing_slash_with_append_slash( # noqa: E501 + self, append_slash_fixture + ): + self.req.path = self.req.path_info = "/regular_url/that/does/not/exist/" + self.req.META["HTTP_REFERER"] = self.req.path_info[:-1] + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 0 + + @pytest.mark.parametrize("append_slash_fixture", [False], indirect=True) + async def test_referer_equal_to_requested_url_without_trailing_slash_with_no_append_slash( # noqa: E501 + self, append_slash_fixture + ): + self.req.path = self.req.path_info = "/regular_url/that/does/not/exist/" + self.req.META["HTTP_REFERER"] = self.req.path_info[:-1] + await AsyncBrokenLinkEmailsMiddleware(self.get_response)(self.req) + assert len(mail.outbox) == 1 diff --git a/tests/test_middlewares/test_middleware_mixin.py b/tests/test_middlewares/test_middleware_mixin.py index 8263028..236ffb3 100644 --- a/tests/test_middlewares/test_middleware_mixin.py +++ b/tests/test_middlewares/test_middleware_mixin.py @@ -6,6 +6,10 @@ from django.http.response import HttpResponse from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.middleware.common import ( + AsyncBrokenLinkEmailsMiddleware, + AsyncCommonMiddleware, +) from django_async_extensions.middleware.gzip import AsyncGZipMiddleware from django_async_extensions.middleware.http import AsyncConditionalGetMiddleware from django_async_extensions.middleware.locale import AsyncLocaleMiddleware @@ -39,6 +43,8 @@ class TestMiddlewareMixin: AsyncLocaleMiddleware, AsyncConditionalGetMiddleware, AsyncGZipMiddleware, + AsyncCommonMiddleware, + AsyncBrokenLinkEmailsMiddleware, ] def test_repr(self): diff --git a/tests/test_middlewares/urls.py b/tests/test_middlewares/urls.py index 4d07945..283d407 100644 --- a/tests/test_middlewares/urls.py +++ b/tests/test_middlewares/urls.py @@ -1,6 +1,6 @@ from django.conf.urls.i18n import i18n_patterns from django.http import HttpResponse, StreamingHttpResponse -from django.urls import path +from django.urls import path, re_path from django.utils.translation import gettext_lazy as _ @@ -14,6 +14,14 @@ async def stream_http_generator(): urlpatterns = [ + path("noslash", views.empty_view), + path("slash/", views.empty_view), + path("needsquoting#/", views.empty_view), + # Accepts paths with two leading slashes. + re_path(r"^(.+)/security/$", views.empty_view), + # Should not append slash. + path("sensitive_fbv/", views.sensitive_fbv), + path("sensitive_cbv/", views.SensitiveCBV.as_view()), path("middleware_exceptions/view/", views.normal_view), path("middleware_exceptions/error/", views.server_error), path("middleware_exceptions/permission_denied/", views.permission_denied), diff --git a/tests/test_middlewares/views.py b/tests/test_middlewares/views.py index 0f1595b..29aba38 100644 --- a/tests/test_middlewares/views.py +++ b/tests/test_middlewares/views.py @@ -2,6 +2,10 @@ from django.http import HttpResponse from django.template import engines from django.template.response import TemplateResponse +from django.utils.decorators import method_decorator +from django.views.decorators.common import no_append_slash + +from django_async_extensions.views.generic.base import AsyncView def normal_view(request): @@ -37,3 +41,18 @@ async def render(self): raise Exception("Exception in HttpResponse.render()") return CustomHttpResponse("Error") + + +async def empty_view(request, *args, **kwargs): + return HttpResponse() + + +@no_append_slash +async def sensitive_fbv(request, *args, **kwargs): + return HttpResponse() + + +@method_decorator(no_append_slash, name="dispatch") +class SensitiveCBV(AsyncView): + async def get(self, *args, **kwargs): + return HttpResponse() From c272a822a0b2559132ecff610a2bd0af6b2b596c Mon Sep 17 00:00:00 2001 From: amirreza Date: Sun, 30 Mar 2025 00:39:20 +0330 Subject: [PATCH 19/21] implemented async clickjacking middleware --- .../middleware/clickjacking.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 django_async_extensions/middleware/clickjacking.py diff --git a/django_async_extensions/middleware/clickjacking.py b/django_async_extensions/middleware/clickjacking.py new file mode 100644 index 0000000..4f1a0ac --- /dev/null +++ b/django_async_extensions/middleware/clickjacking.py @@ -0,0 +1,49 @@ +""" +Clickjacking Protection Middleware. + +This module provides a middleware that implements protection against a +malicious site loading resources from your site in a hidden frame. +""" + +from django.conf import settings + +from django_async_extensions.middleware.base import AsyncMiddlewareMixin + + +class AsyncXFrameOptionsMiddleware(AsyncMiddlewareMixin): + """ + Set the X-Frame-Options HTTP header in HTTP responses. + + Do not set the header if it's already set or if the response contains + a xframe_options_exempt value set to True. + + By default, set the X-Frame-Options header to 'DENY', meaning the response + cannot be displayed in a frame, regardless of the site attempting to do so. + To enable the response to be loaded on a frame within the same site, set + X_FRAME_OPTIONS in your project's Django settings to 'SAMEORIGIN'. + """ + + async def process_response(self, request, response): + # Don't set it if it's already in the response + if response.get("X-Frame-Options") is not None: + return response + + # Don't set it if they used @xframe_options_exempt + if getattr(response, "xframe_options_exempt", False): + return response + + response.headers["X-Frame-Options"] = self.get_xframe_options_value( + request, + response, + ) + return response + + def get_xframe_options_value(self, request, response): + """ + Get the value to set for the X_FRAME_OPTIONS header. Use the value from + the X_FRAME_OPTIONS setting, or 'DENY' if not set. + + This method can be overridden if needed, allowing it to vary based on + the request or response. + """ + return getattr(settings, "X_FRAME_OPTIONS", "DENY").upper() From 5272b94f93439cb26fe93997bd058567894c669a Mon Sep 17 00:00:00 2001 From: amirreza Date: Sun, 30 Mar 2025 00:39:43 +0330 Subject: [PATCH 20/21] added tests for async clickjacking middleware --- tests/test_middlewares/test_clickjacking.py | 138 ++++++++++++++++++ .../test_middlewares/test_middleware_mixin.py | 2 + 2 files changed, 140 insertions(+) create mode 100644 tests/test_middlewares/test_clickjacking.py diff --git a/tests/test_middlewares/test_clickjacking.py b/tests/test_middlewares/test_clickjacking.py new file mode 100644 index 0000000..19f3462 --- /dev/null +++ b/tests/test_middlewares/test_clickjacking.py @@ -0,0 +1,138 @@ +from django.http import HttpResponse, HttpResponseNotFound, HttpRequest + +from django_async_extensions.middleware.clickjacking import ( + AsyncXFrameOptionsMiddleware, +) + + +async def get_response_empty(request): + return HttpResponse() + + +async def get_response_404(request): + return HttpResponseNotFound() + + +class TestXFrameOptionsMiddleware: + """ + Tests for the X-Frame-Options clickjacking prevention middleware. + """ + + async def test_same_origin(self, settings): + """ + The X_FRAME_OPTIONS setting can be set to SAMEORIGIN to have the + middleware use that value for the HTTP header. + """ + settings.X_FRAME_OPTIONS = "SAMEORIGIN" + r = await AsyncXFrameOptionsMiddleware(get_response_empty)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "SAMEORIGIN" + + settings.X_FRAME_OPTIONS = "sameorigin" + r = await AsyncXFrameOptionsMiddleware(get_response_empty)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "SAMEORIGIN" + + async def test_deny(self, settings): + """ + The X_FRAME_OPTIONS setting can be set to DENY to have the middleware + use that value for the HTTP header. + """ + settings.X_FRAME_OPTIONS = "DENY" + r = await AsyncXFrameOptionsMiddleware(get_response_empty)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "DENY" + + settings.X_FRAME_OPTIONS = "deny" + r = await AsyncXFrameOptionsMiddleware(get_response_empty)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "DENY" + + async def test_defaults_sameorigin(self, settings): + """ + If the X_FRAME_OPTIONS setting is not set then it defaults to + DENY. + """ + settings.X_FRAME_OPTIONS = None + del settings.X_FRAME_OPTIONS # restored by override_settings + r = await AsyncXFrameOptionsMiddleware(get_response_empty)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "DENY" + + async def test_dont_set_if_set(self, settings): + """ + If the X-Frame-Options header is already set then the middleware does + not attempt to override it. + """ + + async def same_origin_response(request): + response = HttpResponse() + response.headers["X-Frame-Options"] = "SAMEORIGIN" + return response + + async def deny_response(request): + response = HttpResponse() + response.headers["X-Frame-Options"] = "DENY" + return response + + settings.X_FRAME_OPTIONS = "DENY" + r = await AsyncXFrameOptionsMiddleware(same_origin_response)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "SAMEORIGIN" + + settings.X_FRAME_OPTIONS = "SAMEORIGIN" + r = await AsyncXFrameOptionsMiddleware(deny_response)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "DENY" + + async def test_response_exempt(self, settings): + """ + If the response has an xframe_options_exempt attribute set to False + then it still sets the header, but if it's set to True then it doesn't. + """ + + async def xframe_exempt_response(request): + response = HttpResponse() + response.xframe_options_exempt = True + return response + + async def xframe_not_exempt_response(request): + response = HttpResponse() + response.xframe_options_exempt = False + return response + + settings.X_FRAME_OPTIONS = "SAMEORIGIN" + r = await AsyncXFrameOptionsMiddleware(xframe_not_exempt_response)( + HttpRequest() + ) + assert r.headers["X-Frame-Options"] == "SAMEORIGIN" + + r = await AsyncXFrameOptionsMiddleware(xframe_exempt_response)(HttpRequest()) + assert r.headers.get("X-Frame-Options") is None + + async def test_is_extendable(self, settings): + """ + The XFrameOptionsMiddleware method that determines the X-Frame-Options + header value can be overridden based on something in the request or + response. + """ + + class OtherXFrameOptionsMiddleware(AsyncXFrameOptionsMiddleware): + # This is just an example for testing purposes... + def get_xframe_options_value(self, request, response): + if getattr(request, "sameorigin", False): + return "SAMEORIGIN" + if getattr(response, "sameorigin", False): + return "SAMEORIGIN" + return "DENY" + + async def same_origin_response(request): + response = HttpResponse() + response.sameorigin = True + return response + + settings.X_FRAME_OPTIONS = "DENY" + r = await OtherXFrameOptionsMiddleware(same_origin_response)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "SAMEORIGIN" + + request = HttpRequest() + request.sameorigin = True + r = await OtherXFrameOptionsMiddleware(get_response_empty)(request) + assert r.headers["X-Frame-Options"] == "SAMEORIGIN" + + settings.X_FRAME_OPTIONS = "SAMEORIGIN" + r = await OtherXFrameOptionsMiddleware(get_response_empty)(HttpRequest()) + assert r.headers["X-Frame-Options"] == "DENY" diff --git a/tests/test_middlewares/test_middleware_mixin.py b/tests/test_middlewares/test_middleware_mixin.py index 236ffb3..0251250 100644 --- a/tests/test_middlewares/test_middleware_mixin.py +++ b/tests/test_middlewares/test_middleware_mixin.py @@ -6,6 +6,7 @@ from django.http.response import HttpResponse from django_async_extensions.middleware.base import AsyncMiddlewareMixin +from django_async_extensions.middleware.clickjacking import AsyncXFrameOptionsMiddleware from django_async_extensions.middleware.common import ( AsyncBrokenLinkEmailsMiddleware, AsyncCommonMiddleware, @@ -45,6 +46,7 @@ class TestMiddlewareMixin: AsyncGZipMiddleware, AsyncCommonMiddleware, AsyncBrokenLinkEmailsMiddleware, + AsyncXFrameOptionsMiddleware, ] def test_repr(self): From 33166847a1dc1cd04271be11789aece24f091d37 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sun, 30 Mar 2025 00:39:54 +0330 Subject: [PATCH 21/21] document async clickjacking middleware --- docs/middleware/x_frame_options_middleware.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 docs/middleware/x_frame_options_middleware.md diff --git a/docs/middleware/x_frame_options_middleware.md b/docs/middleware/x_frame_options_middleware.md new file mode 100644 index 0000000..ef3d38c --- /dev/null +++ b/docs/middleware/x_frame_options_middleware.md @@ -0,0 +1,10 @@ +# AsyncXFrameOptionsMiddleware + +it works exactly like django's [XFrameOptionsMiddleware](https://docs.djangoproject.com/en/5.1/ref/middleware/#x-frame-options-middleware) +except that it's fully async + +### Usage: +remove django's `django.middleware.clickjacking.XFrameOptionsMiddleware` from the `MIDDLEWARE` setting and add +`django_async_extensions.middleware.clickjacking.AsyncXFrameOptionsMiddleware` in it's place. + +**note**: this middleware like other middlewares provided in this package can work alongside sync middlewares, and can handle sync views.