diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 385f970b..35275c38 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: check-hooks-apply - repo: https://github.com/asottile/pyupgrade - rev: v2.19.0 + rev: v2.38.4 hooks: - id: pyupgrade args: ["--py36-plus"] diff --git a/docs/extensions.rst b/docs/extensions.rst index e67556a5..a02b5013 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -4,7 +4,7 @@ Extensions x-model ------- -By default, objects are unmarshalled to dynamically created dataclasses. You can use your own dataclasses, pydantic models or models generated by third party generators (i.e. `datamodel-code-generator `__) by providing ``x-model`` property inside schema definition with location of your class. +By default, objects are unmarshalled to dictionaries. You can use dynamically created dataclasses. .. code-block:: yaml @@ -12,7 +12,27 @@ By default, objects are unmarshalled to dynamically created dataclasses. You can components: schemas: Coordinates: - x-model: foo.bar.Coordinates + x-model: Coordinates + type: object + required: + - lat + - lon + properties: + lat: + type: number + lon: + type: number + + +You can use your own dataclasses, pydantic models or models generated by third party generators (i.e. `datamodel-code-generator `__) by providing ``x-model-path`` property inside schema definition with location of your class. + +.. code-block:: yaml + + ... + components: + schemas: + Coordinates: + x-model-path: foo.bar.Coordinates type: object required: - lat diff --git a/openapi_core/extensions/models/factories.py b/openapi_core/extensions/models/factories.py index ecba0a15..86be6157 100644 --- a/openapi_core/extensions/models/factories.py +++ b/openapi_core/extensions/models/factories.py @@ -9,43 +9,40 @@ from typing import Type from openapi_core.extensions.models.types import Field +from openapi_core.spec import Spec class DictFactory: base_class = dict - def create(self, fields: Iterable[Field]) -> Type[Dict[Any, Any]]: + def create( + self, schema: Spec, fields: Iterable[Field] + ) -> Type[Dict[Any, Any]]: return self.base_class -class DataClassFactory(DictFactory): +class ModelFactory(DictFactory): def create( self, + schema: Spec, fields: Iterable[Field], - name: str = "Model", ) -> Type[Any]: + name = schema.getkey("x-model") + if name is None: + return super().create(schema, fields) + return make_dataclass(name, fields, frozen=True) -class ModelClassImporter(DataClassFactory): +class ModelPathFactory(ModelFactory): def create( self, + schema: Spec, fields: Iterable[Field], - name: str = "Model", - model: Optional[str] = None, ) -> Any: - if model is None: - return super().create(fields, name=name) - - model_class = self._get_class(model) - if model_class is not None: - return model_class - - return super().create(fields, name=model) + model_class_path = schema.getkey("x-model-path") + if model_class_path is None: + return super().create(schema, fields) - def _get_class(self, model_class_path: str) -> Optional[object]: - try: - return locate(model_class_path) - except ErrorDuringImport: - return None + return locate(model_class_path) diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index 872d74b0..9bddaddf 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -18,7 +18,7 @@ from openapi_schema_validator._format import oas30_format_checker from openapi_schema_validator._types import is_string -from openapi_core.extensions.models.factories import ModelClassImporter +from openapi_core.extensions.models.factories import ModelPathFactory from openapi_core.schema.schemas import get_all_properties from openapi_core.spec import Spec from openapi_core.unmarshalling.schemas.datatypes import FormattersDict @@ -199,15 +199,14 @@ class ObjectUnmarshaller(ComplexUnmarshaller): } @property - def object_class_factory(self) -> ModelClassImporter: - return ModelClassImporter() + def object_class_factory(self) -> ModelPathFactory: + return ModelPathFactory() def unmarshal(self, value: Any) -> Any: properties = self.unmarshal_raw(value) - model = self.schema.getkey("x-model") fields: Iterable[str] = properties and properties.keys() or [] - object_class = self.object_class_factory.create(fields, model=model) + object_class = self.object_class_factory.create(self.schema, fields) return object_class(**properties) diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index d4731a7c..b1647556 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -233,6 +233,7 @@ paths: components: schemas: Coordinates: + x-model: Coordinates type: object required: - lat @@ -243,6 +244,7 @@ components: lon: type: number Userdata: + x-model: Userdata type: object required: - name diff --git a/tests/integration/data/v3.0/read_only_write_only.yaml b/tests/integration/data/v3.0/read_only_write_only.yaml index be5a06a4..1f403df7 100644 --- a/tests/integration/data/v3.0/read_only_write_only.yaml +++ b/tests/integration/data/v3.0/read_only_write_only.yaml @@ -23,6 +23,7 @@ paths: components: schemas: User: + x-model: User type: object required: - id diff --git a/tests/integration/validation/test_petstore.py b/tests/integration/validation/test_petstore.py index fabe0434..c52feeb3 100644 --- a/tests/integration/validation/test_petstore.py +++ b/tests/integration/validation/test_petstore.py @@ -644,7 +644,7 @@ def test_get_pets_param_coordinates(self, spec): assert is_dataclass(result.parameters.query["coordinates"]) assert ( result.parameters.query["coordinates"].__class__.__name__ - == "Model" + == "Coordinates" ) assert result.parameters.query["coordinates"].lat == coordinates["lat"] assert result.parameters.query["coordinates"].lon == coordinates["lon"] @@ -705,7 +705,8 @@ def test_post_birds(self, spec, spec_dict): assert is_dataclass(result.parameters.cookie["userdata"]) assert ( - result.parameters.cookie["userdata"].__class__.__name__ == "Model" + result.parameters.cookie["userdata"].__class__.__name__ + == "Userdata" ) assert result.parameters.cookie["userdata"].name == "user1" diff --git a/tests/integration/validation/test_read_only_write_only.py b/tests/integration/validation/test_read_only_write_only.py index 8f3d79a7..6e1dad15 100644 --- a/tests/integration/validation/test_read_only_write_only.py +++ b/tests/integration/validation/test_read_only_write_only.py @@ -51,7 +51,7 @@ def test_read_only_property_response(self, spec): assert not result.errors assert is_dataclass(result.data) - assert result.data.__class__.__name__ == "Model" + assert result.data.__class__.__name__ == "User" assert result.data.id == 10 assert result.data.name == "Pedro" @@ -73,7 +73,7 @@ def test_write_only_property(self, spec): assert not result.errors assert is_dataclass(result.body) - assert result.body.__class__.__name__ == "Model" + assert result.body.__class__.__name__ == "User" assert result.body.name == "Pedro" assert result.body.hidden == False diff --git a/tests/integration/validation/test_validators.py b/tests/integration/validation/test_validators.py index 4bb00c0e..57d7d458 100644 --- a/tests/integration/validation/test_validators.py +++ b/tests/integration/validation/test_validators.py @@ -536,6 +536,7 @@ def test_request_object_deep_object_params(self, spec, spec_dict): "in": "query", "required": True, "schema": { + "x-model": "paramObj", "type": "object", "properties": { "count": {"type": "integer"}, diff --git a/tests/unit/extensions/test_factories.py b/tests/unit/extensions/test_factories.py index 89bc7b8f..66bf357f 100644 --- a/tests/unit/extensions/test_factories.py +++ b/tests/unit/extensions/test_factories.py @@ -6,7 +6,8 @@ import pytest -from openapi_core.extensions.models.factories import ModelClassImporter +from openapi_core.extensions.models.factories import ModelPathFactory +from openapi_core.spec import Spec class TestImportModelCreate: @@ -24,18 +25,20 @@ class BarModel: del modules["foo"] def test_dynamic_model(self): - factory = ModelClassImporter() + factory = ModelPathFactory() - test_model_class = factory.create(["name"], model="TestModel") + schema = Spec.from_dict({"x-model": "TestModel"}) + test_model_class = factory.create(schema, ["name"]) assert is_dataclass(test_model_class) assert test_model_class.__name__ == "TestModel" assert list(test_model_class.__dataclass_fields__.keys()) == ["name"] assert test_model_class.__dataclass_fields__["name"].type == str(Any) - def test_imported_model(self, loaded_model_class): - factory = ModelClassImporter() + def test_model_path(self, loaded_model_class): + factory = ModelPathFactory() - test_model_class = factory.create(["a", "b"], model="foo.BarModel") + schema = Spec.from_dict({"x-model-path": "foo.BarModel"}) + test_model_class = factory.create(schema, ["a", "b"]) assert test_model_class == loaded_model_class diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index 0a9a545a..f31a0f69 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -1,6 +1,5 @@ import datetime import uuid -from dataclasses import is_dataclass import pytest from isodate.tzinfo import UTC @@ -540,8 +539,9 @@ def test_object_nullable(self, unmarshaller_factory): value = {"foo": None} result = unmarshaller_factory(spec)(value) - assert is_dataclass(result) - assert result.foo == None + assert result == { + "foo": None, + } def test_schema_any_one_of(self, unmarshaller_factory): schema = { @@ -596,8 +596,9 @@ def test_schema_object_any_of(self, unmarshaller_factory): spec = Spec.from_dict(schema) result = unmarshaller_factory(spec)({"someint": 1}) - assert is_dataclass(result) - assert result.someint == 1 + assert result == { + "someint": 1, + } def test_schema_object_any_of_invalid(self, unmarshaller_factory): schema = { @@ -728,14 +729,7 @@ def test_schema_free_form_object( result = unmarshaller_factory(spec)(value) - assert is_dataclass(result) - for field, val in value.items(): - result_field = getattr(result, field) - if isinstance(val, dict): - for field2, val2 in val.items(): - assert getattr(result_field, field2) == val2 - else: - assert result_field == val + assert result == value def test_read_only_properties(self, unmarshaller_factory): schema = { @@ -755,8 +749,9 @@ def test_read_only_properties(self, unmarshaller_factory): {"id": 10} ) - assert is_dataclass(result) - assert result.id == 10 + assert result == { + "id": 10, + } def test_read_only_properties_invalid(self, unmarshaller_factory): schema = { @@ -795,8 +790,9 @@ def test_write_only_properties(self, unmarshaller_factory): {"id": 10} ) - assert is_dataclass(result) - assert result.id == 10 + assert result == { + "id": 10, + } def test_write_only_properties_invalid(self, unmarshaller_factory): schema = { @@ -825,5 +821,6 @@ def test_additional_properties_list(self, unmarshaller_factory): {"user_ids": [1, 2, 3, 4]} ) - assert is_dataclass(result) - assert result.user_ids == [1, 2, 3, 4] + assert result == { + "user_ids": [1, 2, 3, 4], + }