diff --git a/README.rst b/README.rst index 89041094..e74cfcfb 100644 --- a/README.rst +++ b/README.rst @@ -58,7 +58,7 @@ Alternatively you can download the code and install from the repository: Usage ##### -Firstly create your specification object. By default, OpenAPI spec version is detected: +Firstly create your specification object. .. code-block:: python @@ -66,10 +66,12 @@ Firstly create your specification object. By default, OpenAPI spec version is de spec = Spec.from_file_path('openapi.json') +Now you can use it to validate against requests and/or responses. + Request ******* -Now you can use it to validate against requests +Use ``validate_request`` function to validate request against a given spec. .. code-block:: python @@ -78,7 +80,7 @@ Now you can use it to validate against requests # raise error if request is invalid result = validate_request(request, spec=spec) -and unmarshal request data from validation result +Retrieve request data from validation result .. code-block:: python @@ -98,7 +100,7 @@ Request object should implement OpenAPI Request protocol (See `Integrations HttpResponse: openapi_request = self._get_openapi_request(request) - req_result = self.validation_processor.process_request( - settings.OPENAPI_SPEC, openapi_request - ) + req_result = self.validation_processor.process_request(openapi_request) if req_result.errors: response = self._handle_request_errors(req_result, request) else: @@ -48,7 +40,7 @@ def __call__(self, request: HttpRequest) -> HttpResponse: openapi_response = self._get_openapi_response(response) resp_result = self.validation_processor.process_response( - settings.OPENAPI_SPEC, openapi_request, openapi_response + openapi_request, openapi_response ) if resp_result.errors: return self._handle_response_errors(resp_result, request, response) diff --git a/openapi_core/contrib/falcon/middlewares.py b/openapi_core/contrib/falcon/middlewares.py index c2d509f7..142bf63a 100644 --- a/openapi_core/contrib/falcon/middlewares.py +++ b/openapi_core/contrib/falcon/middlewares.py @@ -11,13 +11,13 @@ from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse from openapi_core.spec import Spec from openapi_core.validation.processors import OpenAPIProcessor -from openapi_core.validation.request import openapi_request_validator from openapi_core.validation.request.datatypes import RequestValidationResult -from openapi_core.validation.response import openapi_response_validator +from openapi_core.validation.request.protocols import RequestValidator from openapi_core.validation.response.datatypes import ResponseValidationResult +from openapi_core.validation.response.protocols import ResponseValidator -class FalconOpenAPIMiddleware: +class FalconOpenAPIMiddleware(OpenAPIProcessor): request_class = FalconOpenAPIRequest response_class = FalconOpenAPIResponse @@ -26,13 +26,17 @@ class FalconOpenAPIMiddleware: def __init__( self, spec: Spec, - validation_processor: OpenAPIProcessor, + request_validator_cls: Optional[Type[RequestValidator]] = None, + response_validator_cls: Optional[Type[ResponseValidator]] = None, request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, ): - self.spec = spec - self.validation_processor = validation_processor + super().__init__( + spec, + request_validator_cls=request_validator_cls, + response_validator_cls=response_validator_cls, + ) self.request_class = request_class or self.request_class self.response_class = response_class or self.response_class self.errors_handler = errors_handler or self.errors_handler @@ -41,33 +45,33 @@ def __init__( def from_spec( cls, spec: Spec, + request_validator_cls: Optional[Type[RequestValidator]] = None, + response_validator_cls: Optional[Type[ResponseValidator]] = None, request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, ) -> "FalconOpenAPIMiddleware": - validation_processor = OpenAPIProcessor( - openapi_request_validator, openapi_response_validator - ) return cls( spec, - validation_processor, + request_validator_cls=request_validator_cls, + response_validator_cls=response_validator_cls, request_class=request_class, response_class=response_class, errors_handler=errors_handler, ) - def process_request(self, req: Request, resp: Response) -> None: + def process_request(self, req: Request, resp: Response) -> None: # type: ignore openapi_req = self._get_openapi_request(req) - req.context.openapi = self._process_openapi_request(openapi_req) + req.context.openapi = super().process_request(openapi_req) if req.context.openapi.errors: return self._handle_request_errors(req, resp, req.context.openapi) - def process_response( + def process_response( # type: ignore self, req: Request, resp: Response, resource: Any, req_succeeded: bool ) -> None: openapi_req = self._get_openapi_request(req) openapi_resp = self._get_openapi_response(resp) - resp.context.openapi = self._process_openapi_response( + resp.context.openapi = super().process_response( openapi_req, openapi_resp ) if resp.context.openapi.errors: @@ -98,19 +102,3 @@ def _get_openapi_response( self, response: Response ) -> FalconOpenAPIResponse: return self.response_class(response) - - def _process_openapi_request( - self, openapi_request: FalconOpenAPIRequest - ) -> RequestValidationResult: - return self.validation_processor.process_request( - self.spec, openapi_request - ) - - def _process_openapi_response( - self, - opneapi_request: FalconOpenAPIRequest, - openapi_response: FalconOpenAPIResponse, - ) -> ResponseValidationResult: - return self.validation_processor.process_response( - self.spec, opneapi_request, openapi_response - ) diff --git a/openapi_core/contrib/flask/decorators.py b/openapi_core/contrib/flask/decorators.py index 9e2eb182..91066d85 100644 --- a/openapi_core/contrib/flask/decorators.py +++ b/openapi_core/contrib/flask/decorators.py @@ -2,6 +2,7 @@ from functools import wraps from typing import Any from typing import Callable +from typing import Optional from typing import Type from flask.globals import request @@ -14,10 +15,8 @@ from openapi_core.contrib.flask.responses import FlaskOpenAPIResponse from openapi_core.spec import Spec from openapi_core.validation.processors import OpenAPIProcessor -from openapi_core.validation.request import openapi_request_validator from openapi_core.validation.request.datatypes import RequestValidationResult from openapi_core.validation.request.protocols import RequestValidator -from openapi_core.validation.response import openapi_response_validator from openapi_core.validation.response.datatypes import ResponseValidationResult from openapi_core.validation.response.protocols import ResponseValidator @@ -26,8 +25,8 @@ class FlaskOpenAPIViewDecorator(OpenAPIProcessor): def __init__( self, spec: Spec, - request_validator: RequestValidator, - response_validator: ResponseValidator, + request_validator_cls: Optional[Type[RequestValidator]] = None, + response_validator_cls: Optional[Type[ResponseValidator]] = None, request_class: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, response_class: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, @@ -35,8 +34,11 @@ def __init__( FlaskOpenAPIErrorsHandler ] = FlaskOpenAPIErrorsHandler, ): - super().__init__(request_validator, response_validator) - self.spec = spec + super().__init__( + spec, + request_validator_cls=request_validator_cls, + response_validator_cls=response_validator_cls, + ) self.request_class = request_class self.response_class = response_class self.request_provider = request_provider @@ -47,7 +49,7 @@ def __call__(self, view: Callable[..., Any]) -> Callable[..., Any]: def decorated(*args: Any, **kwargs: Any) -> Response: request = self._get_request() openapi_request = self._get_openapi_request(request) - request_result = self.process_request(self.spec, openapi_request) + request_result = self.process_request(openapi_request) if request_result.errors: return self._handle_request_errors(request_result) response = self._handle_request_view( @@ -55,7 +57,7 @@ def decorated(*args: Any, **kwargs: Any) -> Response: ) openapi_response = self._get_openapi_response(response) response_result = self.process_response( - self.spec, openapi_request, openapi_response + openapi_request, openapi_response ) if response_result.errors: return self._handle_response_errors(response_result) @@ -99,6 +101,8 @@ def _get_openapi_response( def from_spec( cls, spec: Spec, + request_validator_cls: Optional[Type[RequestValidator]] = None, + response_validator_cls: Optional[Type[ResponseValidator]] = None, request_class: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, response_class: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, @@ -108,8 +112,8 @@ def from_spec( ) -> "FlaskOpenAPIViewDecorator": return cls( spec, - request_validator=openapi_request_validator, - response_validator=openapi_response_validator, + request_validator_cls=request_validator_cls, + response_validator_cls=response_validator_cls, request_class=request_class, response_class=response_class, request_provider=request_provider, diff --git a/openapi_core/contrib/flask/views.py b/openapi_core/contrib/flask/views.py index 499a37ba..23754bf4 100644 --- a/openapi_core/contrib/flask/views.py +++ b/openapi_core/contrib/flask/views.py @@ -6,8 +6,6 @@ from openapi_core.contrib.flask.decorators import FlaskOpenAPIViewDecorator from openapi_core.contrib.flask.handlers import FlaskOpenAPIErrorsHandler from openapi_core.spec import Spec -from openapi_core.validation.request import openapi_request_validator -from openapi_core.validation.response import openapi_response_validator class FlaskOpenAPIView(MethodView): @@ -22,8 +20,6 @@ def __init__(self, spec: Spec): def dispatch_request(self, *args: Any, **kwargs: Any) -> Any: decorator = FlaskOpenAPIViewDecorator( self.spec, - request_validator=openapi_request_validator, - response_validator=openapi_response_validator, openapi_errors_handler=self.openapi_errors_handler, ) return decorator(super().dispatch_request)(*args, **kwargs) diff --git a/openapi_core/templating/paths/finders.py b/openapi_core/templating/paths/finders.py index 377ff68d..0eb37430 100644 --- a/openapi_core/templating/paths/finders.py +++ b/openapi_core/templating/paths/finders.py @@ -29,15 +29,8 @@ def __init__(self, spec: Spec, base_url: Optional[str] = None): def find( self, method: str, - host_url: str, - path: str, - path_pattern: Optional[str] = None, + full_url: str, ) -> ServerOperationPath: - if path_pattern is not None: - full_url = urljoin(host_url, path_pattern) - else: - full_url = urljoin(host_url, path) - paths_iter = self._get_paths_iter(full_url) paths_iter_peek = peekable(paths_iter) diff --git a/openapi_core/validation/processors.py b/openapi_core/validation/processors.py index c2d9356d..3b21c71a 100644 --- a/openapi_core/validation/processors.py +++ b/openapi_core/validation/processors.py @@ -1,18 +1,26 @@ """OpenAPI core validation processors module""" +from typing import Optional +from typing import Type + from openapi_core.spec import Spec from openapi_core.validation.request.datatypes import RequestValidationResult from openapi_core.validation.request.protocols import Request from openapi_core.validation.request.protocols import RequestValidator +from openapi_core.validation.request.proxies import SpecRequestValidatorProxy from openapi_core.validation.response.datatypes import ResponseValidationResult from openapi_core.validation.response.protocols import Response from openapi_core.validation.response.protocols import ResponseValidator +from openapi_core.validation.response.proxies import SpecResponseValidatorProxy +from openapi_core.validation.shortcuts import get_validators +from openapi_core.validation.shortcuts import validate_request +from openapi_core.validation.shortcuts import validate_response -class OpenAPIProcessor: +class OpenAPISpecProcessor: def __init__( self, - request_validator: RequestValidator, - response_validator: ResponseValidator, + request_validator: SpecRequestValidatorProxy, + response_validator: SpecResponseValidatorProxy, ): self.request_validator = request_validator self.response_validator = response_validator @@ -26,3 +34,29 @@ def process_response( self, spec: Spec, request: Request, response: Response ) -> ResponseValidationResult: return self.response_validator.validate(spec, request, response) + + +class OpenAPIProcessor: + def __init__( + self, + spec: Spec, + request_validator_cls: Optional[Type[RequestValidator]] = None, + response_validator_cls: Optional[Type[ResponseValidator]] = None, + ): + self.spec = spec + if request_validator_cls is None or response_validator_cls is None: + validators = get_validators(self.spec) + if request_validator_cls is None: + request_validator_cls = validators.request + if response_validator_cls is None: + response_validator_cls = validators.response + self.request_validator = request_validator_cls(self.spec) + self.response_validator = response_validator_cls(self.spec) + + def process_request(self, request: Request) -> RequestValidationResult: + return self.request_validator.validate(request) + + def process_response( + self, request: Request, response: Response + ) -> ResponseValidationResult: + return self.response_validator.validate(request, response) diff --git a/openapi_core/validation/request/__init__.py b/openapi_core/validation/request/__init__.py index d4c57fc4..9ff42510 100644 --- a/openapi_core/validation/request/__init__.py +++ b/openapi_core/validation/request/__init__.py @@ -1,19 +1,49 @@ """OpenAPI core validation request module""" +from functools import partial + from openapi_core.unmarshalling.schemas import ( oas30_request_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas import ( oas31_schema_unmarshallers_factory, ) -from openapi_core.validation.request.proxies import DetectRequestValidatorProxy +from openapi_core.validation.request.proxies import ( + DetectSpecRequestValidatorProxy, +) +from openapi_core.validation.request.proxies import SpecRequestValidatorProxy from openapi_core.validation.request.validators import RequestBodyValidator from openapi_core.validation.request.validators import ( RequestParametersValidator, ) from openapi_core.validation.request.validators import RequestSecurityValidator from openapi_core.validation.request.validators import RequestValidator +from openapi_core.validation.request.validators import V30RequestBodyValidator +from openapi_core.validation.request.validators import ( + V30RequestParametersValidator, +) +from openapi_core.validation.request.validators import ( + V30RequestSecurityValidator, +) +from openapi_core.validation.request.validators import V30RequestValidator +from openapi_core.validation.request.validators import V31RequestBodyValidator +from openapi_core.validation.request.validators import ( + V31RequestParametersValidator, +) +from openapi_core.validation.request.validators import ( + V31RequestSecurityValidator, +) +from openapi_core.validation.request.validators import V31RequestValidator __all__ = [ + "V30RequestBodyValidator", + "V30RequestParametersValidator", + "V30RequestSecurityValidator", + "V30RequestValidator", + "V31RequestBodyValidator", + "V31RequestParametersValidator", + "V31RequestSecurityValidator", + "V31RequestValidator", + "V3RequestValidator", "openapi_v30_request_body_validator", "openapi_v30_request_parameters_validator", "openapi_v30_request_security_validator", @@ -32,33 +62,45 @@ "openapi_request_validator", ] -openapi_v30_request_body_validator = RequestBodyValidator( +# alias to the latest v3 version +V3RequestValidator = V31RequestValidator + +# spec validators +openapi_v30_request_body_validator = SpecRequestValidatorProxy( + RequestBodyValidator, schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, ) -openapi_v30_request_parameters_validator = RequestParametersValidator( +openapi_v30_request_parameters_validator = SpecRequestValidatorProxy( + RequestParametersValidator, schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, ) -openapi_v30_request_security_validator = RequestSecurityValidator( +openapi_v30_request_security_validator = SpecRequestValidatorProxy( + RequestSecurityValidator, schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, ) -openapi_v30_request_validator = RequestValidator( +openapi_v30_request_validator = SpecRequestValidatorProxy( + RequestValidator, schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, ) -openapi_v31_request_body_validator = RequestBodyValidator( +openapi_v31_request_body_validator = SpecRequestValidatorProxy( + RequestBodyValidator, schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, ) -openapi_v31_request_parameters_validator = RequestParametersValidator( +openapi_v31_request_parameters_validator = SpecRequestValidatorProxy( + RequestParametersValidator, schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, ) -openapi_v31_request_security_validator = RequestSecurityValidator( +openapi_v31_request_security_validator = SpecRequestValidatorProxy( + RequestSecurityValidator, schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, ) -openapi_v31_request_validator = RequestValidator( +openapi_v31_request_validator = SpecRequestValidatorProxy( + RequestValidator, schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, ) -# alias to the latest v3 version +# spec validators alias to the latest v3 version openapi_v3_request_body_validator = openapi_v31_request_body_validator openapi_v3_request_parameters_validator = ( openapi_v31_request_parameters_validator @@ -67,25 +109,25 @@ openapi_v3_request_validator = openapi_v31_request_validator # detect version spec -openapi_request_body_validator = DetectRequestValidatorProxy( +openapi_request_body_validator = DetectSpecRequestValidatorProxy( { ("openapi", "3.0"): openapi_v30_request_body_validator, ("openapi", "3.1"): openapi_v31_request_body_validator, }, ) -openapi_request_parameters_validator = DetectRequestValidatorProxy( +openapi_request_parameters_validator = DetectSpecRequestValidatorProxy( { ("openapi", "3.0"): openapi_v30_request_parameters_validator, ("openapi", "3.1"): openapi_v31_request_parameters_validator, }, ) -openapi_request_security_validator = DetectRequestValidatorProxy( +openapi_request_security_validator = DetectSpecRequestValidatorProxy( { ("openapi", "3.0"): openapi_v30_request_security_validator, ("openapi", "3.1"): openapi_v31_request_security_validator, }, ) -openapi_request_validator = DetectRequestValidatorProxy( +openapi_request_validator = DetectSpecRequestValidatorProxy( { ("openapi", "3.0"): openapi_v30_request_validator, ("openapi", "3.1"): openapi_v31_request_validator, diff --git a/openapi_core/validation/request/protocols.py b/openapi_core/validation/request/protocols.py index bb527b19..a3506952 100644 --- a/openapi_core/validation/request/protocols.py +++ b/openapi_core/validation/request/protocols.py @@ -87,10 +87,11 @@ def path_pattern(self) -> str: @runtime_checkable class RequestValidator(Protocol): + def __init__(self, spec: Spec, base_url: Optional[str] = None): + ... + def validate( self, - spec: Spec, request: Request, - base_url: Optional[str] = None, ) -> RequestValidationResult: ... diff --git a/openapi_core/validation/request/proxies.py b/openapi_core/validation/request/proxies.py index 725853ac..c667af75 100644 --- a/openapi_core/validation/request/proxies.py +++ b/openapi_core/validation/request/proxies.py @@ -1,12 +1,12 @@ """OpenAPI spec validator validation proxies module.""" +import warnings from typing import Any -from typing import Hashable from typing import Iterator from typing import Mapping from typing import Optional from typing import Tuple +from typing import Type -from openapi_core.exceptions import OpenAPIError from openapi_core.spec import Spec from openapi_core.validation.exceptions import ValidatorDetectError from openapi_core.validation.request.datatypes import RequestValidationResult @@ -14,13 +14,62 @@ from openapi_core.validation.request.validators import BaseRequestValidator -class DetectRequestValidatorProxy: +class SpecRequestValidatorProxy: def __init__( - self, choices: Mapping[Tuple[str, str], BaseRequestValidator] + self, + validator_cls: Type[BaseRequestValidator], + **validator_kwargs: Any, + ): + self.validator_cls = validator_cls + self.validator_kwargs = validator_kwargs + + def validate( + self, + spec: Spec, + request: Request, + base_url: Optional[str] = None, + ) -> RequestValidationResult: + warnings.warn( + "openapi_request_validator is deprecated. " + f"Use {self.validator_cls.__name__} class instead.", + DeprecationWarning, + ) + validator = self.validator_cls( + spec, base_url=base_url, **self.validator_kwargs + ) + return validator.validate(request) + + def is_valid( + self, + spec: Spec, + request: Request, + base_url: Optional[str] = None, + ) -> bool: + validator = self.validator_cls( + spec, base_url=base_url, **self.validator_kwargs + ) + error = next(validator.iter_errors(request), None) + return error is None + + def iter_errors( + self, + spec: Spec, + request: Request, + base_url: Optional[str] = None, + ) -> Iterator[Exception]: + validator = self.validator_cls( + spec, base_url=base_url, **self.validator_kwargs + ) + yield from validator.iter_errors(request) + + +class DetectSpecRequestValidatorProxy: + def __init__( + self, choices: Mapping[Tuple[str, str], SpecRequestValidatorProxy] ): self.choices = choices - def detect(self, spec: Spec) -> BaseRequestValidator: + def detect(self, spec: Spec) -> SpecRequestValidatorProxy: for (key, value), validator in self.choices.items(): if key in spec and spec[key].startswith(value): return validator diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 1f431fa6..29a4ef53 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -28,7 +28,12 @@ from openapi_core.spec.paths import Spec from openapi_core.templating.media_types.exceptions import MediaTypeFinderError from openapi_core.templating.paths.exceptions import PathError -from openapi_core.unmarshalling.schemas.enums import UnmarshalContext +from openapi_core.unmarshalling.schemas import ( + oas30_request_schema_unmarshallers_factory, +) +from openapi_core.unmarshalling.schemas import ( + oas31_schema_unmarshallers_factory, +) from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError from openapi_core.unmarshalling.schemas.exceptions import ValidateError from openapi_core.unmarshalling.schemas.factories import ( @@ -52,35 +57,31 @@ class BaseRequestValidator(BaseValidator): def __init__( self, - schema_unmarshallers_factory: SchemaUnmarshallersFactory, + spec: Spec, + base_url: Optional[str] = None, + schema_unmarshallers_factory: Optional[ + SchemaUnmarshallersFactory + ] = None, schema_casters_factory: SchemaCastersFactory = schema_casters_factory, parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, security_provider_factory: SecurityProviderFactory = security_provider_factory, ): super().__init__( - schema_unmarshallers_factory, + spec, + base_url=base_url, + schema_unmarshallers_factory=schema_unmarshallers_factory, schema_casters_factory=schema_casters_factory, parameter_deserializers_factory=parameter_deserializers_factory, media_type_deserializers_factory=media_type_deserializers_factory, ) self.security_provider_factory = security_provider_factory - def iter_errors( - self, - spec: Spec, - request: Request, - base_url: Optional[str] = None, - ) -> Iterator[Exception]: - result = self.validate(spec, request, base_url=base_url) + def iter_errors(self, request: Request) -> Iterator[Exception]: + result = self.validate(request) yield from result.errors - def validate( - self, - spec: Spec, - request: Request, - base_url: Optional[str] = None, - ) -> RequestValidationResult: + def validate(self, request: Request) -> RequestValidationResult: raise NotImplementedError def _get_parameters( @@ -143,11 +144,11 @@ def _get_parameter(self, param: Spec, request: Request) -> Any: raise MissingParameter(name) def _get_security( - self, spec: Spec, request: Request, operation: Spec + self, request: Request, operation: Spec ) -> Optional[Dict[str, str]]: security = None - if "security" in spec: - security = spec / "security" + if "security" in self.spec: + security = self.spec / "security" if "security" in operation: security = operation / "security" @@ -157,9 +158,7 @@ def _get_security( for security_requirement in security: try: return { - scheme_name: self._get_security_value( - spec, scheme_name, request - ) + scheme_name: self._get_security_value(scheme_name, request) for scheme_name in list(security_requirement.keys()) } except SecurityError: @@ -167,10 +166,8 @@ def _get_security( raise InvalidSecurity - def _get_security_value( - self, spec: Spec, scheme_name: str, request: Request - ) -> Any: - security_schemes = spec / "components#securitySchemes" + def _get_security_value(self, scheme_name: str, request: Request) -> Any: + security_schemes = self.spec / "components#securitySchemes" if scheme_name not in security_schemes: return scheme = security_schemes[scheme_name] @@ -207,16 +204,9 @@ def _get_body_value(self, request_body: Spec, request: Request) -> Any: class RequestParametersValidator(BaseRequestValidator): - def validate( - self, - spec: Spec, - request: Request, - base_url: Optional[str] = None, - ) -> RequestValidationResult: + def validate(self, request: Request) -> RequestValidationResult: try: - path, operation, _, path_result, _ = self._find_path( - spec, request, base_url=base_url - ) + path, operation, _, path_result, _ = self._find_path(request) except PathError as exc: return RequestValidationResult(errors=[exc]) @@ -239,16 +229,9 @@ def validate( class RequestBodyValidator(BaseRequestValidator): - def validate( - self, - spec: Spec, - request: Request, - base_url: Optional[str] = None, - ) -> RequestValidationResult: + def validate(self, request: Request) -> RequestValidationResult: try: - _, operation, _, _, _ = self._find_path( - spec, request, base_url=base_url - ) + _, operation, _, _, _ = self._find_path(request) except PathError as exc: return RequestValidationResult(errors=[exc]) @@ -277,21 +260,14 @@ def validate( class RequestSecurityValidator(BaseRequestValidator): - def validate( - self, - spec: Spec, - request: Request, - base_url: Optional[str] = None, - ) -> RequestValidationResult: + def validate(self, request: Request) -> RequestValidationResult: try: - _, operation, _, _, _ = self._find_path( - spec, request, base_url=base_url - ) + _, operation, _, _, _ = self._find_path(request) except PathError as exc: return RequestValidationResult(errors=[exc]) try: - security = self._get_security(spec, request, operation) + security = self._get_security(request, operation) except InvalidSecurity as exc: return RequestValidationResult(errors=[exc]) @@ -302,22 +278,15 @@ def validate( class RequestValidator(BaseRequestValidator): - def validate( - self, - spec: Spec, - request: Request, - base_url: Optional[str] = None, - ) -> RequestValidationResult: + def validate(self, request: Request) -> RequestValidationResult: try: - path, operation, _, path_result, _ = self._find_path( - spec, request, base_url=base_url - ) + path, operation, _, path_result, _ = self._find_path(request) # don't process if operation errors except PathError as exc: return RequestValidationResult(errors=[exc]) try: - security = self._get_security(spec, request, operation) + security = self._get_security(request, operation) except InvalidSecurity as exc: return RequestValidationResult(errors=[exc]) @@ -358,3 +327,35 @@ def validate( parameters=params, security=security, ) + + +class V30RequestBodyValidator(RequestBodyValidator): + schema_unmarshallers_factory = oas30_request_schema_unmarshallers_factory + + +class V30RequestParametersValidator(RequestParametersValidator): + schema_unmarshallers_factory = oas30_request_schema_unmarshallers_factory + + +class V30RequestSecurityValidator(RequestSecurityValidator): + schema_unmarshallers_factory = oas30_request_schema_unmarshallers_factory + + +class V30RequestValidator(RequestValidator): + schema_unmarshallers_factory = oas30_request_schema_unmarshallers_factory + + +class V31RequestBodyValidator(RequestBodyValidator): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31RequestParametersValidator(RequestParametersValidator): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31RequestSecurityValidator(RequestSecurityValidator): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31RequestValidator(RequestValidator): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory diff --git a/openapi_core/validation/response/__init__.py b/openapi_core/validation/response/__init__.py index 3bbc3001..09ec44f4 100644 --- a/openapi_core/validation/response/__init__.py +++ b/openapi_core/validation/response/__init__.py @@ -1,4 +1,6 @@ """OpenAPI core validation response module""" +from functools import partial + from openapi_core.unmarshalling.schemas import ( oas30_response_schema_unmarshallers_factory, ) @@ -8,13 +10,35 @@ from openapi_core.validation.response.proxies import ( DetectResponseValidatorProxy, ) +from openapi_core.validation.response.proxies import SpecResponseValidatorProxy from openapi_core.validation.response.validators import ResponseDataValidator from openapi_core.validation.response.validators import ( ResponseHeadersValidator, ) from openapi_core.validation.response.validators import ResponseValidator +from openapi_core.validation.response.validators import ( + V30ResponseDataValidator, +) +from openapi_core.validation.response.validators import ( + V30ResponseHeadersValidator, +) +from openapi_core.validation.response.validators import V30ResponseValidator +from openapi_core.validation.response.validators import ( + V31ResponseDataValidator, +) +from openapi_core.validation.response.validators import ( + V31ResponseHeadersValidator, +) +from openapi_core.validation.response.validators import V31ResponseValidator __all__ = [ + "V30ResponseDataValidator", + "V30ResponseHeadersValidator", + "V30ResponseValidator", + "V31ResponseDataValidator", + "V31ResponseHeadersValidator", + "V31ResponseValidator", + "V3ResponseValidator", "openapi_v30_response_data_validator", "openapi_v30_response_headers_validator", "openapi_v30_response_validator", @@ -29,27 +53,37 @@ "openapi_response_validator", ] -openapi_v30_response_data_validator = ResponseDataValidator( +# alias to the latest v3 version +V3ResponseValidator = V31ResponseValidator + +# spec validators +openapi_v30_response_data_validator = SpecResponseValidatorProxy( + ResponseDataValidator, schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, ) -openapi_v30_response_headers_validator = ResponseHeadersValidator( +openapi_v30_response_headers_validator = SpecResponseValidatorProxy( + ResponseHeadersValidator, schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, ) -openapi_v30_response_validator = ResponseValidator( +openapi_v30_response_validator = SpecResponseValidatorProxy( + ResponseValidator, schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, ) -openapi_v31_response_data_validator = ResponseDataValidator( +openapi_v31_response_data_validator = SpecResponseValidatorProxy( + ResponseDataValidator, schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, ) -openapi_v31_response_headers_validator = ResponseHeadersValidator( +openapi_v31_response_headers_validator = SpecResponseValidatorProxy( + ResponseHeadersValidator, schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, ) -openapi_v31_response_validator = ResponseValidator( +openapi_v31_response_validator = SpecResponseValidatorProxy( + ResponseValidator, schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, ) -# alias to the latest v3 version +# spec validators alias to the latest v3 version openapi_v3_response_data_validator = openapi_v31_response_data_validator openapi_v3_response_headers_validator = openapi_v31_response_headers_validator openapi_v3_response_validator = openapi_v31_response_validator diff --git a/openapi_core/validation/response/protocols.py b/openapi_core/validation/response/protocols.py index c8247854..dc06ae6b 100644 --- a/openapi_core/validation/response/protocols.py +++ b/openapi_core/validation/response/protocols.py @@ -50,11 +50,12 @@ def headers(self) -> Mapping[str, Any]: @runtime_checkable class ResponseValidator(Protocol): + def __init__(self, spec: Spec, base_url: Optional[str] = None): + ... + def validate( self, - spec: Spec, request: Request, response: Response, - base_url: Optional[str] = None, ) -> ResponseValidationResult: ... diff --git a/openapi_core/validation/response/proxies.py b/openapi_core/validation/response/proxies.py index 750d0337..16cdc276 100644 --- a/openapi_core/validation/response/proxies.py +++ b/openapi_core/validation/response/proxies.py @@ -1,12 +1,12 @@ """OpenAPI spec validator validation proxies module.""" +import warnings from typing import Any -from typing import Hashable from typing import Iterator from typing import Mapping from typing import Optional from typing import Tuple +from typing import Type -from openapi_core.exceptions import OpenAPIError from openapi_core.spec import Spec from openapi_core.validation.exceptions import ValidatorDetectError from openapi_core.validation.request.protocols import Request @@ -15,13 +15,68 @@ from openapi_core.validation.response.validators import BaseResponseValidator +class SpecResponseValidatorProxy: + def __init__( + self, + validator_cls: Type[BaseResponseValidator], + **validator_kwargs: Any, + ): + self.validator_cls = validator_cls + self.validator_kwargs = validator_kwargs + + def validate( + self, + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + ) -> ResponseValidationResult: + warnings.warn( + "openapi_response_validator is deprecated. " + f"Use {self.validator_cls.__name__} class instead.", + DeprecationWarning, + ) + validator = self.validator_cls( + spec, base_url=base_url, **self.validator_kwargs + ) + return validator.validate(request, response) + + def is_valid( + self, + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + ) -> bool: + validator = self.validator_cls( + spec, base_url=base_url, **self.validator_kwargs + ) + error = next( + validator.iter_errors(request, response), + None, + ) + return error is None + + def iter_errors( + self, + spec: Spec, + request: Request, + response: Response, + base_url: Optional[str] = None, + ) -> Iterator[Exception]: + validator = self.validator_cls( + spec, base_url=base_url, **self.validator_kwargs + ) + yield from validator.iter_errors(request, response) + + class DetectResponseValidatorProxy: def __init__( - self, choices: Mapping[Tuple[str, str], BaseResponseValidator] + self, choices: Mapping[Tuple[str, str], SpecResponseValidatorProxy] ): self.choices = choices - def detect(self, spec: Spec) -> BaseResponseValidator: + def detect(self, spec: Spec) -> SpecResponseValidatorProxy: for (key, value), validator in self.choices.items(): if key in spec and spec[key].startswith(value): return validator diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index 9c884a06..6a32db57 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -13,12 +13,14 @@ from openapi_core.templating.media_types.exceptions import MediaTypeFinderError from openapi_core.templating.paths.exceptions import PathError from openapi_core.templating.responses.exceptions import ResponseFinderError -from openapi_core.unmarshalling.schemas.enums import UnmarshalContext +from openapi_core.unmarshalling.schemas import ( + oas30_response_schema_unmarshallers_factory, +) +from openapi_core.unmarshalling.schemas import ( + oas31_schema_unmarshallers_factory, +) from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError from openapi_core.unmarshalling.schemas.exceptions import ValidateError -from openapi_core.unmarshalling.schemas.factories import ( - SchemaUnmarshallersFactory, -) from openapi_core.util import chainiters from openapi_core.validation.exceptions import MissingHeader from openapi_core.validation.exceptions import MissingRequiredHeader @@ -33,33 +35,25 @@ class BaseResponseValidator(BaseValidator): def iter_errors( self, - spec: Spec, request: Request, response: Response, - base_url: Optional[str] = None, ) -> Iterator[Exception]: - result = self.validate(spec, request, response, base_url=base_url) + result = self.validate(request, response) yield from result.errors def validate( self, - spec: Spec, request: Request, response: Response, - base_url: Optional[str] = None, ) -> ResponseValidationResult: raise NotImplementedError def _find_operation_response( self, - spec: Spec, request: Request, response: Response, - base_url: Optional[str] = None, ) -> Spec: - _, operation, _, _, _ = self._find_path( - spec, request, base_url=base_url - ) + _, operation, _, _, _ = self._find_path(request) return self._get_operation_response(operation, response) def _get_operation_response( @@ -152,17 +146,13 @@ def _get_header(self, name: str, header: Spec, response: Response) -> Any: class ResponseDataValidator(BaseResponseValidator): def validate( self, - spec: Spec, request: Request, response: Response, - base_url: Optional[str] = None, ) -> ResponseValidationResult: try: operation_response = self._find_operation_response( - spec, request, response, - base_url=base_url, ) # don't process if operation errors except (PathError, ResponseFinderError) as exc: @@ -192,17 +182,13 @@ def validate( class ResponseHeadersValidator(BaseResponseValidator): def validate( self, - spec: Spec, request: Request, response: Response, - base_url: Optional[str] = None, ) -> ResponseValidationResult: try: operation_response = self._find_operation_response( - spec, request, response, - base_url=base_url, ) # don't process if operation errors except (PathError, ResponseFinderError) as exc: @@ -225,17 +211,13 @@ def validate( class ResponseValidator(BaseResponseValidator): def validate( self, - spec: Spec, request: Request, response: Response, - base_url: Optional[str] = None, ) -> ResponseValidationResult: try: operation_response = self._find_operation_response( - spec, request, response, - base_url=base_url, ) # don't process if operation errors except (PathError, ResponseFinderError) as exc: @@ -270,3 +252,27 @@ def validate( data=data, headers=headers, ) + + +class V30ResponseDataValidator(ResponseDataValidator): + schema_unmarshallers_factory = oas30_response_schema_unmarshallers_factory + + +class V30ResponseHeadersValidator(ResponseHeadersValidator): + schema_unmarshallers_factory = oas30_response_schema_unmarshallers_factory + + +class V30ResponseValidator(ResponseValidator): + schema_unmarshallers_factory = oas30_response_schema_unmarshallers_factory + + +class V31ResponseDataValidator(ResponseDataValidator): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31ResponseHeadersValidator(ResponseHeadersValidator): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31ResponseValidator(ResponseValidator): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory diff --git a/openapi_core/validation/shortcuts.py b/openapi_core/validation/shortcuts.py index e30c56f0..bf94c3c9 100644 --- a/openapi_core/validation/shortcuts.py +++ b/openapi_core/validation/shortcuts.py @@ -1,24 +1,74 @@ """OpenAPI core validation shortcuts module""" +import warnings +from typing import Any +from typing import Dict +from typing import NamedTuple from typing import Optional +from typing import Type from openapi_core.spec import Spec -from openapi_core.validation.request import openapi_request_validator +from openapi_core.validation.exceptions import ValidatorDetectError +from openapi_core.validation.request import V30RequestValidator +from openapi_core.validation.request import V31RequestValidator from openapi_core.validation.request.datatypes import RequestValidationResult from openapi_core.validation.request.protocols import Request from openapi_core.validation.request.protocols import RequestValidator -from openapi_core.validation.response import openapi_response_validator +from openapi_core.validation.request.proxies import SpecRequestValidatorProxy +from openapi_core.validation.response import V30ResponseValidator +from openapi_core.validation.response import V31ResponseValidator from openapi_core.validation.response.datatypes import ResponseValidationResult from openapi_core.validation.response.protocols import Response from openapi_core.validation.response.protocols import ResponseValidator +from openapi_core.validation.response.proxies import SpecResponseValidatorProxy + + +class SpecVersion(NamedTuple): + name: str + version: str + + +class SpecValidators(NamedTuple): + request: Type[RequestValidator] + response: Type[ResponseValidator] + + +SPECS: Dict[SpecVersion, SpecValidators] = { + SpecVersion("openapi", "3.0"): SpecValidators( + V30RequestValidator, V30ResponseValidator + ), + SpecVersion("openapi", "3.1"): SpecValidators( + V31RequestValidator, V31ResponseValidator + ), +} + + +def get_validators(spec: Spec) -> SpecValidators: + for v, validators in SPECS.items(): + if v.name in spec and spec[v.name].startswith(v.version): + return validators + raise ValidatorDetectError("Spec schema version not detected") def validate_request( request: Request, spec: Spec, base_url: Optional[str] = None, - validator: RequestValidator = openapi_request_validator, + validator: Optional[SpecRequestValidatorProxy] = None, + cls: Optional[Type[RequestValidator]] = None, + **validator_kwargs: Any, ) -> RequestValidationResult: - result = validator.validate(spec, request, base_url=base_url) + if validator is not None: + warnings.warn( + "validator parameter is deprecated. Use cls instead.", + DeprecationWarning, + ) + result = validator.validate(spec, request, base_url=base_url) + else: + if cls is None: + validators = get_validators(spec) + cls = getattr(validators, "request") + v = cls(spec, base_url=base_url, **validator_kwargs) + result = v.validate(request) result.raise_for_errors() return result @@ -28,8 +78,21 @@ def validate_response( response: Response, spec: Spec, base_url: Optional[str] = None, - validator: ResponseValidator = openapi_response_validator, + validator: Optional[SpecResponseValidatorProxy] = None, + cls: Optional[Type[ResponseValidator]] = None, + **validator_kwargs: Any, ) -> ResponseValidationResult: - result = validator.validate(spec, request, response, base_url=base_url) + if validator is not None: + warnings.warn( + "validator parameter is deprecated. Use cls instead.", + DeprecationWarning, + ) + result = validator.validate(spec, request, response, base_url=base_url) + else: + if cls is None: + validators = get_validators(spec) + cls = getattr(validators, "response") + v = cls(spec, base_url=base_url, **validator_kwargs) + result = v.validate(request, response) result.raise_for_errors() return result diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index 8689a181..45758489 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -1,7 +1,10 @@ """OpenAPI core validation validators module""" from typing import Any +from typing import Dict from typing import Mapping from typing import Optional +from typing import Tuple +from urllib.parse import urljoin from openapi_core.casting.schemas import schema_casters_factory from openapi_core.casting.schemas.factories import SchemaCastersFactory @@ -30,28 +33,42 @@ class BaseValidator: + + schema_unmarshallers_factory: SchemaUnmarshallersFactory = NotImplemented + def __init__( self, - schema_unmarshallers_factory: SchemaUnmarshallersFactory, + spec: Spec, + base_url: Optional[str] = None, + schema_unmarshallers_factory: Optional[ + SchemaUnmarshallersFactory + ] = None, schema_casters_factory: SchemaCastersFactory = schema_casters_factory, parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, ): - self.schema_unmarshallers_factory = schema_unmarshallers_factory + self.spec = spec + self.base_url = base_url + + self.schema_unmarshallers_factory = ( + schema_unmarshallers_factory or self.schema_unmarshallers_factory + ) + if self.schema_unmarshallers_factory is NotImplemented: + raise NotImplementedError( + "schema_unmarshallers_factory is not assigned" + ) + self.schema_casters_factory = schema_casters_factory self.parameter_deserializers_factory = parameter_deserializers_factory self.media_type_deserializers_factory = ( media_type_deserializers_factory ) - def _find_path( - self, spec: Spec, request: Request, base_url: Optional[str] = None - ) -> ServerOperationPath: - path_finder = PathFinder(spec, base_url=base_url) - path_pattern = getattr(request, "path_pattern", None) - return path_finder.find( - request.method, request.host_url, request.path, path_pattern - ) + def _find_path(self, request: Request) -> ServerOperationPath: + path_finder = PathFinder(self.spec, base_url=self.base_url) + path_pattern = getattr(request, "path_pattern", None) or request.path + full_url = urljoin(request.host_url, path_pattern) + return path_finder.find(request.method, full_url) def _get_media_type(self, content: Spec, mimetype: str) -> MediaType: from openapi_core.templating.media_types.finders import MediaTypeFinder diff --git a/tests/unit/templating/test_paths_finders.py b/tests/unit/templating/test_paths_finders.py index 7c410cd8..183dd9a3 100644 --- a/tests/unit/templating/test_paths_finders.py +++ b/tests/unit/templating/test_paths_finders.py @@ -181,11 +181,11 @@ def servers(self): @pytest.mark.xfail(reason="returns default server") def test_raises(self, finder): - request_uri = "/resource" - request = MockRequest("http://petstore.swagger.io", "get", request_uri) + method = "get" + full_url = "http://petstore.swagger.io/resource" with pytest.raises(ServerNotFound): - finder.find(request.method, request.host_url, request.path) + finder.find(method, full_url) class BaseTestOperationNotFound: @@ -194,22 +194,19 @@ def operations(self): return {} def test_raises(self, finder): - request_uri = "/resource" - request = MockRequest("http://petstore.swagger.io", "get", request_uri) + method = "get" + full_url = "http://petstore.swagger.io/resource" with pytest.raises(OperationNotFound): - finder.find(request.method, request.host_url, request.path) + finder.find(method, full_url) class BaseTestValid: def test_simple(self, finder, spec): - request_uri = "/resource" method = "get" - request = MockRequest( - "http://petstore.swagger.io", method, request_uri - ) + full_url = "http://petstore.swagger.io/resource" - result = finder.find(request.method, request.host_url, request.path) + result = finder.find(method, full_url) path = spec / "paths" / self.path_name operation = spec / "paths" / self.path_name / method @@ -228,13 +225,10 @@ def test_simple(self, finder, spec): class BaseTestVariableValid: @pytest.mark.parametrize("version", ["v1", "v2"]) def test_variable(self, finder, spec, version): - request_uri = f"/{version}/resource" method = "get" - request = MockRequest( - "http://petstore.swagger.io", method, request_uri - ) + full_url = f"http://petstore.swagger.io/{version}/resource" - result = finder.find(request.method, request.host_url, request.path) + result = finder.find(method, full_url) path = spec / "paths" / self.path_name operation = spec / "paths" / self.path_name / method @@ -253,13 +247,10 @@ def test_variable(self, finder, spec, version): class BaseTestPathVariableValid: @pytest.mark.parametrize("res_id", ["111", "222"]) def test_path_variable(self, finder, spec, res_id): - request_uri = f"/resource/{res_id}" method = "get" - request = MockRequest( - "http://petstore.swagger.io", method, request_uri - ) + full_url = f"http://petstore.swagger.io/resource/{res_id}" - result = finder.find(request.method, request.host_url, request.path) + result = finder.find(method, full_url) path = spec / "paths" / self.path_name operation = spec / "paths" / self.path_name / method @@ -281,11 +272,11 @@ def paths(self): return {} def test_raises(self, finder): - request_uri = "/resource" - request = MockRequest("http://petstore.swagger.io", "get", request_uri) + method = "get" + full_url = "http://petstore.swagger.io/resource" with pytest.raises(PathNotFound): - finder.find(request.method, request.host_url, request.path) + finder.find(method, full_url) class TestSpecSimpleServerServerNotFound( @@ -559,13 +550,10 @@ def paths(self, path, path_2): def test_valid(self, finder, spec): token_id = "123" - request_uri = f"/keys/{token_id}/tokens" method = "get" - request = MockRequest( - "http://petstore.swagger.io", method, request_uri - ) + full_url = f"http://petstore.swagger.io/keys/{token_id}/tokens" - result = finder.find(request.method, request.host_url, request.path) + result = finder.find(method, full_url) path_2 = spec / "paths" / self.path_2_name operation_2 = spec / "paths" / self.path_2_name / method @@ -614,12 +602,9 @@ def paths(self, path, path_2): } def test_valid(self, finder, spec): - request_uri = "/keys/master/tokens" method = "get" - request = MockRequest( - "http://petstore.swagger.io", method, request_uri - ) - result = finder.find(request.method, request.host_url, request.path) + full_url = "http://petstore.swagger.io/keys/master/tokens" + result = finder.find(method, full_url) path_2 = spec / "paths" / self.path_2_name operation_2 = spec / "paths" / self.path_2_name / method @@ -669,12 +654,9 @@ def paths(self, path, path_2): def test_valid(self, finder, spec): token_id = "123" - request_uri = f"/keys/{token_id}/tokens/master" method = "get" - request = MockRequest( - "http://petstore.swagger.io", method, request_uri - ) - result = finder.find(request.method, request.host_url, request.path) + full_url = f"http://petstore.swagger.io/keys/{token_id}/tokens/master" + result = finder.find(method, full_url) path_2 = spec / "paths" / self.path_2_name operation_2 = spec / "paths" / self.path_2_name / method diff --git a/tests/unit/validation/test_request_shortcuts.py b/tests/unit/validation/test_request_shortcuts.py deleted file mode 100644 index 20a514ca..00000000 --- a/tests/unit/validation/test_request_shortcuts.py +++ /dev/null @@ -1,36 +0,0 @@ -from unittest import mock - -import pytest - -from openapi_core.testing.datatypes import ResultMock -from openapi_core.validation.shortcuts import validate_request - - -class TestValidateRequest: - @mock.patch( - "openapi_core.validation.shortcuts.openapi_request_validator.validate" - ) - def test_validator_valid(self, mock_validate): - spec = mock.sentinel.spec - request = mock.sentinel.request - parameters = mock.sentinel.parameters - validation_result = ResultMock(parameters=parameters) - mock_validate.return_value = validation_result - - result = validate_request(request, spec=spec) - - assert result == validation_result - mock_validate.aasert_called_once_with(request) - - @mock.patch( - "openapi_core.validation.shortcuts.openapi_request_validator.validate" - ) - def test_validator_error(self, mock_validate): - spec = mock.sentinel.spec - request = mock.sentinel.request - mock_validate.return_value = ResultMock(error_to_raise=ValueError) - - with pytest.raises(ValueError): - validate_request(request, spec=spec) - - mock_validate.aasert_called_once_with(request) diff --git a/tests/unit/validation/test_response_shortcuts.py b/tests/unit/validation/test_response_shortcuts.py deleted file mode 100644 index 05987d37..00000000 --- a/tests/unit/validation/test_response_shortcuts.py +++ /dev/null @@ -1,38 +0,0 @@ -from unittest import mock - -import pytest - -from openapi_core.testing.datatypes import ResultMock -from openapi_core.validation.shortcuts import validate_response - - -class TestSpecValidateData: - @mock.patch( - "openapi_core.validation.shortcuts.openapi_response_validator.validate" - ) - def test_validator_valid(self, mock_validate): - spec = mock.sentinel.spec - request = mock.sentinel.request - response = mock.sentinel.response - data = mock.sentinel.data - validation_result = ResultMock(data=data) - mock_validate.return_value = validation_result - - result = validate_response(request, response, spec=spec) - - assert result == validation_result - mock_validate.aasert_called_once_with(request, response) - - @mock.patch( - "openapi_core.validation.shortcuts.openapi_response_validator.validate" - ) - def test_validator_error(self, mock_validate): - spec = mock.sentinel.spec - request = mock.sentinel.request - response = mock.sentinel.response - mock_validate.return_value = ResultMock(error_to_raise=ValueError) - - with pytest.raises(ValueError): - validate_response(request, response, spec=spec) - - mock_validate.aasert_called_once_with(request, response) diff --git a/tests/unit/validation/test_shortcuts.py b/tests/unit/validation/test_shortcuts.py new file mode 100644 index 00000000..b48406ea --- /dev/null +++ b/tests/unit/validation/test_shortcuts.py @@ -0,0 +1,113 @@ +from unittest import mock + +import pytest + +from openapi_core import validate_request +from openapi_core import validate_response +from openapi_core.testing.datatypes import ResultMock +from openapi_core.validation.request.validators import RequestValidator +from openapi_core.validation.response.validators import ResponseValidator + + +class TestValidateRequest: + @mock.patch( + "openapi_core.validation.request.validators.RequestValidator.validate", + ) + def test_valid(self, mock_validate): + spec = {"openapi": "3.1"} + request = mock.sentinel.request + + result = validate_request(request, spec=spec) + + assert result == mock_validate.return_value + mock_validate.validate.aasert_called_once_with(request) + + @mock.patch( + "openapi_core.validation.request.validators.RequestValidator.validate", + ) + def test_error(self, mock_validate): + spec = {"openapi": "3.1"} + request = mock.sentinel.request + mock_validate.return_value = ResultMock(error_to_raise=ValueError) + + with pytest.raises(ValueError): + validate_request(request, spec=spec) + + mock_validate.aasert_called_once_with(request) + + def test_validator(self): + spec = mock.sentinel.spec + request = mock.sentinel.request + validator = mock.Mock(spec=RequestValidator) + + with pytest.warns(DeprecationWarning): + result = validate_request(request, spec=spec, validator=validator) + + assert result == validator.validate.return_value + validator.validate.aasert_called_once_with(request) + + def test_cls(self): + spec = mock.sentinel.spec + request = mock.sentinel.request + validator_cls = mock.Mock(spec=RequestValidator) + + result = validate_request(request, spec=spec, cls=validator_cls) + + assert result == validator_cls().validate.return_value + validator_cls().validate.aasert_called_once_with(request) + + +class TestSpecValidateData: + @mock.patch( + "openapi_core.validation.response.validators.ResponseValidator.validate", + ) + def test_valid(self, mock_validate): + spec = {"openapi": "3.1"} + request = mock.sentinel.request + response = mock.sentinel.response + + result = validate_response(request, response, spec=spec) + + assert result == mock_validate.return_value + mock_validate.aasert_called_once_with(request, response) + + @mock.patch( + "openapi_core.validation.response.validators.ResponseValidator.validate", + ) + def test_error(self, mock_validate): + spec = {"openapi": "3.1"} + request = mock.sentinel.request + response = mock.sentinel.response + mock_validate.return_value = ResultMock(error_to_raise=ValueError) + + with pytest.raises(ValueError): + validate_response(request, response, spec=spec) + + mock_validate.aasert_called_once_with(request, response) + + def test_validator(self): + spec = mock.sentinel.spec + request = mock.sentinel.request + response = mock.sentinel.response + validator = mock.Mock(spec=ResponseValidator) + + with pytest.warns(DeprecationWarning): + result = validate_response( + request, response, spec=spec, validator=validator + ) + + assert result == validator.validate.return_value + validator.validate.aasert_called_once_with(request) + + def test_cls(self): + spec = mock.sentinel.spec + request = mock.sentinel.request + response = mock.sentinel.response + validator_cls = mock.Mock(spec=ResponseValidator) + + result = validate_response( + request, response, spec=spec, cls=validator_cls + ) + + assert result == validator_cls().validate.return_value + validator_cls().validate.aasert_called_once_with(request)