From a7f62630a362d1e67f9d8736b86d7215dea0da78 Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Sun, 17 Nov 2024 22:32:43 +0100 Subject: [PATCH 01/17] feat(generator): initial support for typing of Composite Types --- src/prisma/generator/models.py | 53 ++++++++++++++----- src/prisma/generator/templates/types.py.jinja | 10 ++++ tests/test_generation/test_composite_type.py | 40 ++++++++++++++ tests/test_generation/test_validation.py | 28 ---------- 4 files changed, 91 insertions(+), 40 deletions(-) create mode 100644 tests/test_generation/test_composite_type.py diff --git a/src/prisma/generator/models.py b/src/prisma/generator/models.py index 40e48eeb8..04cbdd9b8 100644 --- a/src/prisma/generator/models.py +++ b/src/prisma/generator/models.py @@ -2,6 +2,7 @@ import sys import enum import textwrap +import warnings import importlib from typing import ( TYPE_CHECKING, @@ -652,19 +653,13 @@ class DMMF(BaseModel): class Datamodel(BaseModel): enums: List['Enum'] models: List['Model'] + types: List['Types'] - # not implemented yet - types: List[object] - @field_validator('types') - @classmethod - def no_composite_types_validator(cls, types: List[object]) -> object: - if types: - raise ValueError( - 'Composite types are not supported yet. Please indicate you need this here: https://github.com/RobertCraigie/prisma-client-py/issues/314' - ) - - return types +class Types(BaseModel): + name: str + db_name: Optional[str] = FieldInfo(alias='dbName') + fields: List['TypeField'] class Enum(BaseModel): @@ -761,7 +756,7 @@ def relational_fields(self) -> Iterator['Field']: @property def scalar_fields(self) -> Iterator['Field']: for field in self.all_fields: - if not field.is_relational: + if not field.is_relational and not field.is_composite_type: yield field @property @@ -960,6 +955,9 @@ def _actual_python_type(self) -> str: if self.kind == 'enum': return f"'enums.{self.type}'" + if self.is_composite_type: + return f"'{self.type}'" + if self.kind == 'object': return f"'models.{self.type}'" @@ -973,6 +971,9 @@ def _actual_python_type(self) -> str: @property def create_input_type(self) -> str: + if self.is_composite_type: + return self.python_type + if self.kind != 'object': return self.python_type @@ -983,6 +984,9 @@ def create_input_type(self) -> str: @property def where_input_type(self) -> str: + if self.is_composite_type: + return self.python_type + typ = self.type if self.is_relational: if self.is_list: @@ -1025,6 +1029,18 @@ def required_on_create(self) -> bool: and not self.relation_name and not self.is_list ) + + @property + def is_composite_type(self) -> bool: + if self.kind != 'object': + return False + + types = get_datamodel().types + for typ in types: + if typ.name == self.type: + return True + + return False @property def is_optional(self) -> bool: @@ -1049,6 +1065,9 @@ def maybe_optional(self, typ: str) -> str: return f'Optional[{typ}]' def get_update_input_type(self) -> str: + if self.is_composite_type: + return self.python_type + if self.kind == 'object': if self.is_list: return f"'{self.type}UpdateManyWithoutRelationsInput'" @@ -1106,6 +1125,11 @@ def _get_sample_data(self) -> str: assert enum is not None, self.type return f'enums.{enum.name}.{FAKER.from_list(enum.values).name}' + if self.kind == 'object': + # TODO + warnings.warn('Data sampling for object fields not supported yet', stacklevel=2) + return f'{{}}' + typ = self.type if typ == 'Boolean': return str(FAKER.boolean()) @@ -1130,6 +1154,11 @@ def _get_sample_data(self) -> str: raise RuntimeError(f'Sample data not supported for {typ} yet') +class TypeField(Field): + is_generated: None = None # type: ignore + is_updated_at: None = None # type: ignore + + class DefaultValue(BaseModel): args: Any = None name: str diff --git a/src/prisma/generator/templates/types.py.jinja b/src/prisma/generator/templates/types.py.jinja index b3e7350b6..2f49bc115 100644 --- a/src/prisma/generator/templates/types.py.jinja +++ b/src/prisma/generator/templates/types.py.jinja @@ -440,6 +440,16 @@ class _{{ type }}ListUpdatePush(TypedDict): {% endfor %} +{% for composite_type in dmmf.datamodel.types %} +# {{ composite_type.name }} types + +class {{ composite_type.name }}(TypedDict, total=False): + {% for field in composite_type.fields %} + {{ field.name }}: {{ field.maybe_optional(field.create_input_type) }} + {% endfor %} +{% endfor %} + + {% for model in dmmf.datamodel.models %} {% set model_schema = type_schema.get_model(model.name) %} # {{ model.name }} types diff --git a/tests/test_generation/test_composite_type.py b/tests/test_generation/test_composite_type.py new file mode 100644 index 000000000..fca1f66e8 --- /dev/null +++ b/tests/test_generation/test_composite_type.py @@ -0,0 +1,40 @@ +from textwrap import dedent + +from ..utils import Testdir + + +def test_composite_type_not_supported(testdir: Testdir) -> None: + """Composite types are now supported""" + schema = ( + testdir.default_generator + + """ + datasource db {{ + provider = "mongodb" + url = env("foo") + }} + + model User {{ + id String @id @map("_id") + prouuuut String + settings UserSettings + }} + + type UserSettings {{ + theme String + }} + """ + ) + testdir.generate(schema=schema) + + client_types_path = testdir.path / "prisma" / "types.py" + client_types = client_types_path.read_text() + + assert dedent(""" + class UserSettings(TypedDict, total=False): + theme: _str + """).strip() in client_types + + for line in client_types.splitlines(): + line = line.strip() + if line.startswith("settings:"): + assert line == "settings: 'UserSettings'" diff --git a/tests/test_generation/test_validation.py b/tests/test_generation/test_validation.py index b2f0762a3..55261c6d8 100644 --- a/tests/test_generation/test_validation.py +++ b/tests/test_generation/test_validation.py @@ -225,31 +225,3 @@ def test_decimal_type_experimental(testdir: Testdir) -> None: output = str(exc.value.output, 'utf-8') assert 'Support for the Decimal type is experimental' in output assert 'set the `enable_experimental_decimal` config flag to true' in output - - -def test_composite_type_not_supported(testdir: Testdir) -> None: - """Composite types are not supported yet""" - schema = ( - testdir.default_generator - + """ - datasource db {{ - provider = "mongodb" - url = env("foo") - }} - - model User {{ - id String @id @map("_id") - // settings UserSettings - }} - - type UserSettings {{ - points Decimal - }} - """ - ) - with pytest.raises(subprocess.CalledProcessError) as exc: - testdir.generate(schema=schema) - - output = str(exc.value.output, 'utf-8') - assert 'Composite types are not supported yet.' in output - assert 'https://github.com/RobertCraigie/prisma-client-py/issues/314' in output From 91d7db320b48332bae76897adc3f5897ac1d8e46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 17 Nov 2024 21:38:34 +0000 Subject: [PATCH 02/17] chore(pre-commit.ci): auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/prisma/generator/models.py | 2 +- tests/test_generation/test_composite_type.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/prisma/generator/models.py b/src/prisma/generator/models.py index 04cbdd9b8..d1a511eb3 100644 --- a/src/prisma/generator/models.py +++ b/src/prisma/generator/models.py @@ -1029,7 +1029,7 @@ def required_on_create(self) -> bool: and not self.relation_name and not self.is_list ) - + @property def is_composite_type(self) -> bool: if self.kind != 'object': diff --git a/tests/test_generation/test_composite_type.py b/tests/test_generation/test_composite_type.py index fca1f66e8..4f3299f33 100644 --- a/tests/test_generation/test_composite_type.py +++ b/tests/test_generation/test_composite_type.py @@ -26,15 +26,18 @@ def test_composite_type_not_supported(testdir: Testdir) -> None: ) testdir.generate(schema=schema) - client_types_path = testdir.path / "prisma" / "types.py" + client_types_path = testdir.path / 'prisma' / 'types.py' client_types = client_types_path.read_text() - assert dedent(""" + assert ( + dedent(""" class UserSettings(TypedDict, total=False): theme: _str - """).strip() in client_types + """).strip() + in client_types + ) for line in client_types.splitlines(): line = line.strip() - if line.startswith("settings:"): + if line.startswith('settings:'): assert line == "settings: 'UserSettings'" From 4d6e3b1cb96095557c01784e0ff798d07634b20d Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Mon, 18 Nov 2024 01:09:08 +0100 Subject: [PATCH 03/17] test(mongodb): init test environment for mongodb --- databases/_types.py | 2 ++ databases/constants.py | 9 ++++++ databases/main.py | 8 ++++- databases/templates/schema.prisma.jinja2 | 17 ++++++++++ databases/tests_mongodb/__init__.py | 0 databases/tests_mongodb/conftest.py | 40 ++++++++++++++++++++++++ databases/tests_mongodb/test_hello.py | 8 +++++ 7 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 databases/tests_mongodb/__init__.py create mode 100644 databases/tests_mongodb/conftest.py create mode 100644 databases/tests_mongodb/test_hello.py diff --git a/databases/_types.py b/databases/_types.py index 9e86a281c..9f462edf7 100644 --- a/databases/_types.py +++ b/databases/_types.py @@ -14,6 +14,7 @@ 'mariadb', 'postgresql', 'cockroachdb', + 'mongodb', ] @@ -28,3 +29,4 @@ class DatabaseMapping(TypedDict, Generic[_T]): mariadb: _T postgresql: _T cockroachdb: _T + mongodb: _T diff --git a/databases/constants.py b/databases/constants.py index 486b743de..41c825516 100644 --- a/databases/constants.py +++ b/databases/constants.py @@ -83,6 +83,15 @@ def _fromdir(path: str) -> list[str]: 'full_text_search', }, ), + 'mongodb': DatabaseConfig( + id='mongodb', + name='MongoDB', + env_var='MONGODB_URL', + bools_are_ints=False, + default_date_func='', + autoincrement_id='', + unsupported_features=set(), + ), } SUPPORTED_DATABASES = cast(List[SupportedDatabase], list(get_args(SupportedDatabase))) diff --git a/databases/main.py b/databases/main.py index d9de7b6a8..230734af5 100644 --- a/databases/main.py +++ b/databases/main.py @@ -299,7 +299,7 @@ def setup(self) -> None: # template variables config=self.config, for_async=self.for_async, - partial_generator=escape_path(DATABASES_DIR / 'partials.py'), + partial_generator=(escape_path(DATABASES_DIR / 'partials.py') if self.config.id != 'mongodb' else None), ) ) @@ -326,6 +326,8 @@ def test(self, *, pytest_args: str | None) -> None: ) args = [] + if self.config.id == 'mongodb': + args.append('tests_mongodb') if pytest_args is not None: # pragma: no cover args = shlex.split(pytest_args) @@ -395,6 +397,10 @@ def exclude_files(self) -> set[str]: # ensure the tests for the sync client are not ran during the async tests anc vice versa files.append(tests_reldir(for_async=not self.for_async)) + if self.config.id == 'mongodb': + files.append(tests_relpath('.', for_async=True)) + files.append(tests_relpath('.', for_async=False)) + return set(files) diff --git a/databases/templates/schema.prisma.jinja2 b/databases/templates/schema.prisma.jinja2 index 74d7241a2..adbcebd10 100644 --- a/databases/templates/schema.prisma.jinja2 +++ b/databases/templates/schema.prisma.jinja2 @@ -13,7 +13,9 @@ generator client { recursive_type_depth = -1 engineType = "binary" enable_experimental_decimal = true + {% if partial_generator %} partial_type_generator = "{{ partial_generator }}" + {% endif %} {% if config.id == "postgresql" %} previewFeatures = ["fullTextSearch"] {% elif config.id == "mysql" %} @@ -21,6 +23,7 @@ generator client { {% endif %} } +{% if config.id != "mongodb" %} model User { id String @id @default(cuid()) created_at DateTime @default(now()) @@ -233,3 +236,17 @@ enum Role { } {% endif %} +{% endif %} + +{% if config.id == "mongodb" %} +model User { + id String @id @default(auto()) @map("_id") @db.ObjectId + name String + {# contact Contact #} +} + +{# type Contact { + email String + phone String +} #} +{% endif %} diff --git a/databases/tests_mongodb/__init__.py b/databases/tests_mongodb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/databases/tests_mongodb/conftest.py b/databases/tests_mongodb/conftest.py new file mode 100644 index 000000000..ae506fb35 --- /dev/null +++ b/databases/tests_mongodb/conftest.py @@ -0,0 +1,40 @@ +import os + +import pytest +from syrupy.assertion import SnapshotAssertion + +import prisma +from prisma import Prisma +from prisma._compat import model_parse_json +from lib.testing.shared_conftest import * +from lib.testing.shared_conftest.async_client import * + +from ..utils import ( + RAW_QUERIES_MAPPING, + RawQueries, + DatabaseConfig, + AmberSharedExtension, +) + +prisma.register(Prisma()) + + +# TODO: better error messages for invalid state +@pytest.fixture(name='database') +def database_fixture() -> str: + return os.environ['PRISMA_DATABASE'] + + +@pytest.fixture(name='raw_queries') +def raw_queries_fixture(database: str) -> RawQueries: + return RAW_QUERIES_MAPPING[database] + + +@pytest.fixture(name='config', scope='session') +def config_fixture() -> DatabaseConfig: + return model_parse_json(DatabaseConfig, os.environ['DATABASE_CONFIG']) + + +@pytest.fixture() +def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: + return snapshot.use_extension(AmberSharedExtension) diff --git a/databases/tests_mongodb/test_hello.py b/databases/tests_mongodb/test_hello.py new file mode 100644 index 000000000..54919cf5c --- /dev/null +++ b/databases/tests_mongodb/test_hello.py @@ -0,0 +1,8 @@ +import pytest + +from prisma import Prisma + + +@pytest.mark.asyncio +async def test_base_usage(client: Prisma) -> None: + assert "Hello, World!" == "Hello, World!" From f413334ea12af7bc036e15295e4c59bb578db8b4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 00:09:17 +0000 Subject: [PATCH 04/17] chore(pre-commit.ci): auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- databases/tests_mongodb/test_hello.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databases/tests_mongodb/test_hello.py b/databases/tests_mongodb/test_hello.py index 54919cf5c..6d6cdd20f 100644 --- a/databases/tests_mongodb/test_hello.py +++ b/databases/tests_mongodb/test_hello.py @@ -5,4 +5,4 @@ @pytest.mark.asyncio async def test_base_usage(client: Prisma) -> None: - assert "Hello, World!" == "Hello, World!" + assert 'Hello, World!' == 'Hello, World!' From f08bc8bcea70d31a50ba43d857e70a1360238a12 Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Mon, 18 Nov 2024 19:24:02 +0100 Subject: [PATCH 05/17] test(mongodb): create test for mongodb composite types --- databases/templates/schema.prisma.jinja2 | 6 +++--- .../tests_mongodb/test_composite_types.py | 18 ++++++++++++++++++ databases/tests_mongodb/test_hello.py | 8 -------- 3 files changed, 21 insertions(+), 11 deletions(-) create mode 100644 databases/tests_mongodb/test_composite_types.py delete mode 100644 databases/tests_mongodb/test_hello.py diff --git a/databases/templates/schema.prisma.jinja2 b/databases/templates/schema.prisma.jinja2 index adbcebd10..fe667c38b 100644 --- a/databases/templates/schema.prisma.jinja2 +++ b/databases/templates/schema.prisma.jinja2 @@ -242,11 +242,11 @@ enum Role { model User { id String @id @default(auto()) @map("_id") @db.ObjectId name String - {# contact Contact #} + contact Contact } -{# type Contact { +type Contact { email String phone String -} #} +} {% endif %} diff --git a/databases/tests_mongodb/test_composite_types.py b/databases/tests_mongodb/test_composite_types.py new file mode 100644 index 000000000..1edbef0a4 --- /dev/null +++ b/databases/tests_mongodb/test_composite_types.py @@ -0,0 +1,18 @@ +import pytest + +from prisma import Prisma + + +@pytest.mark.asyncio +async def test_composite_types(client: Prisma) -> None: + await client.user.create({ + 'name': 'Alice', + 'contact': { + 'email': 'test@test.com', + 'phone': '123-456-7890' + } + }) + user = await client.user.find_first() + + assert user is not None + assert user.name == 'Alice' diff --git a/databases/tests_mongodb/test_hello.py b/databases/tests_mongodb/test_hello.py deleted file mode 100644 index 6d6cdd20f..000000000 --- a/databases/tests_mongodb/test_hello.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - -from prisma import Prisma - - -@pytest.mark.asyncio -async def test_base_usage(client: Prisma) -> None: - assert 'Hello, World!' == 'Hello, World!' From e1a8a1f1e69c267325e555e3f1a78acc36df4aea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 18:24:25 +0000 Subject: [PATCH 06/17] chore(pre-commit.ci): auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- databases/tests_mongodb/test_composite_types.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/databases/tests_mongodb/test_composite_types.py b/databases/tests_mongodb/test_composite_types.py index 1edbef0a4..a9fa4b0fe 100644 --- a/databases/tests_mongodb/test_composite_types.py +++ b/databases/tests_mongodb/test_composite_types.py @@ -5,13 +5,7 @@ @pytest.mark.asyncio async def test_composite_types(client: Prisma) -> None: - await client.user.create({ - 'name': 'Alice', - 'contact': { - 'email': 'test@test.com', - 'phone': '123-456-7890' - } - }) + await client.user.create({'name': 'Alice', 'contact': {'email': 'test@test.com', 'phone': '123-456-7890'}}) user = await client.user.find_first() assert user is not None From 8ad465124c04a2646cda154c70743e380ab8523f Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Tue, 19 Nov 2024 18:42:10 +0100 Subject: [PATCH 07/17] feat(queryBuilder): include composite_types in the query sent to prisma --- .gitignore | 1 + src/prisma/_builder.py | 60 +++++++++---------- .../templates/composite_types.py.jinja | 35 +++++++++++ .../generator/templates/models.py.jinja | 10 +++- src/prisma/generator/templates/types.py.jinja | 6 ++ src/prisma/generator/types.py | 12 +++- 6 files changed, 88 insertions(+), 36 deletions(-) create mode 100644 src/prisma/generator/templates/composite_types.py.jinja diff --git a/.gitignore b/.gitignore index 85fdfa60b..ac083802b 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ src/prisma/bases.py src/prisma/types.py src/prisma/enums.py src/prisma/client.py +src/prisma/composite_types.py src/prisma/models.py src/prisma/actions.py src/prisma/metadata.py diff --git a/src/prisma/_builder.py b/src/prisma/_builder.py index bb5a39234..1d797a7f1 100644 --- a/src/prisma/_builder.py +++ b/src/prisma/_builder.py @@ -20,6 +20,7 @@ from .errors import InvalidModelError, UnknownModelError, UnknownRelationalFieldError from ._compat import get_args, is_union, get_origin, model_fields, model_field_type from ._typing import is_list_type +from .generator import PartialModelField, MetaFieldsInterface from ._constants import QUERY_BUILDER_ALIASES if TYPE_CHECKING: @@ -202,28 +203,7 @@ def _create_root_node(self) -> 'RootNode': ) return root - def get_default_fields(self, model: type[PrismaModel]) -> list[str]: - """Returns a list of all the scalar fields of a model - - Raises UnknownModelError if the current model cannot be found. - """ - name = getattr(model, '__prisma_model__', MISSING) - if name is MISSING: - raise InvalidModelError(model) - - name = model.__prisma_model__ - if name not in self.prisma_models: - raise UnknownModelError(name) - - # by default we exclude every field that points to a PrismaModel as that indicates that it is a relational field - # we explicitly keep fields that point to anything else, even other pydantic.BaseModel types, as they can be used to deserialize JSON - return [ - field - for field, info in model_fields(model).items() - if not _field_is_prisma_model(info, name=field, parent=model) - ] - - def get_relational_model(self, current_model: type[PrismaModel], field: str) -> type[PrismaModel]: + def get_relational_model(self, current_model: type[BaseModel], field: str) -> type[PrismaModel]: """Returns the model that the field is related to. Raises UnknownModelError if the current model is invalid. @@ -311,14 +291,6 @@ def _prisma_model_for_field( return None -def _field_is_prisma_model(field: FieldInfo, *, name: str, parent: type[BaseModel]) -> bool: - """Whether or not the given field info represents a model at the database level. - - This will return `True` for cases where the field represents a list of models or a single model. - """ - return _prisma_model_for_field(field, name=name, parent=parent) is not None - - def _is_prisma_model_type(type_: type[BaseModel]) -> TypeGuard[type[PrismaModel]]: from .bases import _PrismaModel # noqa: TID251 @@ -694,7 +666,7 @@ class Selection(Node): } """ - model: type[PrismaModel] | None + model: type[MetaFieldsInterface] | None include: dict[str, Any] | None root_selection: list[str] | None @@ -706,7 +678,7 @@ class Selection(Node): def __init__( self, - model: type[PrismaModel] | None = None, + model: type[MetaFieldsInterface] | None = None, include: dict[str, Any] | None = None, root_selection: list[str] | None = None, **kwargs: Any, @@ -742,12 +714,16 @@ def create_children(self) -> list[ChildType]: if root_selection is not None: children.extend(root_selection) elif model is not None: - children.extend(builder.get_default_fields(model)) + for field, info in model.get_meta_fields().items(): + children.append(self._get_child_from_model(field, info)) if include is not None: if model is None: raise ValueError('Cannot include fields when model is None.') + if not isinstance(model, type(BaseModel)): + raise ValueError(f'Expected model to be a Pydantic model but got {type(model)} instead.') + for key, value in include.items(): if value is True: # e.g. posts { post_fields } @@ -788,6 +764,24 @@ def create_children(self) -> list[ChildType]: return children + def _get_child_from_model(self, field: str, info: PartialModelField) -> ChildType: + builder = self.builder + + composite_type = info.get('composite_type') + + if composite_type is not None: + return Key( + field, + sep=' ', + node=Selection.create( + builder, + include=None, + model=composite_type, + ), + ) + + return field + class Key(AbstractNode): """Node for rendering a child node with a prefixed key""" diff --git a/src/prisma/generator/templates/composite_types.py.jinja b/src/prisma/generator/templates/composite_types.py.jinja new file mode 100644 index 000000000..8a1d54412 --- /dev/null +++ b/src/prisma/generator/templates/composite_types.py.jinja @@ -0,0 +1,35 @@ +{% include '_header.py.jinja' %} +{% from '_utils.py.jinja' import recursive_types with context %} + +from collections import OrderedDict + +from . import types, enums, errors, fields, bases +from .generator import partial_models_ctx, PartialModelField, MetaFieldsInterface + +{% for composite_type in dmmf.datamodel.types %} +class {{ composite_type.name }}(MetaFieldsInterface): + {% for field in composite_type.fields %} + {{ field.name }}: {{ field.python_type_as_string }} + {% endfor %} + + @staticmethod + def get_meta_fields() -> Dict['types.{{ composite_type.name }}Keys', PartialModelField]: + return _{{ composite_type.name }}_fields + +_{{ composite_type.name }}_fields: Dict['types.{{ composite_type.name }}Keys', PartialModelField] = OrderedDict( + [ + {% for field in composite_type.fields %} + ('{{ field.name }}', { + 'name': '{{ field.name }}', + 'is_list': {{ field.is_list }}, + 'optional': {{ field.is_optional }}, + 'type': {{ field.python_type_as_string }}, + 'is_relational': {{ field.relation_name is not none }}, + 'documentation': {% if field.documentation is none %}None{% else %}'''{{ field.documentation }}'''{% endif %}, + 'composite_type': {% if field.is_composite_type %}{{ field.type }}{% else %}None{% endif %}, + }), + {% endfor %} + ], +) + +{% endfor %} diff --git a/src/prisma/generator/templates/models.py.jinja b/src/prisma/generator/templates/models.py.jinja index 17b5b55a8..e9ad48335 100644 --- a/src/prisma/generator/templates/models.py.jinja +++ b/src/prisma/generator/templates/models.py.jinja @@ -13,14 +13,15 @@ from . import types, enums, errors, fields, bases from ._types import FuncType from ._compat import model_rebuild, field_validator from ._builder import serialize_base64 -from .generator import partial_models_ctx, PartialModelField +from .generator import partial_models_ctx, PartialModelField, MetaFieldsInterface +from .composite_types import * log: logging.Logger = logging.getLogger(__name__) _created_partial_types: Set[str] = set() {% for model in dmmf.datamodel.models %} -class {{ model.name }}(bases.Base{{ model.name }}): +class {{ model.name }}(bases.Base{{ model.name }}, MetaFieldsInterface): {% if model.documentation is none %} """Represents a {{ model.name }} record""" {% else %} @@ -184,6 +185,10 @@ class {{ model.name }}(bases.Base{{ model.name }}): } ) _created_partial_types.add(name) + + @staticmethod + def get_meta_fields() -> Dict['types.{{ model.name }}Keys', PartialModelField]: + return _{{ model.name }}_fields {% endfor %} @@ -208,6 +213,7 @@ _{{ model.name }}_fields: Dict['types.{{ model.name }}Keys', PartialModelField] 'type': {{ field.python_type_as_string }}, 'is_relational': {{ field.relation_name is not none }}, 'documentation': {% if field.documentation is none %}None{% else %}'''{{ field.documentation }}'''{% endif %}, + 'composite_type': {% if field.is_composite_type %}{{ field.type }}{% else %}None{% endif %}, }), {% endfor %} ], diff --git a/src/prisma/generator/templates/types.py.jinja b/src/prisma/generator/templates/types.py.jinja index 2f49bc115..68a4a8b90 100644 --- a/src/prisma/generator/templates/types.py.jinja +++ b/src/prisma/generator/templates/types.py.jinja @@ -447,6 +447,12 @@ class {{ composite_type.name }}(TypedDict, total=False): {% for field in composite_type.fields %} {{ field.name }}: {{ field.maybe_optional(field.create_input_type) }} {% endfor %} + +{{ composite_type.name }}Keys = Literal[ + {% for field in composite_type.fields %} + '{{ field.name }}', + {% endfor %} +] {% endfor %} diff --git a/src/prisma/generator/types.py b/src/prisma/generator/types.py index 3cbc25f69..a5c506893 100644 --- a/src/prisma/generator/types.py +++ b/src/prisma/generator/types.py @@ -1,10 +1,12 @@ -from typing import Mapping, Optional +from abc import abstractmethod +from typing import Dict, Mapping, Optional from .._types import TypedDict __all__ = ( 'PartialModel', 'PartialModelField', + 'MetaFieldsInterface', ) @@ -15,9 +17,17 @@ class PartialModelField(TypedDict): type: str documentation: Optional[str] is_relational: bool + composite_type: Optional[object] class PartialModel(TypedDict): name: str from_model: str fields: Mapping[str, PartialModelField] + + +class MetaFieldsInterface: + @staticmethod + @abstractmethod + def get_meta_fields() -> Dict[str, PartialModelField]: + ... From 402454156dcec011f507ef69a651c793c133900a Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Tue, 19 Nov 2024 19:04:37 +0100 Subject: [PATCH 08/17] fix(compositeTypes): add BaseModel to Composite Type classes to allow Pydantic to parse --- databases/tests_mongodb/test_composite_types.py | 1 + src/prisma/generator/templates/composite_types.py.jinja | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/databases/tests_mongodb/test_composite_types.py b/databases/tests_mongodb/test_composite_types.py index a9fa4b0fe..679a9d2ea 100644 --- a/databases/tests_mongodb/test_composite_types.py +++ b/databases/tests_mongodb/test_composite_types.py @@ -10,3 +10,4 @@ async def test_composite_types(client: Prisma) -> None: assert user is not None assert user.name == 'Alice' + assert user.contact.email == 'test@test.com' diff --git a/src/prisma/generator/templates/composite_types.py.jinja b/src/prisma/generator/templates/composite_types.py.jinja index 8a1d54412..96b0624ec 100644 --- a/src/prisma/generator/templates/composite_types.py.jinja +++ b/src/prisma/generator/templates/composite_types.py.jinja @@ -3,11 +3,13 @@ from collections import OrderedDict +from pydantic import BaseModel + from . import types, enums, errors, fields, bases from .generator import partial_models_ctx, PartialModelField, MetaFieldsInterface {% for composite_type in dmmf.datamodel.types %} -class {{ composite_type.name }}(MetaFieldsInterface): +class {{ composite_type.name }}(BaseModel, MetaFieldsInterface): {% for field in composite_type.fields %} {{ field.name }}: {{ field.python_type_as_string }} {% endfor %} From 0c6f530732b4783856931d405ff3c27ec4932c4c Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Tue, 19 Nov 2024 20:05:04 +0100 Subject: [PATCH 09/17] fix(generator): fix type errors that failed tests --- src/prisma/_builder.py | 1 + src/prisma/generator/templates/composite_types.py.jinja | 2 ++ src/prisma/generator/templates/models.py.jinja | 2 ++ src/prisma/generator/types.py | 4 ++-- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/prisma/_builder.py b/src/prisma/_builder.py index 1d797a7f1..6c67e1239 100644 --- a/src/prisma/_builder.py +++ b/src/prisma/_builder.py @@ -723,6 +723,7 @@ def create_children(self) -> list[ChildType]: if not isinstance(model, type(BaseModel)): raise ValueError(f'Expected model to be a Pydantic model but got {type(model)} instead.') + model = cast(type[BaseModel], model) for key, value in include.items(): if value is True: diff --git a/src/prisma/generator/templates/composite_types.py.jinja b/src/prisma/generator/templates/composite_types.py.jinja index 96b0624ec..b84c4dd75 100644 --- a/src/prisma/generator/templates/composite_types.py.jinja +++ b/src/prisma/generator/templates/composite_types.py.jinja @@ -2,6 +2,7 @@ {% from '_utils.py.jinja' import recursive_types with context %} from collections import OrderedDict +from typing_extensions import override from pydantic import BaseModel @@ -15,6 +16,7 @@ class {{ composite_type.name }}(BaseModel, MetaFieldsInterface): {% endfor %} @staticmethod + @override def get_meta_fields() -> Dict['types.{{ composite_type.name }}Keys', PartialModelField]: return _{{ composite_type.name }}_fields diff --git a/src/prisma/generator/templates/models.py.jinja b/src/prisma/generator/templates/models.py.jinja index e9ad48335..7d25df759 100644 --- a/src/prisma/generator/templates/models.py.jinja +++ b/src/prisma/generator/templates/models.py.jinja @@ -6,6 +6,7 @@ import logging import inspect import warnings from collections import OrderedDict +from typing_extensions import override from pydantic import BaseModel, Field @@ -187,6 +188,7 @@ class {{ model.name }}(bases.Base{{ model.name }}, MetaFieldsInterface): _created_partial_types.add(name) @staticmethod + @override def get_meta_fields() -> Dict['types.{{ model.name }}Keys', PartialModelField]: return _{{ model.name }}_fields diff --git a/src/prisma/generator/types.py b/src/prisma/generator/types.py index a5c506893..fb1750a42 100644 --- a/src/prisma/generator/types.py +++ b/src/prisma/generator/types.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Dict, Mapping, Optional +from typing import Any, Dict, Mapping, Optional from .._types import TypedDict @@ -29,5 +29,5 @@ class PartialModel(TypedDict): class MetaFieldsInterface: @staticmethod @abstractmethod - def get_meta_fields() -> Dict[str, PartialModelField]: + def get_meta_fields() -> Dict[Any, PartialModelField]: ... From 83ebdd1aa083269d1a590e3298665f33d3029e9f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 19:07:07 +0000 Subject: [PATCH 10/17] chore(pre-commit.ci): auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/prisma/generator/templates/models.py.jinja | 2 +- src/prisma/generator/types.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/prisma/generator/templates/models.py.jinja b/src/prisma/generator/templates/models.py.jinja index 7d25df759..4e4cdb50a 100644 --- a/src/prisma/generator/templates/models.py.jinja +++ b/src/prisma/generator/templates/models.py.jinja @@ -186,7 +186,7 @@ class {{ model.name }}(bases.Base{{ model.name }}, MetaFieldsInterface): } ) _created_partial_types.add(name) - + @staticmethod @override def get_meta_fields() -> Dict['types.{{ model.name }}Keys', PartialModelField]: diff --git a/src/prisma/generator/types.py b/src/prisma/generator/types.py index fb1750a42..819e0125d 100644 --- a/src/prisma/generator/types.py +++ b/src/prisma/generator/types.py @@ -29,5 +29,4 @@ class PartialModel(TypedDict): class MetaFieldsInterface: @staticmethod @abstractmethod - def get_meta_fields() -> Dict[Any, PartialModelField]: - ... + def get_meta_fields() -> Dict[Any, PartialModelField]: ... From 64c6c03277793c3fa655339b3b962d1969876914 Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Wed, 20 Nov 2024 14:50:36 +0100 Subject: [PATCH 11/17] test(compositeTypes): add some tests for filtering with composite types --- databases/templates/schema.prisma.jinja2 | 12 ++ .../tests_mongodb/test_composite_types.py | 168 +++++++++++++++++- 2 files changed, 177 insertions(+), 3 deletions(-) diff --git a/databases/templates/schema.prisma.jinja2 b/databases/templates/schema.prisma.jinja2 index fe667c38b..950cc53f7 100644 --- a/databases/templates/schema.prisma.jinja2 +++ b/databases/templates/schema.prisma.jinja2 @@ -243,10 +243,22 @@ model User { id String @id @default(auto()) @map("_id") @db.ObjectId name String contact Contact + pets Pet[] } type Contact { email String phone String } + +type Pet { + name String + type PetType +} + +enum PetType { + CAT + DOG + FISH +} {% endif %} diff --git a/databases/tests_mongodb/test_composite_types.py b/databases/tests_mongodb/test_composite_types.py index 679a9d2ea..e0bc92419 100644 --- a/databases/tests_mongodb/test_composite_types.py +++ b/databases/tests_mongodb/test_composite_types.py @@ -1,13 +1,175 @@ import pytest +import prisma from prisma import Prisma @pytest.mark.asyncio -async def test_composite_types(client: Prisma) -> None: - await client.user.create({'name': 'Alice', 'contact': {'email': 'test@test.com', 'phone': '123-456-7890'}}) +async def test_composite_type_create(client: Prisma) -> None: + """ + Test creating a user with a composite type + """ + user = await client.user.create({ + 'name': 'Alice', + 'contact': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + }, + 'pets': [] + }) + + assert user is not None + assert user.id is not None + assert user.name == 'Alice' + + +@pytest.mark.asyncio +async def test_composite_type_find_first_without_filters(client: Prisma) -> None: + """ + Test finding a user with a composite type without any filters + """ + await client.user.create({ + 'name': 'Alice', + 'contact': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + }, + 'pets': [] + }) + user = await client.user.find_first() assert user is not None + assert user.id is not None + assert user.name == 'Alice' + + +@pytest.mark.asyncio +async def test_composite_type_without_complete_where(client: Prisma) -> None: + """ + Test finding a user with a composite type without providing all the fields raises an error + """ + await client.user.create({ + 'name': 'Alice', + 'contact': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + }, + 'pets': [] + }) + + with pytest.raises(prisma.errors.MissingRequiredValueError) as exc: + await client.user.find_first(where={ + 'contact': { + 'email': 'alice@example.com', + }, + }) + + assert '`where.contact.phone`: A value is required but not set' in str(exc.value) + + +@pytest.mark.asyncio +async def test_composite_type_with_complete_where(client: Prisma) -> None: + """ + Test finding a user with a composite type with all the fields + """ + await client.user.create({ + 'name': 'Alice', + 'contact': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + }, + 'pets': [] + }) + + user = await client.user.find_first(where={ + 'contact': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + }, + }) + + assert user is not None + assert user.id is not None + assert user.name == 'Alice' + + +@pytest.mark.asyncio +async def test_composite_type_with_where_is(client: Prisma) -> None: + """ + Test finding a user with a composite type with the `is` operator + """ + await client.user.create({ + 'name': 'Alice', + 'contact': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + }, + 'pets': [] + }) + + user = await client.user.find_first(where={ + 'contact': { + 'is': { + 'email': 'alice@example.com', + } + }, + }) + + assert user is not None + assert user.id is not None + assert user.name == 'Alice' + + +@pytest.mark.asyncio +async def test_composite_type_with_where_equals_without_all_fields(client: Prisma) -> None: + """ + Test finding a user with a composite type with the `equals` operator without all the fields + """ + await client.user.create({ + 'name': 'Alice', + 'contact': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + }, + 'pets': [] + }) + + with pytest.raises(prisma.errors.MissingRequiredValueError) as exc: + await client.user.find_first(where={ + 'contact': { + 'equals': { + 'email': 'alice@example.com', + } + }, + }) + + assert '`where.contact.equals.phone`: A value is required but not set' in str(exc.value) + + +@pytest.mark.asyncio +async def test_composite_type_with_where_equals_with_all_fields(client: Prisma) -> None: + """ + Test finding a user with a composite type with the `equals` operator with all the fields + """ + await client.user.create({ + 'name': 'Alice', + 'contact': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + }, + 'pets': [] + }) + + user = await client.user.find_first(where={ + 'contact': { + 'equals': { + 'email': 'alice@example.com', + 'phone': '123-456-7890' + } + }, + }) + + assert user is not None + assert user.id is not None assert user.name == 'Alice' - assert user.contact.email == 'test@test.com' From 4ee38b7b369336675f1a43e9e7f66cb5ed73212e Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Wed, 20 Nov 2024 20:49:09 +0100 Subject: [PATCH 12/17] feat(compositeTypes): add types for filters of composite types --- .../tests_mongodb/test_composite_types.py | 4 +- src/prisma/generator/models.py | 13 +++-- src/prisma/generator/templates/types.py.jinja | 52 ++++++++++++++++++- 3 files changed, 62 insertions(+), 7 deletions(-) diff --git a/databases/tests_mongodb/test_composite_types.py b/databases/tests_mongodb/test_composite_types.py index e0bc92419..e0f5a2d55 100644 --- a/databases/tests_mongodb/test_composite_types.py +++ b/databases/tests_mongodb/test_composite_types.py @@ -63,7 +63,7 @@ async def test_composite_type_without_complete_where(client: Prisma) -> None: 'contact': { 'email': 'alice@example.com', }, - }) + }) # type: ignore assert '`where.contact.phone`: A value is required but not set' in str(exc.value) @@ -142,7 +142,7 @@ async def test_composite_type_with_where_equals_without_all_fields(client: Prism 'email': 'alice@example.com', } }, - }) + }) # type: ignore assert '`where.contact.equals.phone`: A value is required but not set' in str(exc.value) diff --git a/src/prisma/generator/models.py b/src/prisma/generator/models.py index d1a511eb3..d2bf7f8d5 100644 --- a/src/prisma/generator/models.py +++ b/src/prisma/generator/models.py @@ -984,10 +984,15 @@ def create_input_type(self) -> str: @property def where_input_type(self) -> str: + typ = self.type + if self.is_composite_type: - return self.python_type + if self.is_list: + return f"'types.{typ}ListFilter'" + if self.is_optional: + return f"'types.{typ}OptionalFilter'" + return f"'types.{typ}Filter'" - typ = self.type if self.is_relational: if self.is_list: return f"'{typ}ListRelationFilter'" @@ -1058,9 +1063,9 @@ def is_atomic(self) -> bool: def is_number(self) -> bool: return self.type in {'Int', 'BigInt', 'Float'} - def maybe_optional(self, typ: str) -> str: + def maybe_optional(self, typ: str, force: bool = False) -> str: """Wrap the given type string within `Optional` if applicable""" - if self.is_required or self.is_relational: + if (not force) and (self.is_required or self.is_relational or self.is_composite_type): return typ return f'Optional[{typ}]' diff --git a/src/prisma/generator/templates/types.py.jinja b/src/prisma/generator/templates/types.py.jinja index 68a4a8b90..0fc0a4ae4 100644 --- a/src/prisma/generator/templates/types.py.jinja +++ b/src/prisma/generator/templates/types.py.jinja @@ -443,11 +443,18 @@ class _{{ type }}ListUpdatePush(TypedDict): {% for composite_type in dmmf.datamodel.types %} # {{ composite_type.name }} types -class {{ composite_type.name }}(TypedDict, total=False): +class {{ composite_type.name }}(TypedDict): {% for field in composite_type.fields %} {{ field.name }}: {{ field.maybe_optional(field.create_input_type) }} {% endfor %} + +class Partial{{ composite_type.name }}(TypedDict, total=False): + {% for field in composite_type.fields %} + {{ field.name }}: {{ field.maybe_optional(field.create_input_type, True) }} + {% endfor %} + + {{ composite_type.name }}Keys = Literal[ {% for field in composite_type.fields %} '{{ field.name }}', @@ -793,6 +800,49 @@ class {{ model.name }}NumberAggregateInput(TypedDict, total=False): {% endfor %} +{% for composite_type in dmmf.datamodel.types %} + +{{ composite_type.name }}Filter = Union[ + {{ composite_type.name }}, + '{{ composite_type.name }}FilterOperations', +] + +{{ composite_type.name }}FilterOperations = TypedDict( + '{{ composite_type.name }}FilterOperations', + { + 'equals': '{{ composite_type.name }}', + 'is': 'Partial{{ composite_type.name }}', + 'is_not': 'Partial{{ composite_type.name }}', + }, + total=False, +) + +class {{ composite_type.name }}OptionalFilterOperations({{ composite_type.name }}FilterOperations): + is_set: bool + + +class {{ composite_type.name }}FilterWithQuantifiers(Partial{{ composite_type.name }}, total=False): + AND: 'List[{{ composite_type.name }}FilterWithQuantifiers]' + OR: 'List[{{ composite_type.name }}FilterWithQuantifiers]' + NOT: 'List[{{ composite_type.name }}FilterWithQuantifiers]' + + +{{ composite_type.name }}ListFilter = Union[ + List[{{ composite_type.name }}], + '{{ composite_type.name }}ListFilterOperations', +] + + +class {{ composite_type.name }}ListFilterOperations(TypedDict, total=False): + equals: List['{{ composite_type.name }}'] + is_empty: bool + is_set: bool + every: '{{ composite_type.name }}FilterWithQuantifiers' + some: '{{ composite_type.name }}FilterWithQuantifiers' + none: '{{ composite_type.name }}FilterWithQuantifiers' + +{% endfor %} + # we have to import ourselves as types can be namespaced to types from . import types, enums, models, fields From ad50676c78ea90049bd881cd64cbd5b1160a6593 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Nov 2024 19:50:38 +0000 Subject: [PATCH 13/17] chore(pre-commit.ci): auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests_mongodb/test_composite_types.py | 155 +++++++----------- 1 file changed, 61 insertions(+), 94 deletions(-) diff --git a/databases/tests_mongodb/test_composite_types.py b/databases/tests_mongodb/test_composite_types.py index e0f5a2d55..e497bb863 100644 --- a/databases/tests_mongodb/test_composite_types.py +++ b/databases/tests_mongodb/test_composite_types.py @@ -9,14 +9,9 @@ async def test_composite_type_create(client: Prisma) -> None: """ Test creating a user with a composite type """ - user = await client.user.create({ - 'name': 'Alice', - 'contact': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - }, - 'pets': [] - }) + user = await client.user.create( + {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} + ) assert user is not None assert user.id is not None @@ -28,14 +23,9 @@ async def test_composite_type_find_first_without_filters(client: Prisma) -> None """ Test finding a user with a composite type without any filters """ - await client.user.create({ - 'name': 'Alice', - 'contact': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - }, - 'pets': [] - }) + await client.user.create( + {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} + ) user = await client.user.find_first() @@ -49,45 +39,36 @@ async def test_composite_type_without_complete_where(client: Prisma) -> None: """ Test finding a user with a composite type without providing all the fields raises an error """ - await client.user.create({ - 'name': 'Alice', - 'contact': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - }, - 'pets': [] - }) + await client.user.create( + {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} + ) with pytest.raises(prisma.errors.MissingRequiredValueError) as exc: - await client.user.find_first(where={ - 'contact': { - 'email': 'alice@example.com', - }, - }) # type: ignore - + await client.user.find_first( + where={ + 'contact': { + 'email': 'alice@example.com', + }, + } + ) # type: ignore + assert '`where.contact.phone`: A value is required but not set' in str(exc.value) - + @pytest.mark.asyncio async def test_composite_type_with_complete_where(client: Prisma) -> None: """ Test finding a user with a composite type with all the fields """ - await client.user.create({ - 'name': 'Alice', - 'contact': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - }, - 'pets': [] - }) - - user = await client.user.find_first(where={ - 'contact': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - }, - }) + await client.user.create( + {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} + ) + + user = await client.user.find_first( + where={ + 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, + } + ) assert user is not None assert user.id is not None @@ -99,22 +80,19 @@ async def test_composite_type_with_where_is(client: Prisma) -> None: """ Test finding a user with a composite type with the `is` operator """ - await client.user.create({ - 'name': 'Alice', - 'contact': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - }, - 'pets': [] - }) - - user = await client.user.find_first(where={ - 'contact': { - 'is': { - 'email': 'alice@example.com', - } - }, - }) + await client.user.create( + {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} + ) + + user = await client.user.find_first( + where={ + 'contact': { + 'is': { + 'email': 'alice@example.com', + } + }, + } + ) assert user is not None assert user.id is not None @@ -126,23 +104,20 @@ async def test_composite_type_with_where_equals_without_all_fields(client: Prism """ Test finding a user with a composite type with the `equals` operator without all the fields """ - await client.user.create({ - 'name': 'Alice', - 'contact': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - }, - 'pets': [] - }) + await client.user.create( + {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} + ) with pytest.raises(prisma.errors.MissingRequiredValueError) as exc: - await client.user.find_first(where={ - 'contact': { - 'equals': { - 'email': 'alice@example.com', - } - }, - }) # type: ignore + await client.user.find_first( + where={ + 'contact': { + 'equals': { + 'email': 'alice@example.com', + } + }, + } + ) # type: ignore assert '`where.contact.equals.phone`: A value is required but not set' in str(exc.value) @@ -152,23 +127,15 @@ async def test_composite_type_with_where_equals_with_all_fields(client: Prisma) """ Test finding a user with a composite type with the `equals` operator with all the fields """ - await client.user.create({ - 'name': 'Alice', - 'contact': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - }, - 'pets': [] - }) - - user = await client.user.find_first(where={ - 'contact': { - 'equals': { - 'email': 'alice@example.com', - 'phone': '123-456-7890' - } - }, - }) + await client.user.create( + {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} + ) + + user = await client.user.find_first( + where={ + 'contact': {'equals': {'email': 'alice@example.com', 'phone': '123-456-7890'}}, + } + ) assert user is not None assert user.id is not None From f77f9e7e2f7da04c891169c90efa8ec418b4b0e6 Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Fri, 22 Nov 2024 16:06:14 +0100 Subject: [PATCH 14/17] Revert "test(mongodb): create test for mongodb composite types" This reverts commit f08bc8bcea70d31a50ba43d857e70a1360238a12. --- databases/templates/schema.prisma.jinja2 | 18 +-- .../tests_mongodb/test_composite_types.py | 142 ------------------ databases/tests_mongodb/test_hello.py | 8 + 3 files changed, 11 insertions(+), 157 deletions(-) delete mode 100644 databases/tests_mongodb/test_composite_types.py create mode 100644 databases/tests_mongodb/test_hello.py diff --git a/databases/templates/schema.prisma.jinja2 b/databases/templates/schema.prisma.jinja2 index 950cc53f7..adbcebd10 100644 --- a/databases/templates/schema.prisma.jinja2 +++ b/databases/templates/schema.prisma.jinja2 @@ -242,23 +242,11 @@ enum Role { model User { id String @id @default(auto()) @map("_id") @db.ObjectId name String - contact Contact - pets Pet[] + {# contact Contact #} } -type Contact { +{# type Contact { email String phone String -} - -type Pet { - name String - type PetType -} - -enum PetType { - CAT - DOG - FISH -} +} #} {% endif %} diff --git a/databases/tests_mongodb/test_composite_types.py b/databases/tests_mongodb/test_composite_types.py deleted file mode 100644 index e497bb863..000000000 --- a/databases/tests_mongodb/test_composite_types.py +++ /dev/null @@ -1,142 +0,0 @@ -import pytest - -import prisma -from prisma import Prisma - - -@pytest.mark.asyncio -async def test_composite_type_create(client: Prisma) -> None: - """ - Test creating a user with a composite type - """ - user = await client.user.create( - {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} - ) - - assert user is not None - assert user.id is not None - assert user.name == 'Alice' - - -@pytest.mark.asyncio -async def test_composite_type_find_first_without_filters(client: Prisma) -> None: - """ - Test finding a user with a composite type without any filters - """ - await client.user.create( - {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} - ) - - user = await client.user.find_first() - - assert user is not None - assert user.id is not None - assert user.name == 'Alice' - - -@pytest.mark.asyncio -async def test_composite_type_without_complete_where(client: Prisma) -> None: - """ - Test finding a user with a composite type without providing all the fields raises an error - """ - await client.user.create( - {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} - ) - - with pytest.raises(prisma.errors.MissingRequiredValueError) as exc: - await client.user.find_first( - where={ - 'contact': { - 'email': 'alice@example.com', - }, - } - ) # type: ignore - - assert '`where.contact.phone`: A value is required but not set' in str(exc.value) - - -@pytest.mark.asyncio -async def test_composite_type_with_complete_where(client: Prisma) -> None: - """ - Test finding a user with a composite type with all the fields - """ - await client.user.create( - {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} - ) - - user = await client.user.find_first( - where={ - 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, - } - ) - - assert user is not None - assert user.id is not None - assert user.name == 'Alice' - - -@pytest.mark.asyncio -async def test_composite_type_with_where_is(client: Prisma) -> None: - """ - Test finding a user with a composite type with the `is` operator - """ - await client.user.create( - {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} - ) - - user = await client.user.find_first( - where={ - 'contact': { - 'is': { - 'email': 'alice@example.com', - } - }, - } - ) - - assert user is not None - assert user.id is not None - assert user.name == 'Alice' - - -@pytest.mark.asyncio -async def test_composite_type_with_where_equals_without_all_fields(client: Prisma) -> None: - """ - Test finding a user with a composite type with the `equals` operator without all the fields - """ - await client.user.create( - {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} - ) - - with pytest.raises(prisma.errors.MissingRequiredValueError) as exc: - await client.user.find_first( - where={ - 'contact': { - 'equals': { - 'email': 'alice@example.com', - } - }, - } - ) # type: ignore - - assert '`where.contact.equals.phone`: A value is required but not set' in str(exc.value) - - -@pytest.mark.asyncio -async def test_composite_type_with_where_equals_with_all_fields(client: Prisma) -> None: - """ - Test finding a user with a composite type with the `equals` operator with all the fields - """ - await client.user.create( - {'name': 'Alice', 'contact': {'email': 'alice@example.com', 'phone': '123-456-7890'}, 'pets': []} - ) - - user = await client.user.find_first( - where={ - 'contact': {'equals': {'email': 'alice@example.com', 'phone': '123-456-7890'}}, - } - ) - - assert user is not None - assert user.id is not None - assert user.name == 'Alice' diff --git a/databases/tests_mongodb/test_hello.py b/databases/tests_mongodb/test_hello.py new file mode 100644 index 000000000..6d6cdd20f --- /dev/null +++ b/databases/tests_mongodb/test_hello.py @@ -0,0 +1,8 @@ +import pytest + +from prisma import Prisma + + +@pytest.mark.asyncio +async def test_base_usage(client: Prisma) -> None: + assert 'Hello, World!' == 'Hello, World!' From d377fe7129f83140f944e090d671f2ddb1392a0c Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Fri, 22 Nov 2024 16:06:52 +0100 Subject: [PATCH 15/17] Revert "test(mongodb): init test environment for mongodb" This reverts commit 4d6e3b1cb96095557c01784e0ff798d07634b20d. --- databases/_types.py | 2 -- databases/constants.py | 9 ------ databases/main.py | 8 +---- databases/templates/schema.prisma.jinja2 | 17 ---------- databases/tests_mongodb/__init__.py | 0 databases/tests_mongodb/conftest.py | 40 ------------------------ databases/tests_mongodb/test_hello.py | 8 ----- 7 files changed, 1 insertion(+), 83 deletions(-) delete mode 100644 databases/tests_mongodb/__init__.py delete mode 100644 databases/tests_mongodb/conftest.py delete mode 100644 databases/tests_mongodb/test_hello.py diff --git a/databases/_types.py b/databases/_types.py index 9f462edf7..9e86a281c 100644 --- a/databases/_types.py +++ b/databases/_types.py @@ -14,7 +14,6 @@ 'mariadb', 'postgresql', 'cockroachdb', - 'mongodb', ] @@ -29,4 +28,3 @@ class DatabaseMapping(TypedDict, Generic[_T]): mariadb: _T postgresql: _T cockroachdb: _T - mongodb: _T diff --git a/databases/constants.py b/databases/constants.py index 41c825516..486b743de 100644 --- a/databases/constants.py +++ b/databases/constants.py @@ -83,15 +83,6 @@ def _fromdir(path: str) -> list[str]: 'full_text_search', }, ), - 'mongodb': DatabaseConfig( - id='mongodb', - name='MongoDB', - env_var='MONGODB_URL', - bools_are_ints=False, - default_date_func='', - autoincrement_id='', - unsupported_features=set(), - ), } SUPPORTED_DATABASES = cast(List[SupportedDatabase], list(get_args(SupportedDatabase))) diff --git a/databases/main.py b/databases/main.py index 230734af5..d9de7b6a8 100644 --- a/databases/main.py +++ b/databases/main.py @@ -299,7 +299,7 @@ def setup(self) -> None: # template variables config=self.config, for_async=self.for_async, - partial_generator=(escape_path(DATABASES_DIR / 'partials.py') if self.config.id != 'mongodb' else None), + partial_generator=escape_path(DATABASES_DIR / 'partials.py'), ) ) @@ -326,8 +326,6 @@ def test(self, *, pytest_args: str | None) -> None: ) args = [] - if self.config.id == 'mongodb': - args.append('tests_mongodb') if pytest_args is not None: # pragma: no cover args = shlex.split(pytest_args) @@ -397,10 +395,6 @@ def exclude_files(self) -> set[str]: # ensure the tests for the sync client are not ran during the async tests anc vice versa files.append(tests_reldir(for_async=not self.for_async)) - if self.config.id == 'mongodb': - files.append(tests_relpath('.', for_async=True)) - files.append(tests_relpath('.', for_async=False)) - return set(files) diff --git a/databases/templates/schema.prisma.jinja2 b/databases/templates/schema.prisma.jinja2 index adbcebd10..74d7241a2 100644 --- a/databases/templates/schema.prisma.jinja2 +++ b/databases/templates/schema.prisma.jinja2 @@ -13,9 +13,7 @@ generator client { recursive_type_depth = -1 engineType = "binary" enable_experimental_decimal = true - {% if partial_generator %} partial_type_generator = "{{ partial_generator }}" - {% endif %} {% if config.id == "postgresql" %} previewFeatures = ["fullTextSearch"] {% elif config.id == "mysql" %} @@ -23,7 +21,6 @@ generator client { {% endif %} } -{% if config.id != "mongodb" %} model User { id String @id @default(cuid()) created_at DateTime @default(now()) @@ -236,17 +233,3 @@ enum Role { } {% endif %} -{% endif %} - -{% if config.id == "mongodb" %} -model User { - id String @id @default(auto()) @map("_id") @db.ObjectId - name String - {# contact Contact #} -} - -{# type Contact { - email String - phone String -} #} -{% endif %} diff --git a/databases/tests_mongodb/__init__.py b/databases/tests_mongodb/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/databases/tests_mongodb/conftest.py b/databases/tests_mongodb/conftest.py deleted file mode 100644 index ae506fb35..000000000 --- a/databases/tests_mongodb/conftest.py +++ /dev/null @@ -1,40 +0,0 @@ -import os - -import pytest -from syrupy.assertion import SnapshotAssertion - -import prisma -from prisma import Prisma -from prisma._compat import model_parse_json -from lib.testing.shared_conftest import * -from lib.testing.shared_conftest.async_client import * - -from ..utils import ( - RAW_QUERIES_MAPPING, - RawQueries, - DatabaseConfig, - AmberSharedExtension, -) - -prisma.register(Prisma()) - - -# TODO: better error messages for invalid state -@pytest.fixture(name='database') -def database_fixture() -> str: - return os.environ['PRISMA_DATABASE'] - - -@pytest.fixture(name='raw_queries') -def raw_queries_fixture(database: str) -> RawQueries: - return RAW_QUERIES_MAPPING[database] - - -@pytest.fixture(name='config', scope='session') -def config_fixture() -> DatabaseConfig: - return model_parse_json(DatabaseConfig, os.environ['DATABASE_CONFIG']) - - -@pytest.fixture() -def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: - return snapshot.use_extension(AmberSharedExtension) diff --git a/databases/tests_mongodb/test_hello.py b/databases/tests_mongodb/test_hello.py deleted file mode 100644 index 6d6cdd20f..000000000 --- a/databases/tests_mongodb/test_hello.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - -from prisma import Prisma - - -@pytest.mark.asyncio -async def test_base_usage(client: Prisma) -> None: - assert 'Hello, World!' == 'Hello, World!' From 590bd894f2ac33ee9136d55eaa965c242e5c36ac Mon Sep 17 00:00:00 2001 From: Arthur Fontaine <0arthur.fontaine@gmail.com> Date: Wed, 27 Nov 2024 22:33:02 +0100 Subject: [PATCH 16/17] fix(queryBuilder): fix regressions caused by using get_meta_fields instead of get_default_fields --- src/prisma/_builder.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/src/prisma/_builder.py b/src/prisma/_builder.py index 6c67e1239..886b29cd0 100644 --- a/src/prisma/_builder.py +++ b/src/prisma/_builder.py @@ -203,6 +203,26 @@ def _create_root_node(self) -> 'RootNode': ) return root + def get_default_fields(self, model: type[PrismaModel]) -> list[str]: + """Returns a list of all the scalar fields of a model + Raises UnknownModelError if the current model cannot be found. + """ + name = getattr(model, '__prisma_model__', MISSING) + if name is MISSING: + raise InvalidModelError(model) + + name = model.__prisma_model__ + if name not in self.prisma_models: + raise UnknownModelError(name) + + # by default we exclude every field that points to a PrismaModel as that indicates that it is a relational field + # we explicitly keep fields that point to anything else, even other pydantic.BaseModel types, as they can be used to deserialize JSON + return [ + field + for field, info in model_fields(model).items() + if not _field_is_prisma_model(info, name=field, parent=model) + ] + def get_relational_model(self, current_model: type[BaseModel], field: str) -> type[PrismaModel]: """Returns the model that the field is related to. @@ -291,6 +311,13 @@ def _prisma_model_for_field( return None +def _field_is_prisma_model(field: FieldInfo, *, name: str, parent: type[BaseModel]) -> bool: + """Whether or not the given field info represents a model at the database level. + This will return `True` for cases where the field represents a list of models or a single model. + """ + return _prisma_model_for_field(field, name=name, parent=parent) is not None + + def _is_prisma_model_type(type_: type[BaseModel]) -> TypeGuard[type[PrismaModel]]: from .bases import _PrismaModel # noqa: TID251 @@ -714,8 +741,11 @@ def create_children(self) -> list[ChildType]: if root_selection is not None: children.extend(root_selection) elif model is not None: - for field, info in model.get_meta_fields().items(): - children.append(self._get_child_from_model(field, info)) + if hasattr(model, 'get_meta_fields'): + for field, info in model.get_meta_fields().items(): + children.append(self._get_child_from_model(field, info)) + else: + children.extend(builder.get_default_fields(model)) # type: ignore if include is not None: if model is None: @@ -780,6 +810,9 @@ def _get_child_from_model(self, field: str, info: PartialModelField) -> ChildTyp model=composite_type, ), ) + + if info.get("is_relational"): + return "" return field From e9e5bfb5ffd99f8dfcd6a5501dc7a0c650998c30 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 21:33:22 +0000 Subject: [PATCH 17/17] chore(pre-commit.ci): auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/prisma/_builder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/prisma/_builder.py b/src/prisma/_builder.py index 886b29cd0..c1a1f9979 100644 --- a/src/prisma/_builder.py +++ b/src/prisma/_builder.py @@ -745,7 +745,7 @@ def create_children(self) -> list[ChildType]: for field, info in model.get_meta_fields().items(): children.append(self._get_child_from_model(field, info)) else: - children.extend(builder.get_default_fields(model)) # type: ignore + children.extend(builder.get_default_fields(model)) # type: ignore if include is not None: if model is None: @@ -810,9 +810,9 @@ def _get_child_from_model(self, field: str, info: PartialModelField) -> ChildTyp model=composite_type, ), ) - - if info.get("is_relational"): - return "" + + if info.get('is_relational'): + return '' return field