Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/docs/guides/throttling.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,21 @@ class NoReadsThrottle(AnonRateThrottle):
return True
return super().allow_request(request)
```

## Customizing Client IP Address Lookups

To use custom client IP address lookup logic, change the `NINJA_CLIENT_IP_CALLABLE` setting to a suitable callable path.

Example

```Python
from django.http import HttpRequest

def get_client_ip(request: HttpRequest) -> str:
return request.META.get("REMOTE_ADDR")
```

`settings.py`:
```python
NINJA_CLIENT_IP_CALLABLE = "example.utils.get_client_ip"
```
3 changes: 3 additions & 0 deletions ninja/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Settings(BaseModel):

# Throttling
NUM_PROXIES: Optional[int] = Field(None, alias="NINJA_NUM_PROXIES")
CLIENT_IP_CALLABLE: str = Field(
"ninja.throttling.get_client_ip", alias="NINJA_CLIENT_IP_CALLABLE"
)
DEFAULT_THROTTLE_RATES: Dict[str, Optional[str]] = Field(
{
"auth": "10000/day",
Expand Down
61 changes: 39 additions & 22 deletions ninja/throttling.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,59 @@
import hashlib
import time
import warnings
from typing import Dict, List, Optional, Tuple

from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured
from django.http import HttpRequest
from django.utils.module_loading import import_string


def get_client_ip(request: HttpRequest) -> Optional[str]:
"""
Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
from ninja.conf import settings

xff = request.META.get("HTTP_X_FORWARDED_FOR")
remote_addr = request.META.get("REMOTE_ADDR")
num_proxies = settings.NUM_PROXIES

if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs: List[str] = xff.split(",")
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()

return "".join(xff.split()) if xff else remote_addr


class BaseThrottle:
"""
Rate throttling of requests.
"""

def __init__(self) -> None:
from ninja.conf import settings

self.client_ip = import_string(settings.CLIENT_IP_CALLABLE)

def allow_request(self, request: HttpRequest) -> bool:
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
raise NotImplementedError(".allow_request() must be overridden")

def get_ident(self, request: HttpRequest) -> Optional[str]:
"""
Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
from ninja.conf import settings

xff = request.META.get("HTTP_X_FORWARDED_FOR")
remote_addr = request.META.get("REMOTE_ADDR")
num_proxies = settings.NUM_PROXIES

if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs: List[str] = xff.split(",")
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()

return "".join(xff.split()) if xff else remote_addr
warnings.warn(
".get_ident() is deprecated, use .client_ip() instead",
DeprecationWarning,
stacklevel=2,
)
return get_client_ip(request)

def wait(self) -> Optional[float]:
"""
Expand Down Expand Up @@ -79,6 +95,7 @@ class SimpleRateThrottle(BaseThrottle):
}

def __init__(self, rate: Optional[str] = None):
super().__init__()
self.rate: Optional[str]
if rate:
self.rate = rate
Expand Down Expand Up @@ -204,7 +221,7 @@ def get_cache_key(self, request: HttpRequest) -> Optional[str]:

return self.cache_format % {
"scope": self.scope,
"ident": self.get_ident(request),
"ident": self.client_ip(request),
}


Expand All @@ -224,7 +241,7 @@ def get_cache_key(self, request: HttpRequest) -> str:
ident = hashlib.sha256(str(request.auth).encode()).hexdigest() # type: ignore
# TODO: ^maybe auth should have an attribute that developer can overwrite
else:
ident = self.get_ident(request) # type: ignore
ident = self.client_ip(request)

return self.cache_format % {"scope": self.scope, "ident": ident}

Expand All @@ -244,6 +261,6 @@ def get_cache_key(self, request: HttpRequest) -> str:
if request.user and request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
ident = self.client_ip(request)

return self.cache_format % {"scope": self.scope, "ident": ident}
28 changes: 25 additions & 3 deletions tests/test_throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,40 @@ def test_proxy_throttle():

th = SimpleRateThrottle("1/s")
request = build_request(x_forwarded_for=None)
assert th.get_ident(request) == "8.8.8.8"
assert th.client_ip(request) == "8.8.8.8"
assert th.get_ident(request) == "8.8.8.8" # Deprecated

settings.NUM_PROXIES = 0
request = build_request(x_forwarded_for="8.8.8.8,127.0.0.1")
assert th.get_ident(request) == "8.8.8.8"
assert th.client_ip(request) == "8.8.8.8"
assert th.get_ident(request) == "8.8.8.8" # Deprecated

settings.NUM_PROXIES = 1
assert th.get_ident(request) == "127.0.0.1"
assert th.client_ip(request) == "127.0.0.1"
assert th.get_ident(request) == "127.0.0.1" # Deprecated

settings.NUM_PROXIES = None


def custom_client_ip_callable(request) -> str:
return "custom-ip"


def test_proxy_throttle_custom_client_ip_callable():
from ninja.conf import settings

settings.CLIENT_IP_CALLABLE = "tests.test_throttling.custom_client_ip_callable"

th = SimpleRateThrottle("1/s")
request = build_request()
assert th.client_ip(request) == "custom-ip"
assert (
th.get_ident(request) == "8.8.8.8"
) # Note: Deprecated function preserves old behavior and ignores new CLIENT_IP_CALLABLE setting.

settings.CLIENT_IP_CALLABLE = "ninja.throttling.get_client_ip"


def test_base_classes():
base = BaseThrottle()
with pytest.raises(NotImplementedError):
Expand Down