Skip to content

Validation errors refactor #473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions openapi_core/contrib/django/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
from openapi_core.templating.paths.exceptions import OperationNotFound
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.validation.exceptions import InvalidSecurity
from openapi_core.validation.exceptions import MissingRequiredParameter
from openapi_core.templating.security.exceptions import SecurityNotFound


class DjangoOpenAPIErrorsHandler:

OPENAPI_ERROR_STATUS: Dict[Type[Exception], int] = {
MissingRequiredParameter: 400,
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = {
ServerNotFound: 400,
InvalidSecurity: 403,
SecurityNotFound: 403,
OperationNotFound: 405,
PathNotFound: 404,
MediaTypeNotFound: 415,
Expand All @@ -43,7 +41,9 @@ def handle(
return JsonResponse(data, status=data_error_max["status"])

@classmethod
def format_openapi_error(cls, error: Exception) -> Dict[str, Any]:
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
if error.__cause__ is not None:
error = error.__cause__
return {
"title": str(error),
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
Expand Down
12 changes: 6 additions & 6 deletions openapi_core/contrib/falcon/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@
from openapi_core.templating.paths.exceptions import OperationNotFound
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.validation.exceptions import InvalidSecurity
from openapi_core.validation.exceptions import MissingRequiredParameter
from openapi_core.templating.security.exceptions import SecurityNotFound


class FalconOpenAPIErrorsHandler:

OPENAPI_ERROR_STATUS: Dict[Type[Exception], int] = {
MissingRequiredParameter: 400,
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = {
ServerNotFound: 400,
InvalidSecurity: 403,
SecurityNotFound: 403,
OperationNotFound: 405,
PathNotFound: 404,
MediaTypeNotFound: 415,
Expand All @@ -49,7 +47,9 @@ def handle(
resp.complete = True

@classmethod
def format_openapi_error(cls, error: Exception) -> Dict[str, Any]:
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
if error.__cause__ is not None:
error = error.__cause__
return {
"title": str(error),
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
Expand Down
10 changes: 7 additions & 3 deletions openapi_core/contrib/flask/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
from openapi_core.templating.paths.exceptions import OperationNotFound
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound


class FlaskOpenAPIErrorsHandler:

OPENAPI_ERROR_STATUS: Dict[Type[Exception], int] = {
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = {
ServerNotFound: 400,
SecurityNotFound: 403,
OperationNotFound: 405,
PathNotFound: 404,
MediaTypeNotFound: 415,
}

@classmethod
def handle(cls, errors: Iterable[Exception]) -> Response:
def handle(cls, errors: Iterable[BaseException]) -> Response:
data_errors = [cls.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
Expand All @@ -36,7 +38,9 @@ def handle(cls, errors: Iterable[Exception]) -> Response:
)

@classmethod
def format_openapi_error(cls, error: Exception) -> Dict[str, Any]:
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
if error.__cause__ is not None:
error = error.__cause__
return {
"title": str(error),
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
Expand Down
2 changes: 1 addition & 1 deletion openapi_core/security/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from openapi_core.exceptions import OpenAPIError


class SecurityError(OpenAPIError):
class SecurityProviderError(OpenAPIError):
pass
14 changes: 9 additions & 5 deletions openapi_core/security/providers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import Any

from openapi_core.security.exceptions import SecurityError
from openapi_core.security.exceptions import SecurityProviderError
from openapi_core.spec import Spec
from openapi_core.validation.request.datatypes import RequestParameters

Expand All @@ -25,22 +25,26 @@ def __call__(self, parameters: RequestParameters) -> Any:
location = self.scheme["in"]
source = getattr(parameters, location)
if name not in source:
raise SecurityError("Missing api key parameter.")
raise SecurityProviderError("Missing api key parameter.")
return source[name]


class HttpProvider(BaseProvider):
def __call__(self, parameters: RequestParameters) -> Any:
if "Authorization" not in parameters.header:
raise SecurityError("Missing authorization header.")
raise SecurityProviderError("Missing authorization header.")
auth_header = parameters.header["Authorization"]
try:
auth_type, encoded_credentials = auth_header.split(" ", 1)
except ValueError:
raise SecurityError("Could not parse authorization header.")
raise SecurityProviderError(
"Could not parse authorization header."
)

scheme = self.scheme["scheme"]
if auth_type.lower() != scheme:
raise SecurityError(f"Unknown authorization method {auth_type}")
raise SecurityProviderError(
f"Unknown authorization method {auth_type}"
)

return encoded_credentials
Empty file.
18 changes: 18 additions & 0 deletions openapi_core/templating/security/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass
from typing import List

from openapi_core.exceptions import OpenAPIError


class SecurityFinderError(OpenAPIError):
"""Security finder error"""


@dataclass
class SecurityNotFound(SecurityFinderError):
"""Find security error"""

schemes: List[List[str]]

def __str__(self) -> str:
return f"Security not found. Schemes not valid for any requirement: {str(self.schemes)}"
4 changes: 3 additions & 1 deletion openapi_core/validation/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from dataclasses import dataclass
from typing import Iterable

from openapi_core.exceptions import OpenAPIError


@dataclass
class BaseValidationResult:
errors: Iterable[Exception]
errors: Iterable[OpenAPIError]

def raise_for_errors(self) -> None:
for error in self.errors:
Expand Down
58 changes: 58 additions & 0 deletions openapi_core/validation/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from functools import wraps
from inspect import signature
from typing import Any
from typing import Callable
from typing import Optional
from typing import Type

from openapi_core.exceptions import OpenAPIError
from openapi_core.unmarshalling.schemas.exceptions import ValidateError

OpenAPIErrorType = Type[OpenAPIError]


class ValidationErrorWrapper:
def __init__(
self,
err_cls: OpenAPIErrorType,
err_validate_cls: Optional[OpenAPIErrorType] = None,
err_cls_init: Optional[str] = None,
**err_cls_kw: Any
):
self.err_cls = err_cls
self.err_validate_cls = err_validate_cls or err_cls
self.err_cls_init = err_cls_init
self.err_cls_kw = err_cls_kw

def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwds: Any) -> Any:
try:
return f(*args, **kwds)
except ValidateError as exc:
self._raise_error(exc, self.err_validate_cls, f, *args, **kwds)
except OpenAPIError as exc:
self._raise_error(exc, self.err_cls, f, *args, **kwds)

return wrapper

def _raise_error(
self,
exc: OpenAPIError,
cls: OpenAPIErrorType,
f: Callable[..., Any],
*args: Any,
**kwds: Any
) -> None:
if isinstance(exc, self.err_cls):
raise
sig = signature(f)
ba = sig.bind(*args, **kwds)
kw = {
name: ba.arguments[func_kw]
for name, func_kw in self.err_cls_kw.items()
}
init = cls
if self.err_cls_init is not None:
init = getattr(cls, self.err_cls_init)
raise init(**kw) from exc
56 changes: 2 additions & 54 deletions openapi_core/validation/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,59 +8,7 @@ class ValidatorDetectError(OpenAPIError):
pass


class ValidationError(OpenAPIError):
pass


@dataclass
class InvalidSecurity(ValidationError):
def __str__(self) -> str:
return "Security not valid for any requirement"


class OpenAPIParameterError(OpenAPIError):
pass


class MissingParameterError(OpenAPIParameterError):
"""Missing parameter error"""


@dataclass
class MissingParameter(MissingParameterError):
name: str

def __str__(self) -> str:
return f"Missing parameter (without default value): {self.name}"


@dataclass
class MissingRequiredParameter(MissingParameterError):
name: str

def __str__(self) -> str:
return f"Missing required parameter: {self.name}"


class OpenAPIHeaderError(OpenAPIError):
pass


class MissingHeaderError(OpenAPIHeaderError):
"""Missing header error"""


@dataclass
class MissingHeader(MissingHeaderError):
name: str

def __str__(self) -> str:
return f"Missing header (without default value): {self.name}"


@dataclass
class MissingRequiredHeader(MissingHeaderError):
name: str

class ValidationError(OpenAPIError):
def __str__(self) -> str:
return f"Missing required header: {self.name}"
return f"{self.__class__.__name__}: {self.__cause__}"
63 changes: 57 additions & 6 deletions openapi_core/validation/request/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
from typing import Iterable

from openapi_core.exceptions import OpenAPIError
from openapi_core.spec import Spec
from openapi_core.unmarshalling.schemas.exceptions import ValidateError
from openapi_core.validation.exceptions import ValidationError
from openapi_core.validation.request.datatypes import Parameters
from openapi_core.validation.request.protocols import Request


@dataclass
class ParametersError(Exception):
parameters: Parameters
errors: Iterable[Exception]
errors: Iterable[OpenAPIError]

@property
def context(self) -> Iterable[Exception]:
def context(self) -> Iterable[OpenAPIError]:
warnings.warn(
"context property of ParametersError is deprecated. "
"Use errors instead.",
Expand All @@ -22,11 +24,20 @@ def context(self) -> Iterable[Exception]:
return self.errors


class OpenAPIRequestBodyError(OpenAPIError):
pass
class RequestError(ValidationError):
"""Request error"""


class RequestBodyError(RequestError):
def __str__(self) -> str:
return "Request body error"


class MissingRequestBodyError(OpenAPIRequestBodyError):
class InvalidRequestBody(RequestBodyError, ValidateError):
"""Invalid request body"""


class MissingRequestBodyError(RequestBodyError):
"""Missing request body error"""


Expand All @@ -38,3 +49,43 @@ def __str__(self) -> str:
class MissingRequiredRequestBody(MissingRequestBodyError):
def __str__(self) -> str:
return "Missing required request body"


@dataclass
class ParameterError(RequestError):
name: str
location: str

@classmethod
def from_spec(cls, spec: Spec) -> "ParameterError":
return cls(spec["name"], spec["in"])

def __str__(self) -> str:
return f"{self.location.title()} parameter error: {self.name}"


class InvalidParameter(ParameterError, ValidateError):
def __str__(self) -> str:
return f"Invalid {self.location} parameter: {self.name}"


class MissingParameterError(ParameterError):
"""Missing parameter error"""


class MissingParameter(MissingParameterError):
def __str__(self) -> str:
return f"Missing {self.location} parameter: {self.name}"


class MissingRequiredParameter(MissingParameterError):
def __str__(self) -> str:
return f"Missing required {self.location} parameter: {self.name}"


class SecurityError(RequestError):
pass


class InvalidSecurity(SecurityError, ValidateError):
"""Invalid security"""
Loading