diff --git a/docs/docs/guides/throttling.md b/docs/docs/guides/throttling.md index fead0fa52..08e029e00 100644 --- a/docs/docs/guides/throttling.md +++ b/docs/docs/guides/throttling.md @@ -105,3 +105,21 @@ class NoReadsThrottle(AnonRateThrottle): return True return super().allow_request(request) ``` + +## Customizing Client IP Address Lookups + +To use custom client IP address lookup logic, change the `NINJA_CLIENT_IP_CALLABLE` setting to a suitable callable path. + +Example + +```Python +from django.http import HttpRequest + +def get_client_ip(request: HttpRequest) -> str: + return request.META.get("REMOTE_ADDR") +``` + +`settings.py`: +```python +NINJA_CLIENT_IP_CALLABLE = "example.utils.get_client_ip" +``` diff --git a/ninja/conf.py b/ninja/conf.py index 3f2332250..70555bce8 100644 --- a/ninja/conf.py +++ b/ninja/conf.py @@ -21,6 +21,9 @@ class Settings(BaseModel): # Throttling NUM_PROXIES: Optional[int] = Field(None, alias="NINJA_NUM_PROXIES") + CLIENT_IP_CALLABLE: str = Field( + "ninja.throttling.get_client_ip", alias="NINJA_CLIENT_IP_CALLABLE" + ) DEFAULT_THROTTLE_RATES: Dict[str, Optional[str]] = Field( { "auth": "10000/day", diff --git a/ninja/throttling.py b/ninja/throttling.py index 8b1838d5c..b9017adab 100644 --- a/ninja/throttling.py +++ b/ninja/throttling.py @@ -1,10 +1,34 @@ import hashlib import time +import warnings from typing import Dict, List, Optional, Tuple from django.core.cache import cache as default_cache from django.core.exceptions import ImproperlyConfigured from django.http import HttpRequest +from django.utils.module_loading import import_string + + +def get_client_ip(request: HttpRequest) -> Optional[str]: + """ + Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR + if present and number of proxies is > 0. If not use all of + HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. + """ + from ninja.conf import settings + + xff = request.META.get("HTTP_X_FORWARDED_FOR") + remote_addr = request.META.get("REMOTE_ADDR") + num_proxies = settings.NUM_PROXIES + + if num_proxies is not None: + if num_proxies == 0 or xff is None: + return remote_addr + addrs: List[str] = xff.split(",") + client_addr = addrs[-min(num_proxies, len(addrs))] + return client_addr.strip() + + return "".join(xff.split()) if xff else remote_addr class BaseThrottle: @@ -12,6 +36,11 @@ class BaseThrottle: Rate throttling of requests. """ + def __init__(self) -> None: + from ninja.conf import settings + + self.client_ip = import_string(settings.CLIENT_IP_CALLABLE) + def allow_request(self, request: HttpRequest) -> bool: """ Return `True` if the request should be allowed, `False` otherwise. @@ -19,25 +48,12 @@ def allow_request(self, request: HttpRequest) -> bool: raise NotImplementedError(".allow_request() must be overridden") def get_ident(self, request: HttpRequest) -> Optional[str]: - """ - Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR - if present and number of proxies is > 0. If not use all of - HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. - """ - from ninja.conf import settings - - xff = request.META.get("HTTP_X_FORWARDED_FOR") - remote_addr = request.META.get("REMOTE_ADDR") - num_proxies = settings.NUM_PROXIES - - if num_proxies is not None: - if num_proxies == 0 or xff is None: - return remote_addr - addrs: List[str] = xff.split(",") - client_addr = addrs[-min(num_proxies, len(addrs))] - return client_addr.strip() - - return "".join(xff.split()) if xff else remote_addr + warnings.warn( + ".get_ident() is deprecated, use .client_ip() instead", + DeprecationWarning, + stacklevel=2, + ) + return get_client_ip(request) def wait(self) -> Optional[float]: """ @@ -79,6 +95,7 @@ class SimpleRateThrottle(BaseThrottle): } def __init__(self, rate: Optional[str] = None): + super().__init__() self.rate: Optional[str] if rate: self.rate = rate @@ -204,7 +221,7 @@ def get_cache_key(self, request: HttpRequest) -> Optional[str]: return self.cache_format % { "scope": self.scope, - "ident": self.get_ident(request), + "ident": self.client_ip(request), } @@ -224,7 +241,7 @@ def get_cache_key(self, request: HttpRequest) -> str: ident = hashlib.sha256(str(request.auth).encode()).hexdigest() # type: ignore # TODO: ^maybe auth should have an attribute that developer can overwrite else: - ident = self.get_ident(request) # type: ignore + ident = self.client_ip(request) return self.cache_format % {"scope": self.scope, "ident": ident} @@ -244,6 +261,6 @@ def get_cache_key(self, request: HttpRequest) -> str: if request.user and request.user.is_authenticated: ident = request.user.pk else: - ident = self.get_ident(request) + ident = self.client_ip(request) return self.cache_format % {"scope": self.scope, "ident": ident} diff --git a/tests/test_throttling.py b/tests/test_throttling.py index 0913f0f9d..a7860b59e 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -278,18 +278,40 @@ def test_proxy_throttle(): th = SimpleRateThrottle("1/s") request = build_request(x_forwarded_for=None) - assert th.get_ident(request) == "8.8.8.8" + assert th.client_ip(request) == "8.8.8.8" + assert th.get_ident(request) == "8.8.8.8" # Deprecated settings.NUM_PROXIES = 0 request = build_request(x_forwarded_for="8.8.8.8,127.0.0.1") - assert th.get_ident(request) == "8.8.8.8" + assert th.client_ip(request) == "8.8.8.8" + assert th.get_ident(request) == "8.8.8.8" # Deprecated settings.NUM_PROXIES = 1 - assert th.get_ident(request) == "127.0.0.1" + assert th.client_ip(request) == "127.0.0.1" + assert th.get_ident(request) == "127.0.0.1" # Deprecated settings.NUM_PROXIES = None +def custom_client_ip_callable(request) -> str: + return "custom-ip" + + +def test_proxy_throttle_custom_client_ip_callable(): + from ninja.conf import settings + + settings.CLIENT_IP_CALLABLE = "tests.test_throttling.custom_client_ip_callable" + + th = SimpleRateThrottle("1/s") + request = build_request() + assert th.client_ip(request) == "custom-ip" + assert ( + th.get_ident(request) == "8.8.8.8" + ) # Note: Deprecated function preserves old behavior and ignores new CLIENT_IP_CALLABLE setting. + + settings.CLIENT_IP_CALLABLE = "ninja.throttling.get_client_ip" + + def test_base_classes(): base = BaseThrottle() with pytest.raises(NotImplementedError):