Skip to content
This repository was archived by the owner on Apr 15, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
a7f6263
feat(generator): initial support for typing of Composite Types
arthur-fontaine Nov 17, 2024
91d7db3
chore(pre-commit.ci): auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
4d6e3b1
test(mongodb): init test environment for mongodb
arthur-fontaine Nov 18, 2024
f413334
chore(pre-commit.ci): auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2024
f08bc8b
test(mongodb): create test for mongodb composite types
arthur-fontaine Nov 18, 2024
e1a8a1f
chore(pre-commit.ci): auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2024
8ad4651
feat(queryBuilder): include composite_types in the query sent to prisma
arthur-fontaine Nov 19, 2024
4024541
fix(compositeTypes): add BaseModel to Composite Type classes to allow…
arthur-fontaine Nov 19, 2024
0c6f530
fix(generator): fix type errors that failed tests
arthur-fontaine Nov 19, 2024
83ebdd1
chore(pre-commit.ci): auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
64c6c03
test(compositeTypes): add some tests for filtering with composite types
arthur-fontaine Nov 20, 2024
4ee38b7
feat(compositeTypes): add types for filters of composite types
arthur-fontaine Nov 20, 2024
ad50676
chore(pre-commit.ci): auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
f77f9e7
Revert "test(mongodb): create test for mongodb composite types"
arthur-fontaine Nov 22, 2024
d377fe7
Revert "test(mongodb): init test environment for mongodb"
arthur-fontaine Nov 22, 2024
590bd89
fix(queryBuilder): fix regressions caused by using get_meta_fields in…
arthur-fontaine Nov 27, 2024
e9e5bfb
chore(pre-commit.ci): auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ src/prisma/bases.py
src/prisma/types.py
src/prisma/enums.py
src/prisma/client.py
src/prisma/composite_types.py
src/prisma/models.py
src/prisma/actions.py
src/prisma/metadata.py
Expand Down
40 changes: 34 additions & 6 deletions src/prisma/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .errors import InvalidModelError, UnknownModelError, UnknownRelationalFieldError
from ._compat import get_args, is_union, get_origin, model_fields, model_field_type
from ._typing import is_list_type
from .generator import PartialModelField, MetaFieldsInterface
from ._constants import QUERY_BUILDER_ALIASES

if TYPE_CHECKING:
Expand Down Expand Up @@ -204,7 +205,6 @@ def _create_root_node(self) -> 'RootNode':

def get_default_fields(self, model: type[PrismaModel]) -> list[str]:
"""Returns a list of all the scalar fields of a model

Raises UnknownModelError if the current model cannot be found.
"""
name = getattr(model, '__prisma_model__', MISSING)
Expand All @@ -223,7 +223,7 @@ def get_default_fields(self, model: type[PrismaModel]) -> list[str]:
if not _field_is_prisma_model(info, name=field, parent=model)
]

def get_relational_model(self, current_model: type[PrismaModel], field: str) -> type[PrismaModel]:
def get_relational_model(self, current_model: type[BaseModel], field: str) -> type[PrismaModel]:
"""Returns the model that the field is related to.

Raises UnknownModelError if the current model is invalid.
Expand Down Expand Up @@ -313,7 +313,6 @@ def _prisma_model_for_field(

def _field_is_prisma_model(field: FieldInfo, *, name: str, parent: type[BaseModel]) -> bool:
"""Whether or not the given field info represents a model at the database level.

This will return `True` for cases where the field represents a list of models or a single model.
"""
return _prisma_model_for_field(field, name=name, parent=parent) is not None
Expand Down Expand Up @@ -694,7 +693,7 @@ class Selection(Node):
}
"""

model: type[PrismaModel] | None
model: type[MetaFieldsInterface] | None
include: dict[str, Any] | None
root_selection: list[str] | None

Expand All @@ -706,7 +705,7 @@ class Selection(Node):

