Skip to content

Commit 3068b62

Browse files
committed
new tests for security
1 parent 63fce5f commit 3068b62

File tree

2 files changed

+80
-32
lines changed

2 files changed

+80
-32
lines changed

rest_framework/schemas/openapi.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,11 @@ def get_schema(self, request=None, public=False):
7373
security_schemes_schemas = {}
7474
root_security_requirements = []
7575

76-
for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES:
77-
req = auth_class.openapi_security_requirement(None, None)
78-
if req:
79-
root_security_requirements += req
76+
if api_settings.DEFAULT_AUTHENTICATION_CLASSES:
77+
for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES:
78+
req = auth_class.openapi_security_requirement(None, None)
79+
if req:
80+
root_security_requirements += req
8081

8182
# Iterate endpoints generating per method path operations.
8283
paths = {}
@@ -721,6 +722,34 @@ def get_tags(self, path, method):
721722

722723
return [path.split('/')[0].replace('_', '-')]
723724

725+
def get_security_schemes(self, path, method):
726+
"""
727+
Get components.schemas.securitySchemes required by this path.
728+
returns dict of securitySchemes.
729+
"""
730+
schemes = {}
731+
for auth_class in self.view.authentication_classes:
732+
if hasattr(auth_class, 'openapi_security_scheme'):
733+
schemes.update(auth_class.openapi_security_scheme())
734+
return schemes
735+
736+
def get_security_requirements(self, path, method):
737+
"""
738+
Get Security Requirement Object list for this operation.
739+
Returns a list of security requirement objects based on the view's authentication classes
740+
unless this view's authentication classes are the same as the root-level defaults.
741+
"""
742+
# references the securityScheme names described above in get_security_schemes()
743+
security = []
744+
if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES:
745+
return None
746+
for auth_class in self.view.authentication_classes:
747+
if hasattr(auth_class, 'openapi_security_requirement'):
748+
req = auth_class.openapi_security_requirement(self.view, method)
749+
if req:
750+
security += req
751+
return security
752+
724753
def _get_path_parameters(self, path, method):
725754
warnings.warn(
726755
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "
@@ -816,31 +845,3 @@ def _allows_filters(self, path, method):
816845
RemovedInDRF314Warning, stacklevel=2
817846
)
818847
return self.allows_filters(path, method)
819-
820-
def get_security_schemes(self, path, method):
821-
"""
822-
Get components.schemas.securitySchemes required by this path.
823-
returns dict of securitySchemes.
824-
"""
825-
schemes = {}
826-
for auth_class in self.view.authentication_classes:
827-
if hasattr(auth_class, 'openapi_security_scheme'):
828-
schemes.update(auth_class.openapi_security_scheme())
829-
return schemes
830-
831-
def get_security_requirements(self, path, method):
832-
"""
833-
Get Security Requirement Object list for this operation.
834-
Returns a list of security requirement objects based on the view's authentication classes
835-
unless this view's authentication classes are the same as the root-level defaults.
836-
"""
837-
# references the securityScheme names described above in get_security_schemes()
838-
security = []
839-
if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES:
840-
return None
841-
for auth_class in self.view.authentication_classes:
842-
if hasattr(auth_class, 'openapi_security_requirement'):
843-
req = auth_class.openapi_security_requirement(self.view, method)
844-
if req:
845-
security += req
846-
return security

tests/schemas/test_openapi.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.utils.translation import gettext_lazy as _
88

99
from rest_framework import filters, generics, pagination, routers, serializers
10+
from rest_framework.authentication import TokenAuthentication
1011
from rest_framework.authtoken.views import obtain_auth_token
1112
from rest_framework.compat import uritemplate
1213
from rest_framework.parsers import JSONParser, MultiPartParser
@@ -1110,3 +1111,49 @@ class ExampleView(generics.DestroyAPIView):
11101111
schema = generator.get_schema(request=create_request('/'))
11111112
assert 'schemas' not in schema['components']
11121113
assert 'content' not in schema['paths']['/example/']['delete']['responses']['204']
1114+
1115+
def test_default_root_security_schemes(self):
1116+
patterns = [
1117+
url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
1118+
]
1119+
1120+
generator = SchemaGenerator(patterns=patterns)
1121+
1122+
request = create_request('/')
1123+
schema = generator.get_schema(request=request)
1124+
assert 'security' in schema
1125+
assert {'sessionAuth': []} in schema['security']
1126+
assert {'basicAuth': []} in schema['security']
1127+
assert 'security' not in schema['paths']['/example/']['get']
1128+
1129+
@override_settings(REST_FRAMEWORK={'DEFAULT_AUTHENTICATION_CLASSES': None})
1130+
def test_no_default_root_security_schemes(self):
1131+
patterns = [
1132+
url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
1133+
]
1134+
1135+
generator = SchemaGenerator(patterns=patterns)
1136+
1137+
request = create_request('/')
1138+
schema = generator.get_schema(request=request)
1139+
assert 'security' not in schema
1140+
1141+
def test_operation_security_schemes(self):
1142+
class MyExample(views.ExampleAutoSchemaComponentName):
1143+
authentication_classes = [TokenAuthentication]
1144+
1145+
patterns = [
1146+
url(r'^example/?$', MyExample.as_view()),
1147+
]
1148+
1149+
generator = SchemaGenerator(patterns=patterns)
1150+
1151+
request = create_request('/')
1152+
schema = generator.get_schema(request=request)
1153+
assert 'security' in schema
1154+
assert {'sessionAuth': []} in schema['security']
1155+
assert {'basicAuth': []} in schema['security']
1156+
get_operation = schema['paths']['/example/']['get']
1157+
assert 'security' in get_operation
1158+
assert {'tokenAuth': []} in get_operation['security']
1159+
assert len(get_operation['security']) == 1

0 commit comments

Comments
 (0)