From c2e166e5ee28e6ecdcd0495b0ec75f223961d4b3 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Tue, 13 Sep 2022 13:44:37 +0100 Subject: [PATCH] x-model extension import model class --- docs/extensions.rst | 34 ++++++++++ docs/index.rst | 1 + openapi_core/contrib/django/handlers.py | 2 +- openapi_core/contrib/falcon/handlers.py | 2 +- openapi_core/extensions/models/factories.py | 54 ++++++++++------ openapi_core/extensions/models/models.py | 29 --------- openapi_core/extensions/models/types.py | 5 ++ .../unmarshalling/schemas/unmarshallers.py | 28 +++++---- .../contrib/django/test_django_project.py | 14 ++--- .../contrib/falcon/test_falcon_project.py | 14 ++--- tests/integration/data/v3.0/petstore.yaml | 24 +++---- tests/integration/validation/test_petstore.py | 62 +++++++++---------- .../validation/test_read_only_write_only.py | 17 ++--- .../integration/validation/test_validators.py | 4 +- tests/unit/extensions/test_factories.py | 41 ++++++++++++ tests/unit/extensions/test_models.py | 43 ------------- tests/unit/unmarshalling/test_unmarshal.py | 26 +++++--- tests/unit/unmarshalling/test_validate.py | 3 +- 18 files changed, 219 insertions(+), 184 deletions(-) create mode 100644 docs/extensions.rst delete mode 100644 openapi_core/extensions/models/models.py create mode 100644 openapi_core/extensions/models/types.py create mode 100644 tests/unit/extensions/test_factories.py delete mode 100644 tests/unit/extensions/test_models.py diff --git a/docs/extensions.rst b/docs/extensions.rst new file mode 100644 index 00000000..e67556a5 --- /dev/null +++ b/docs/extensions.rst @@ -0,0 +1,34 @@ +Extensions +========== + +x-model +------- + +By default, objects are unmarshalled to dynamically created dataclasses. You can use your own dataclasses, pydantic models or models generated by third party generators (i.e. `datamodel-code-generator `__) by providing ``x-model`` property inside schema definition with location of your class. + +.. code-block:: yaml + + ... + components: + schemas: + Coordinates: + x-model: foo.bar.Coordinates + type: object + required: + - lat + - lon + properties: + lat: + type: number + lon: + type: number + +.. code-block:: python + + # foo/bar.py + from dataclasses import dataclass + + @dataclass + class Coordinates: + lat: float + lon: float diff --git a/docs/index.rst b/docs/index.rst index 8090a33b..f5decbf1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -31,6 +31,7 @@ Table of contents installation usage + extensions customizations integrations diff --git a/openapi_core/contrib/django/handlers.py b/openapi_core/contrib/django/handlers.py index 05bbb742..0ddee347 100644 --- a/openapi_core/contrib/django/handlers.py +++ b/openapi_core/contrib/django/handlers.py @@ -47,7 +47,7 @@ def format_openapi_error(cls, error: Exception) -> Dict[str, Any]: return { "title": str(error), "status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), - "class": str(type(error)), + "type": str(type(error)), } @classmethod diff --git a/openapi_core/contrib/falcon/handlers.py b/openapi_core/contrib/falcon/handlers.py index 6bd59f25..14d71d47 100644 --- a/openapi_core/contrib/falcon/handlers.py +++ b/openapi_core/contrib/falcon/handlers.py @@ -53,7 +53,7 @@ def format_openapi_error(cls, error: Exception) -> Dict[str, Any]: return { "title": str(error), "status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400), - "class": str(type(error)), + "type": str(type(error)), } @classmethod diff --git a/openapi_core/extensions/models/factories.py b/openapi_core/extensions/models/factories.py index af6074f1..ecba0a15 100644 --- a/openapi_core/extensions/models/factories.py +++ b/openapi_core/extensions/models/factories.py @@ -1,33 +1,51 @@ """OpenAPI X-Model extension factories module""" +from dataclasses import make_dataclass +from pydoc import ErrorDuringImport +from pydoc import locate from typing import Any from typing import Dict +from typing import Iterable from typing import Optional from typing import Type -from openapi_core.extensions.models.models import Model +from openapi_core.extensions.models.types import Field -class ModelClassFactory: +class DictFactory: - base_class = Model + base_class = dict - def create(self, name: str) -> Type[Model]: - return type(name, (self.base_class,), {}) + def create(self, fields: Iterable[Field]) -> Type[Dict[Any, Any]]: + return self.base_class -class ModelFactory: - def __init__( - self, model_class_factory: Optional[ModelClassFactory] = None - ): - self.model_class_factory = model_class_factory or ModelClassFactory() - +class DataClassFactory(DictFactory): def create( - self, properties: Optional[Dict[str, Any]], name: Optional[str] = None - ) -> Model: - name = name or "Model" + self, + fields: Iterable[Field], + name: str = "Model", + ) -> Type[Any]: + return make_dataclass(name, fields, frozen=True) - model_class = self._create_class(name) - return model_class(properties) - def _create_class(self, name: str) -> Type[Model]: - return self.model_class_factory.create(name) +class ModelClassImporter(DataClassFactory): + def create( + self, + fields: Iterable[Field], + name: str = "Model", + model: Optional[str] = None, + ) -> Any: + if model is None: + return super().create(fields, name=name) + + model_class = self._get_class(model) + if model_class is not None: + return model_class + + return super().create(fields, name=model) + + def _get_class(self, model_class_path: str) -> Optional[object]: + try: + return locate(model_class_path) + except ErrorDuringImport: + return None diff --git a/openapi_core/extensions/models/models.py b/openapi_core/extensions/models/models.py deleted file mode 100644 index c27abf15..00000000 --- a/openapi_core/extensions/models/models.py +++ /dev/null @@ -1,29 +0,0 @@ -"""OpenAPI X-Model extension models module""" -from typing import Any -from typing import Dict -from typing import Optional - - -class BaseModel: - """Base class for OpenAPI X-Model.""" - - @property - def __dict__(self) -> Dict[Any, Any]: # type: ignore - raise NotImplementedError - - -class Model(BaseModel): - """Model class for OpenAPI X-Model.""" - - def __init__(self, properties: Optional[Dict[str, Any]] = None): - self.__properties = properties or {} - - @property - def __dict__(self) -> Dict[Any, Any]: # type: ignore - return self.__properties - - def __getattr__(self, name: str) -> Any: - if name not in self.__properties: - raise AttributeError - - return self.__properties[name] diff --git a/openapi_core/extensions/models/types.py b/openapi_core/extensions/models/types.py new file mode 100644 index 00000000..c97af344 --- /dev/null +++ b/openapi_core/extensions/models/types.py @@ -0,0 +1,5 @@ +from typing import Any +from typing import Tuple +from typing import Union + +Field = Union[str, Tuple[str, Any]] diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index 0001c8fc..6c855cff 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Dict +from typing import Iterable from typing import List from typing import Optional @@ -16,7 +17,7 @@ from openapi_schema_validator._format import oas30_format_checker from openapi_schema_validator._types import is_string -from openapi_core.extensions.models.factories import ModelFactory +from openapi_core.extensions.models.factories import ModelClassImporter from openapi_core.schema.schemas import get_all_properties from openapi_core.schema.schemas import get_all_properties_names from openapi_core.spec import Spec @@ -196,8 +197,8 @@ class ObjectUnmarshaller(ComplexUnmarshaller): } @property - def model_factory(self) -> ModelFactory: - return ModelFactory() + def object_class_factory(self) -> ModelClassImporter: + return ModelClassImporter() def unmarshal(self, value: Any) -> Any: try: @@ -230,11 +231,11 @@ def _unmarshal_object(self, value: Any) -> Any: else: properties = self._unmarshal_properties(value) - if "x-model" in self.schema: - name = self.schema["x-model"] - return self.model_factory.create(properties, name=name) + model = self.schema.getkey("x-model") + fields: Iterable[str] = properties and properties.keys() or [] + object_class = self.object_class_factory.create(fields, model=model) - return properties + return object_class(**properties) def _unmarshal_properties( self, value: Any, one_of_schema: Optional[Spec] = None @@ -253,17 +254,18 @@ def _unmarshal_properties( additional_properties = self.schema.getkey( "additionalProperties", True ) - if isinstance(additional_properties, dict): - additional_prop_schema = self.schema / "additionalProperties" + if additional_properties is not False: + # free-form object + if additional_properties is True: + additional_prop_schema = Spec.from_dict({}) + # defined schema + else: + additional_prop_schema = self.schema / "additionalProperties" for prop_name in extra_props: prop_value = value[prop_name] properties[prop_name] = self.unmarshallers_factory.create( additional_prop_schema )(prop_value) - elif additional_properties is True: - for prop_name in extra_props: - prop_value = value[prop_name] - properties[prop_name] = prop_value for prop_name, prop in list(all_props.items()): read_only = prop.getkey("readOnly", False) diff --git a/tests/integration/contrib/django/test_django_project.py b/tests/integration/contrib/django/test_django_project.py index 0170bdc2..faf64387 100644 --- a/tests/integration/contrib/django/test_django_project.py +++ b/tests/integration/contrib/django/test_django_project.py @@ -54,7 +54,7 @@ def test_get_no_required_param(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -101,7 +101,7 @@ def test_post_server_invalid(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -148,7 +148,7 @@ def test_post_required_header_param_missing(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -176,7 +176,7 @@ def test_post_media_type_invalid(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -213,7 +213,7 @@ def test_post_required_cookie_param_missing(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -267,7 +267,7 @@ def test_get_unauthorized(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -289,7 +289,7 @@ def test_delete_method_invalid(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), diff --git a/tests/integration/contrib/falcon/test_falcon_project.py b/tests/integration/contrib/falcon/test_falcon_project.py index 921de4e0..547fda0f 100644 --- a/tests/integration/contrib/falcon/test_falcon_project.py +++ b/tests/integration/contrib/falcon/test_falcon_project.py @@ -65,7 +65,7 @@ def test_post_server_invalid(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -119,7 +119,7 @@ def test_post_required_header_param_missing(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -155,7 +155,7 @@ def test_post_media_type_invalid(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -198,7 +198,7 @@ def test_post_required_cookie_param_missing(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -249,7 +249,7 @@ def test_get_server_invalid(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -283,7 +283,7 @@ def test_get_unauthorized(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), @@ -324,7 +324,7 @@ def test_delete_method_invalid(self, client): expected_data = { "errors": [ { - "class": ( + "type": ( "" ), diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index dbebd363..d4731a7c 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -81,15 +81,7 @@ paths: content: application/json: schema: - type: object - required: - - lat - - lon - properties: - lat: - type: number - lon: - type: number + $ref: "#/components/schemas/Coordinates" responses: '200': $ref: "#/components/responses/PetsResponse" @@ -240,6 +232,16 @@ paths: $ref: "#/components/responses/ErrorResponse" components: schemas: + Coordinates: + type: object + required: + - lat + - lon + properties: + lat: + type: number + lon: + type: number Userdata: type: object required: @@ -411,7 +413,7 @@ components: required: - title - status - - class + - type properties: title: type: string @@ -419,7 +421,7 @@ components: type: integer format: int32 default: 400 - class: + type: type: string StandardErrors: type: object diff --git a/tests/integration/validation/test_petstore.py b/tests/integration/validation/test_petstore.py index fd9e9f5c..fabe0434 100644 --- a/tests/integration/validation/test_petstore.py +++ b/tests/integration/validation/test_petstore.py @@ -1,5 +1,7 @@ import json from base64 import b64encode +from dataclasses import is_dataclass +from dataclasses import make_dataclass from datetime import datetime from uuid import UUID @@ -11,7 +13,6 @@ from openapi_core.deserializing.parameters.exceptions import ( EmptyQueryParameterValue, ) -from openapi_core.extensions.models.models import BaseModel from openapi_core.spec import Spec from openapi_core.templating.media_types.exceptions import MediaTypeNotFound from openapi_core.templating.paths.exceptions import ServerNotFound @@ -111,7 +112,7 @@ def test_get_pets(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert response_result.data.data == [] assert response_result.headers == { "x-next": "next-url", @@ -170,7 +171,7 @@ def test_get_pets_response(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert len(response_result.data.data) == 1 assert response_result.data.data[0].id == 1 assert response_result.data.data[0].name == "Cat" @@ -338,7 +339,7 @@ def test_get_pets_ids_param(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert response_result.data.data == [] def test_get_pets_tags_param(self, spec): @@ -388,7 +389,7 @@ def test_get_pets_tags_param(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert response_result.data.data == [] def test_get_pets_parameter_deserialization_error(self, spec): @@ -640,14 +641,13 @@ def test_get_pets_param_coordinates(self, spec): validator=openapi_v30_request_parameters_validator, ) - assert result.parameters == Parameters( - query={ - "limit": None, - "page": 1, - "search": "", - "coordinates": coordinates, - } + assert is_dataclass(result.parameters.query["coordinates"]) + assert ( + result.parameters.query["coordinates"].__class__.__name__ + == "Model" ) + assert result.parameters.query["coordinates"].lat == coordinates["lat"] + assert result.parameters.query["coordinates"].lon == coordinates["lon"] result = validate_request( spec, request, validator=openapi_v30_request_body_validator @@ -703,17 +703,11 @@ def test_post_birds(self, spec, spec_dict): spec, request, validator=openapi_v30_request_parameters_validator ) - assert result.parameters == Parameters( - header={ - "api-key": self.api_key, - }, - cookie={ - "user": 123, - "userdata": { - "name": "user1", - }, - }, + assert is_dataclass(result.parameters.cookie["userdata"]) + assert ( + result.parameters.cookie["userdata"].__class__.__name__ == "Model" ) + assert result.parameters.cookie["userdata"].name == "user1" result = validate_request( spec, request, validator=openapi_v30_request_body_validator @@ -1221,8 +1215,8 @@ def test_get_pet(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) - assert isinstance(response_result.data.data, BaseModel) + assert is_dataclass(response_result.data) + assert is_dataclass(response_result.data.data) assert response_result.data.data.id == data_id assert response_result.data.data.name == data_name @@ -1270,7 +1264,7 @@ def test_get_pet_not_found(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert response_result.data.code == code assert response_result.data.message == message assert response_result.data.rootCause == rootCause @@ -1453,7 +1447,7 @@ def test_post_tags_additional_properties(self, spec): spec, request, validator=openapi_v30_request_body_validator ) - assert isinstance(result.body, BaseModel) + assert is_dataclass(result.body) assert result.body.name == pet_name code = 400 @@ -1472,7 +1466,7 @@ def test_post_tags_additional_properties(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert response_result.data.code == code assert response_result.data.message == message assert response_result.data.rootCause == rootCause @@ -1507,7 +1501,7 @@ def test_post_tags_created_now(self, spec): spec, request, validator=openapi_v30_request_body_validator ) - assert isinstance(result.body, BaseModel) + assert is_dataclass(result.body) assert result.body.created == created assert result.body.name == pet_name @@ -1527,7 +1521,7 @@ def test_post_tags_created_now(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert response_result.data.code == code assert response_result.data.message == message assert response_result.data.rootCause == rootCause @@ -1562,7 +1556,7 @@ def test_post_tags_created_datetime(self, spec): spec, request, validator=openapi_v30_request_body_validator ) - assert isinstance(result.body, BaseModel) + assert is_dataclass(result.body) assert result.body.created == datetime( 2016, 4, 16, 16, 6, 5, tzinfo=UTC ) @@ -1588,7 +1582,7 @@ def test_post_tags_created_datetime(self, spec): validator=openapi_v30_response_data_validator, ) - assert isinstance(result.data, BaseModel) + assert is_dataclass(result.data) assert result.data.code == code assert result.data.message == message assert result.data.rootCause == rootCause @@ -1597,7 +1591,7 @@ def test_post_tags_created_datetime(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert response_result.data.code == code assert response_result.data.message == message assert response_result.data.rootCause == rootCause @@ -1650,7 +1644,7 @@ def test_post_tags_created_invalid_type(self, spec): response_result = validate_response(spec, request, response) assert response_result.errors == [] - assert isinstance(response_result.data, BaseModel) + assert is_dataclass(response_result.data) assert response_result.data.code == code assert response_result.data.message == message assert response_result.data.correlationId == correlationId @@ -1683,7 +1677,7 @@ def test_delete_tags_with_requestbody(self, spec): spec, request, validator=openapi_v30_request_body_validator ) - assert isinstance(result.body, BaseModel) + assert is_dataclass(result.body) assert result.body.ids == ids data = None diff --git a/tests/integration/validation/test_read_only_write_only.py b/tests/integration/validation/test_read_only_write_only.py index e4bc1fda..8f3d79a7 100644 --- a/tests/integration/validation/test_read_only_write_only.py +++ b/tests/integration/validation/test_read_only_write_only.py @@ -1,4 +1,5 @@ import json +from dataclasses import is_dataclass import pytest @@ -49,10 +50,10 @@ def test_read_only_property_response(self, spec): ) assert not result.errors - assert result.data == { - "id": 10, - "name": "Pedro", - } + assert is_dataclass(result.data) + assert result.data.__class__.__name__ == "Model" + assert result.data.id == 10 + assert result.data.name == "Pedro" class TestWriteOnly: @@ -71,10 +72,10 @@ def test_write_only_property(self, spec): result = openapi_v30_request_validator.validate(spec, request) assert not result.errors - assert result.body == { - "name": "Pedro", - "hidden": False, - } + assert is_dataclass(result.body) + assert result.body.__class__.__name__ == "Model" + assert result.body.name == "Pedro" + assert result.body.hidden == False def test_read_a_write_only_property(self, spec): data = json.dumps( diff --git a/tests/integration/validation/test_validators.py b/tests/integration/validation/test_validators.py index 63a8ea74..220d3ede 100644 --- a/tests/integration/validation/test_validators.py +++ b/tests/integration/validation/test_validators.py @@ -1,5 +1,6 @@ import json from base64 import b64encode +from dataclasses import is_dataclass import pytest @@ -7,7 +8,6 @@ from openapi_core.deserializing.media_types.exceptions import ( MediaTypeDeserializeError, ) -from openapi_core.extensions.models.models import BaseModel from openapi_core.spec import Spec from openapi_core.templating.media_types.exceptions import MediaTypeNotFound from openapi_core.templating.paths.exceptions import OperationNotFound @@ -655,7 +655,7 @@ def test_get_pets(self, spec): result = openapi_response_validator.validate(spec, request, response) assert result.errors == [] - assert isinstance(result.data, BaseModel) + assert is_dataclass(result.data) assert len(result.data.data) == 1 assert result.data.data[0].id == 1 assert result.data.data[0].name == "Sparky" diff --git a/tests/unit/extensions/test_factories.py b/tests/unit/extensions/test_factories.py new file mode 100644 index 00000000..89bc7b8f --- /dev/null +++ b/tests/unit/extensions/test_factories.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from dataclasses import is_dataclass +from sys import modules +from types import ModuleType +from typing import Any + +import pytest + +from openapi_core.extensions.models.factories import ModelClassImporter + + +class TestImportModelCreate: + @pytest.fixture + def loaded_model_class(self): + @dataclass + class BarModel: + a: str + b: int + + foo_module = ModuleType("foo") + foo_module.BarModel = BarModel + modules["foo"] = foo_module + yield BarModel + del modules["foo"] + + def test_dynamic_model(self): + factory = ModelClassImporter() + + test_model_class = factory.create(["name"], model="TestModel") + + assert is_dataclass(test_model_class) + assert test_model_class.__name__ == "TestModel" + assert list(test_model_class.__dataclass_fields__.keys()) == ["name"] + assert test_model_class.__dataclass_fields__["name"].type == str(Any) + + def test_imported_model(self, loaded_model_class): + factory = ModelClassImporter() + + test_model_class = factory.create(["a", "b"], model="foo.BarModel") + + assert test_model_class == loaded_model_class diff --git a/tests/unit/extensions/test_models.py b/tests/unit/extensions/test_models.py deleted file mode 100644 index 5878b62d..00000000 --- a/tests/unit/extensions/test_models.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest - -from openapi_core.extensions.models.models import BaseModel -from openapi_core.extensions.models.models import Model - - -class TestBaseModelDict: - def test_not_implemented(self): - model = BaseModel() - - with pytest.raises(NotImplementedError): - model.__dict__ - - -class TestModelDict: - def test_dict_empty(self): - model = Model() - - result = model.__dict__ - - assert result == {} - - def test_dict(self): - properties = { - "prop1": "value1", - "prop2": "value2", - } - model = Model(properties) - - result = model.__dict__ - - assert result == properties - - def test_attribute(self): - prop_value = "value1" - properties = { - "prop1": prop_value, - } - model = Model(properties) - - result = model.prop1 - - assert result == prop_value diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index e3d0aa66..3b33e133 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -1,5 +1,6 @@ import datetime import uuid +from dataclasses import is_dataclass import pytest from isodate.tzinfo import UTC @@ -539,7 +540,8 @@ def test_object_nullable(self, unmarshaller_factory): value = {"foo": None} result = unmarshaller_factory(spec)(value) - assert result == {"foo": None} + assert is_dataclass(result) + assert result.foo == None def test_schema_any_one_of(self, unmarshaller_factory): schema = { @@ -666,7 +668,15 @@ def test_schema_free_form_object( spec = Spec.from_dict(schema) result = unmarshaller_factory(spec)(value) - assert result == value + + assert is_dataclass(result) + for field, val in value.items(): + result_field = getattr(result, field) + if isinstance(val, dict): + for field2, val2 in val.items(): + assert getattr(result_field, field2) == val2 + else: + assert result_field == val def test_read_only_properties(self, unmarshaller_factory): schema = { @@ -685,9 +695,9 @@ def test_read_only_properties(self, unmarshaller_factory): result = unmarshaller_factory(spec, context=UnmarshalContext.RESPONSE)( {"id": 10} ) - assert result == { - "id": 10, - } + + assert is_dataclass(result) + assert result.id == 10 def test_read_only_properties_invalid(self, unmarshaller_factory): schema = { @@ -725,9 +735,9 @@ def test_write_only_properties(self, unmarshaller_factory): result = unmarshaller_factory(spec, context=UnmarshalContext.REQUEST)( {"id": 10} ) - assert result == { - "id": 10, - } + + assert is_dataclass(result) + assert result.id == 10 def test_write_only_properties_invalid(self, unmarshaller_factory): schema = { diff --git a/tests/unit/unmarshalling/test_validate.py b/tests/unit/unmarshalling/test_validate.py index 62ce34f7..07547d10 100644 --- a/tests/unit/unmarshalling/test_validate.py +++ b/tests/unit/unmarshalling/test_validate.py @@ -3,7 +3,6 @@ import pytest -from openapi_core.extensions.models.models import Model from openapi_core.spec.paths import Spec from openapi_core.unmarshalling.schemas import ( oas30_request_schema_unmarshallers_factory, @@ -729,7 +728,7 @@ def test_object_not_an_object(self, value, validator_factory): @pytest.mark.parametrize( "value", [ - Model(), + dict(), ], ) def test_object_multiple_one_of(self, value, validator_factory):