def __init__(
self,
model: type[PrismaModel] | None = None,
model: type[MetaFieldsInterface] | None = None,
include: dict[str, Any] | None = None,
root_selection: list[str] | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -742,12 +741,20 @@ def create_children(self) -> list[ChildType]:
if root_selection is not None:
children.extend(root_selection)
elif model is not None:
children.extend(builder.get_default_fields(model))
if hasattr(model, 'get_meta_fields'):
for field, info in model.get_meta_fields().items():
children.append(self._get_child_from_model(field, info))
else:
children.extend(builder.get_default_fields(model)) # type: ignore

if include is not None:
if model is None:
raise ValueError('Cannot include fields when model is None.')

if not isinstance(model, type(BaseModel)):
raise ValueError(f'Expected model to be a Pydantic model but got {type(model)} instead.')
model = cast(type[BaseModel], model)

for key, value in include.items():
if value is True:
# e.g. posts { post_fields }
Expand Down Expand Up @@ -788,6 +795,27 @@ def create_children(self) -> list[ChildType]:

return children

def _get_child_from_model(self, field: str, info: PartialModelField) -> ChildType:
builder = self.builder

composite_type = info.get('composite_type')

if composite_type is not None:
return Key(
field,
sep=' ',
node=Selection.create(
builder,
include=None,
model=composite_type,
),
)

if info.get('is_relational'):
return ''

return field


class Key(AbstractNode):
"""Node for rendering a child node with a prefixed key"""
Expand Down
62 changes: 48 additions & 14 deletions src/prisma/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import enum
import textwrap
import warnings
import importlib
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -652,19 +653,13 @@ class DMMF(BaseModel):
class Datamodel(BaseModel):
enums: List['Enum']
models: List['Model']
types: List['Types']

# not implemented yet
types: List[object]

@field_validator('types')
@classmethod
def no_composite_types_validator(cls, types: List[object]) -> object:
if types:
raise ValueError(
'Composite types are not supported yet. Please indicate you need this here: https://github.com/RobertCraigie/prisma-client-py/issues/314'
)

return types
class Types(BaseModel):
name: str
db_name: Optional[str] = FieldInfo(alias='dbName')
fields: List['TypeField']


class Enum(BaseModel):
Expand Down Expand Up @@ -761,7 +756,7 @@ def relational_fields(self) -> Iterator['Field']:
@property
def scalar_fields(self) -> Iterator['Field']:
for field in self.all_fields:
if not field.is_relational:
if not field.is_relational and not field.is_composite_type:
yield field

@property
Expand Down Expand Up @@ -960,6 +955,9 @@ def _actual_python_type(self) -> str:
if self.kind == 'enum':
return f"'enums.{self.type}'"

if self.is_composite_type:
return f"'{self.type}'"

if self.kind == 'object':
return f"'models.{self.type}'"

Expand All @@ -973,6 +971,9 @@ def _actual_python_type(self) -> str:

@property
def create_input_type(self) -> str:
if self.is_composite_type:
return self.python_type

if self.kind != 'object':
return self.python_type

Expand All @@ -984,6 +985,14 @@ def create_input_type(self) -> str:
@property
def where_input_type(self) -> str:
typ = self.type

if self.is_composite_type:
if self.is_list:
return f"'types.{typ}ListFilter'"
if self.is_optional:
return f"'types.{typ}OptionalFilter'"
return f"'types.{typ}Filter'"

if self.is_relational:
if self.is_list:
return f"'{typ}ListRelationFilter'"
Expand Down Expand Up @@ -1026,6 +1035,18 @@ def required_on_create(self) -> bool:
and not self.is_list
)

@property
def is_composite_type(self) -> bool:
if self.kind != 'object':
return False

types = get_datamodel().types
for typ in types:
if typ.name == self.type:
return True

return False

@property
def is_optional(self) -> bool:
return not (self.is_required and not self.relation_name)
Expand All @@ -1042,13 +1063,16 @@ def is_atomic(self) -> bool:
def is_number(self) -> bool:
return self.type in {'Int', 'BigInt', 'Float'}

def maybe_optional(self, typ: str) -> str:
def maybe_optional(self, typ: str, force: bool = False) -> str:
"""Wrap the given type string within `Optional` if applicable"""
if self.is_required or self.is_relational:
if (not force) and (self.is_required or self.is_relational or self.is_composite_type):
return typ
return f'Optional[{typ}]'

def get_update_input_type(self) -> str:
if self.is_composite_type:
return self.python_type

