diff --git a/.gitignore b/.gitignore index 85fdfa60b..ac083802b 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ src/prisma/bases.py src/prisma/types.py src/prisma/enums.py src/prisma/client.py +src/prisma/composite_types.py src/prisma/models.py src/prisma/actions.py src/prisma/metadata.py diff --git a/src/prisma/_builder.py b/src/prisma/_builder.py index bb5a39234..c1a1f9979 100644 --- a/src/prisma/_builder.py +++ b/src/prisma/_builder.py @@ -20,6 +20,7 @@ from .errors import InvalidModelError, UnknownModelError, UnknownRelationalFieldError from ._compat import get_args, is_union, get_origin, model_fields, model_field_type from ._typing import is_list_type +from .generator import PartialModelField, MetaFieldsInterface from ._constants import QUERY_BUILDER_ALIASES if TYPE_CHECKING: @@ -204,7 +205,6 @@ def _create_root_node(self) -> 'RootNode': def get_default_fields(self, model: type[PrismaModel]) -> list[str]: """Returns a list of all the scalar fields of a model - Raises UnknownModelError if the current model cannot be found. """ name = getattr(model, '__prisma_model__', MISSING) @@ -223,7 +223,7 @@ def get_default_fields(self, model: type[PrismaModel]) -> list[str]: if not _field_is_prisma_model(info, name=field, parent=model) ] - def get_relational_model(self, current_model: type[PrismaModel], field: str) -> type[PrismaModel]: + def get_relational_model(self, current_model: type[BaseModel], field: str) -> type[PrismaModel]: """Returns the model that the field is related to. Raises UnknownModelError if the current model is invalid. @@ -313,7 +313,6 @@ def _prisma_model_for_field( def _field_is_prisma_model(field: FieldInfo, *, name: str, parent: type[BaseModel]) -> bool: """Whether or not the given field info represents a model at the database level. - This will return `True` for cases where the field represents a list of models or a single model. """ return _prisma_model_for_field(field, name=name, parent=parent) is not None @@ -694,7 +693,7 @@ class Selection(Node): } """ - model: type[PrismaModel] | None + model: type[MetaFieldsInterface] | None include: dict[str, Any] | None root_selection: list[str] | None @@ -706,7 +705,7 @@ class Selection(Node): def __init__( self, - model: type[PrismaModel] | None = None, + model: type[MetaFieldsInterface] | None = None, include: dict[str, Any] | None = None, root_selection: list[str] | None = None, **kwargs: Any, @@ -742,12 +741,20 @@ def create_children(self) -> list[ChildType]: if root_selection is not None: children.extend(root_selection) elif model is not None: - children.extend(builder.get_default_fields(model)) + if hasattr(model, 'get_meta_fields'): + for field, info in model.get_meta_fields().items(): + children.append(self._get_child_from_model(field, info)) + else: + children.extend(builder.get_default_fields(model)) # type: ignore if include is not None: if model is None: raise ValueError('Cannot include fields when model is None.') + if not isinstance(model, type(BaseModel)): + raise ValueError(f'Expected model to be a Pydantic model but got {type(model)} instead.') + model = cast(type[BaseModel], model) + for key, value in include.items(): if value is True: # e.g. posts { post_fields } @@ -788,6 +795,27 @@ def create_children(self) -> list[ChildType]: return children + def _get_child_from_model(self, field: str, info: PartialModelField) -> ChildType: + builder = self.builder + + composite_type = info.get('composite_type') + + if composite_type is not None: + return Key( + field, + sep=' ', + node=Selection.create( + builder, + include=None, + model=composite_type, + ), + ) + + if info.get('is_relational'): + return '' + + return field + class Key(AbstractNode): """Node for rendering a child node with a prefixed key""" diff --git a/src/prisma/generator/models.py b/src/prisma/generator/models.py index 40e48eeb8..d2bf7f8d5 100644 --- a/src/prisma/generator/models.py +++ b/src/prisma/generator/models.py @@ -2,6 +2,7 @@ import sys import enum import textwrap +import warnings import importlib from typing import ( TYPE_CHECKING, @@ -652,19 +653,13 @@ class DMMF(BaseModel): class Datamodel(BaseModel): enums: List['Enum'] models: List['Model'] + types: List['Types'] - # not implemented yet - types: List[object] - @field_validator('types') - @classmethod - def no_composite_types_validator(cls, types: List[object]) -> object: - if types: - raise ValueError( - 'Composite types are not supported yet. Please indicate you need this here: https://github.com/RobertCraigie/prisma-client-py/issues/314' - ) - - return types +class Types(BaseModel): + name: str + db_name: Optional[str] = FieldInfo(alias='dbName') + fields: List['TypeField'] class Enum(BaseModel): @@ -761,7 +756,7 @@ def relational_fields(self) -> Iterator['Field']: @property def scalar_fields(self) -> Iterator['Field']: for field in self.all_fields: - if not field.is_relational: + if not field.is_relational and not field.is_composite_type: yield field @property @@ -960,6 +955,9 @@ def _actual_python_type(self) -> str: if self.kind == 'enum': return f"'enums.{self.type}'" + if self.is_composite_type: + return f"'{self.type}'" + if self.kind == 'object': return f"'models.{self.type}'" @@ -973,6 +971,9 @@ def _actual_python_type(self) -> str: @property def create_input_type(self) -> str: + if self.is_composite_type: + return self.python_type + if self.kind != 'object': return self.python_type @@ -984,6 +985,14 @@ def create_input_type(self) -> str: @property def where_input_type(self) -> str: typ = self.type + + if self.is_composite_type: + if self.is_list: + return f"'types.{typ}ListFilter'" + if self.is_optional: + return f"'types.{typ}OptionalFilter'" + return f"'types.{typ}Filter'" + if self.is_relational: if self.is_list: return f"'{typ}ListRelationFilter'" @@ -1026,6 +1035,18 @@ def required_on_create(self) -> bool: and not self.is_list ) + @property + def is_composite_type(self) -> bool: + if self.kind != 'object': + return False + + types = get_datamodel().types + for typ in types: + if typ.name == self.type: + return True + + return False + @property def is_optional(self) -> bool: return not (self.is_required and not self.relation_name) @@ -1042,13 +1063,16 @@ def is_atomic(self) -> bool: def is_number(self) -> bool: return self.type in {'Int', 'BigInt', 'Float'} - def maybe_optional(self, typ: str) -> str: + def maybe_optional(self, typ: str, force: bool = False) -> str: """Wrap the given type string within `Optional` if applicable""" - if self.is_required or self.is_relational: + if (not force) and (self.is_required or self.is_relational or self.is_composite_type): return typ return f'Optional[{typ}]' def get_update_input_type(self) -> str: + if self.is_composite_type: + return self.python_type + if self.kind == 'object': if self.is_list: return f"'{self.type}UpdateManyWithoutRelationsInput'" @@ -1106,6 +1130,11 @@ def _get_sample_data(self) -> str: assert enum is not None, self.type return f'enums.{enum.name}.{FAKER.from_list(enum.values).name}' + if self.kind == 'object': + # TODO + warnings.warn('Data sampling for object fields not supported yet', stacklevel=2) + return f'{{}}' + typ = self.type if typ == 'Boolean': return str(FAKER.boolean()) @@ -1130,6 +1159,11 @@ def _get_sample_data(self) -> str: raise RuntimeError(f'Sample data not supported for {typ} yet') +class TypeField(Field): + is_generated: None = None # type: ignore + is_updated_at: None = None # type: ignore + + class DefaultValue(BaseModel): args: Any = None name: str diff --git a/src/prisma/generator/templates/composite_types.py.jinja b/src/prisma/generator/templates/composite_types.py.jinja new file mode 100644 index 000000000..b84c4dd75 --- /dev/null +++ b/src/prisma/generator/templates/composite_types.py.jinja @@ -0,0 +1,39 @@ +{% include '_header.py.jinja' %} +{% from '_utils.py.jinja' import recursive_types with context %} + +from collections import OrderedDict +from typing_extensions import override + +from pydantic import BaseModel + +from . import types, enums, errors, fields, bases +from .generator import partial_models_ctx, PartialModelField, MetaFieldsInterface + +{% for composite_type in dmmf.datamodel.types %} +class {{ composite_type.name }}(BaseModel, MetaFieldsInterface): + {% for field in composite_type.fields %} + {{ field.name }}: {{ field.python_type_as_string }} + {% endfor %} + + @staticmethod + @override + def get_meta_fields() -> Dict['types.{{ composite_type.name }}Keys', PartialModelField]: + return _{{ composite_type.name }}_fields + +_{{ composite_type.name }}_fields: Dict['types.{{ composite_type.name }}Keys', PartialModelField] = OrderedDict( + [ + {% for field in composite_type.fields %} + ('{{ field.name }}', { + 'name': '{{ field.name }}', + 'is_list': {{ field.is_list }}, + 'optional': {{ field.is_optional }}, + 'type': {{ field.python_type_as_string }}, + 'is_relational': {{ field.relation_name is not none }}, + 'documentation': {% if field.documentation is none %}None{% else %}'''{{ field.documentation }}'''{% endif %}, + 'composite_type': {% if field.is_composite_type %}{{ field.type }}{% else %}None{% endif %}, + }), + {% endfor %} + ], +) + +{% endfor %} diff --git a/src/prisma/generator/templates/models.py.jinja b/src/prisma/generator/templates/models.py.jinja index 17b5b55a8..4e4cdb50a 100644 --- a/src/prisma/generator/templates/models.py.jinja +++ b/src/prisma/generator/templates/models.py.jinja @@ -6,6 +6,7 @@ import logging import inspect import warnings from collections import OrderedDict +from typing_extensions import override from pydantic import BaseModel, Field @@ -13,14 +14,15 @@ from . import types, enums, errors, fields, bases from ._types import FuncType from ._compat import model_rebuild, field_validator from ._builder import serialize_base64 -from .generator import partial_models_ctx, PartialModelField +from .generator import partial_models_ctx, PartialModelField, MetaFieldsInterface +from .composite_types import * log: logging.Logger = logging.getLogger(__name__) _created_partial_types: Set[str] = set() {% for model in dmmf.datamodel.models %} -class {{ model.name }}(bases.Base{{ model.name }}): +class {{ model.name }}(bases.Base{{ model.name }}, MetaFieldsInterface): {% if model.documentation is none %} """Represents a {{ model.name }} record""" {% else %} @@ -185,6 +187,11 @@ class {{ model.name }}(bases.Base{{ model.name }}): ) _created_partial_types.add(name) + @staticmethod + @override + def get_meta_fields() -> Dict['types.{{ model.name }}Keys', PartialModelField]: + return _{{ model.name }}_fields + {% endfor %} @@ -208,6 +215,7 @@ _{{ model.name }}_fields: Dict['types.{{ model.name }}Keys', PartialModelField] 'type': {{ field.python_type_as_string }}, 'is_relational': {{ field.relation_name is not none }}, 'documentation': {% if field.documentation is none %}None{% else %}'''{{ field.documentation }}'''{% endif %}, + 'composite_type': {% if field.is_composite_type %}{{ field.type }}{% else %}None{% endif %}, }), {% endfor %} ], diff --git a/src/prisma/generator/templates/types.py.jinja b/src/prisma/generator/templates/types.py.jinja index b3e7350b6..0fc0a4ae4 100644 --- a/src/prisma/generator/templates/types.py.jinja +++ b/src/prisma/generator/templates/types.py.jinja @@ -440,6 +440,29 @@ class _{{ type }}ListUpdatePush(TypedDict): {% endfor %} +{% for composite_type in dmmf.datamodel.types %} +# {{ composite_type.name }} types + +class {{ composite_type.name }}(TypedDict): + {% for field in composite_type.fields %} + {{ field.name }}: {{ field.maybe_optional(field.create_input_type) }} + {% endfor %} + + +class Partial{{ composite_type.name }}(TypedDict, total=False): + {% for field in composite_type.fields %} + {{ field.name }}: {{ field.maybe_optional(field.create_input_type, True) }} + {% endfor %} + + +{{ composite_type.name }}Keys = Literal[ + {% for field in composite_type.fields %} + '{{ field.name }}', + {% endfor %} +] +{% endfor %} + + {% for model in dmmf.datamodel.models %} {% set model_schema = type_schema.get_model(model.name) %} # {{ model.name }} types @@ -777,6 +800,49 @@ class {{ model.name }}NumberAggregateInput(TypedDict, total=False): {% endfor %} +{% for composite_type in dmmf.datamodel.types %} + +{{ composite_type.name }}Filter = Union[ + {{ composite_type.name }}, + '{{ composite_type.name }}FilterOperations', +] + +{{ composite_type.name }}FilterOperations = TypedDict( + '{{ composite_type.name }}FilterOperations', + { + 'equals': '{{ composite_type.name }}', + 'is': 'Partial{{ composite_type.name }}', + 'is_not': 'Partial{{ composite_type.name }}', + }, + total=False, +) + +class {{ composite_type.name }}OptionalFilterOperations({{ composite_type.name }}FilterOperations): + is_set: bool + + +class {{ composite_type.name }}FilterWithQuantifiers(Partial{{ composite_type.name }}, total=False): + AND: 'List[{{ composite_type.name }}FilterWithQuantifiers]' + OR: 'List[{{ composite_type.name }}FilterWithQuantifiers]' + NOT: 'List[{{ composite_type.name }}FilterWithQuantifiers]' + + +{{ composite_type.name }}ListFilter = Union[ + List[{{ composite_type.name }}], + '{{ composite_type.name }}ListFilterOperations', +] + + +class {{ composite_type.name }}ListFilterOperations(TypedDict, total=False): + equals: List['{{ composite_type.name }}'] + is_empty: bool + is_set: bool + every: '{{ composite_type.name }}FilterWithQuantifiers' + some: '{{ composite_type.name }}FilterWithQuantifiers' + none: '{{ composite_type.name }}FilterWithQuantifiers' + +{% endfor %} + # we have to import ourselves as types can be namespaced to types from . import types, enums, models, fields diff --git a/src/prisma/generator/types.py b/src/prisma/generator/types.py index 3cbc25f69..819e0125d 100644 --- a/src/prisma/generator/types.py +++ b/src/prisma/generator/types.py @@ -1,10 +1,12 @@ -from typing import Mapping, Optional +from abc import abstractmethod +from typing import Any, Dict, Mapping, Optional from .._types import TypedDict __all__ = ( 'PartialModel', 'PartialModelField', + 'MetaFieldsInterface', ) @@ -15,9 +17,16 @@ class PartialModelField(TypedDict): type: str documentation: Optional[str] is_relational: bool + composite_type: Optional[object] class PartialModel(TypedDict): name: str from_model: str fields: Mapping[str, PartialModelField] + + +class MetaFieldsInterface: + @staticmethod + @abstractmethod + def get_meta_fields() -> Dict[Any, PartialModelField]: ... diff --git a/tests/test_generation/test_composite_type.py b/tests/test_generation/test_composite_type.py new file mode 100644 index 000000000..4f3299f33 --- /dev/null +++ b/tests/test_generation/test_composite_type.py @@ -0,0 +1,43 @@ +from textwrap import dedent + +from ..utils import Testdir + + +def test_composite_type_not_supported(testdir: Testdir) -> None: + """Composite types are now supported""" + schema = ( + testdir.default_generator + + """ + datasource db {{ + provider = "mongodb" + url = env("foo") + }} + + model User {{ + id String @id @map("_id") + prouuuut String + settings UserSettings + }} + + type UserSettings {{ + theme String + }} + """ + ) + testdir.generate(schema=schema) + + client_types_path = testdir.path / 'prisma' / 'types.py' + client_types = client_types_path.read_text() + + assert ( + dedent(""" + class UserSettings(TypedDict, total=False): + theme: _str + """).strip() + in client_types + ) + + for line in client_types.splitlines(): + line = line.strip() + if line.startswith('settings:'): + assert line == "settings: 'UserSettings'" diff --git a/tests/test_generation/test_validation.py b/tests/test_generation/test_validation.py index b2f0762a3..55261c6d8 100644 --- a/tests/test_generation/test_validation.py +++ b/tests/test_generation/test_validation.py @@ -225,31 +225,3 @@ def test_decimal_type_experimental(testdir: Testdir) -> None: output = str(exc.value.output, 'utf-8') assert 'Support for the Decimal type is experimental' in output assert 'set the `enable_experimental_decimal` config flag to true' in output - - -def test_composite_type_not_supported(testdir: Testdir) -> None: - """Composite types are not supported yet""" - schema = ( - testdir.default_generator - + """ - datasource db {{ - provider = "mongodb" - url = env("foo") - }} - - model User {{ - id String @id @map("_id") - // settings UserSettings - }} - - type UserSettings {{ - points Decimal - }} - """ - ) - with pytest.raises(subprocess.CalledProcessError) as exc: - testdir.generate(schema=schema) - - output = str(exc.value.output, 'utf-8') - assert 'Composite types are not supported yet.' in output - assert 'https://github.com/RobertCraigie/prisma-client-py/issues/314' in output