if self.kind == 'object':
if self.is_list:
return f"'{self.type}UpdateManyWithoutRelationsInput'"
Expand Down Expand Up @@ -1106,6 +1130,11 @@ def _get_sample_data(self) -> str:
assert enum is not None, self.type
return f'enums.{enum.name}.{FAKER.from_list(enum.values).name}'

if self.kind == 'object':
# TODO
warnings.warn('Data sampling for object fields not supported yet', stacklevel=2)
return f'{{}}'

typ = self.type
if typ == 'Boolean':
return str(FAKER.boolean())
Expand All @@ -1130,6 +1159,11 @@ def _get_sample_data(self) -> str:
raise RuntimeError(f'Sample data not supported for {typ} yet')


class TypeField(Field):
is_generated: None = None # type: ignore
is_updated_at: None = None # type: ignore


class DefaultValue(BaseModel):
args: Any = None
name: str
Expand Down
39 changes: 39 additions & 0 deletions src/prisma/generator/templates/composite_types.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{% include '_header.py.jinja' %}
{% from '_utils.py.jinja' import recursive_types with context %}

from collections import OrderedDict
from typing_extensions import override

from pydantic import BaseModel

from . import types, enums, errors, fields, bases
from .generator import partial_models_ctx, PartialModelField, MetaFieldsInterface

{% for composite_type in dmmf.datamodel.types %}
class {{ composite_type.name }}(BaseModel, MetaFieldsInterface):
{% for field in composite_type.fields %}
{{ field.name }}: {{ field.python_type_as_string }}
{% endfor %}

@staticmethod
@override
def get_meta_fields() -> Dict['types.{{ composite_type.name }}Keys', PartialModelField]:
return _{{ composite_type.name }}_fields

_{{ composite_type.name }}_fields: Dict['types.{{ composite_type.name }}Keys', PartialModelField] = OrderedDict(
[
{% for field in composite_type.fields %}
('{{ field.name }}', {
'name': '{{ field.name }}',
'is_list': {{ field.is_list }},
'optional': {{ field.is_optional }},
'type': {{ field.python_type_as_string }},
'is_relational': {{ field.relation_name is not none }},
'documentation': {% if field.documentation is none %}None{% else %}'''{{ field.documentation }}'''{% endif %},
'composite_type': {% if field.is_composite_type %}{{ field.type }}{% else %}None{% endif %},
}),
{% endfor %}
],
)

{% endfor %}
12 changes: 10 additions & 2 deletions src/prisma/generator/templates/models.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,23 @@ import logging
import inspect
import warnings
from collections import OrderedDict
from typing_extensions import override

from pydantic import BaseModel, Field

from . import types, enums, errors, fields, bases
from ._types import FuncType
from ._compat import model_rebuild, field_validator
from ._builder import serialize_base64
from .generator import partial_models_ctx, PartialModelField
from .generator import partial_models_ctx, PartialModelField, MetaFieldsInterface
from .composite_types import *


log: logging.Logger = logging.getLogger(__name__)
_created_partial_types: Set[str] = set()

{% for model in dmmf.datamodel.models %}
class {{ model.name }}(bases.Base{{ model.name }}):
class {{ model.name }}(bases.Base{{ model.name }}, MetaFieldsInterface):
{% if model.documentation is none %}
"""Represents a {{ model.name }} record"""
{% else %}
Expand Down Expand Up @@ -185,6 +187,11 @@ class {{ model.name }}(bases.Base{{ model.name }}):
)
_created_partial_types.add(name)

@staticmethod
@override
def get_meta_fields() -> Dict['types.{{ model.name }}Keys', PartialModelField]:
return _{{ model.name }}_fields


{% endfor %}

Expand All @@ -208,6 +215,7 @@ _{{ model.name }}_fields: Dict['types.{{ model.name }}Keys', PartialModelField]
'type': {{ field.python_type_as_string }},
'is_relational': {{ field.relation_name is not none }},
'documentation': {% if field.documentation is none %}None{% else %}'''{{ field.documentation }}'''{% endif %},
'composite_type': {% if field.is_composite_type %}{{ field.type }}{% else %}None{% endif %},
}),
{% endfor %}
],
Expand Down
Loading