Skip to content
This repository was archived by the owner on Apr 15, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions databases/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

import prisma
# import prisma
from prisma import Prisma
from prisma.errors import FieldNotFoundError, ForeignKeyViolationError

Expand Down Expand Up @@ -39,19 +39,19 @@ async def test_field_not_found_error(client: Prisma) -> None:
)


@pytest.mark.asyncio
@pytest.mark.prisma
async def test_field_not_found_error_selection() -> None:
"""The FieldNotFoundError is raised when an unknown field is passed to selections."""
# @pytest.mark.asyncio
# @pytest.mark.prisma
# async def test_field_not_found_error_selection() -> None:
# """The FieldNotFoundError is raised when an unknown field is passed to selections."""

class CustomPost(prisma.bases.BasePost):
foo_field: str
# class CustomPost(prisma.bases.BasePost):
# foo_field: str

with pytest.raises(
FieldNotFoundError,
match=r'Field \'foo_field\' not found in enclosing type \'Post\'',
):
await CustomPost.prisma().find_first()
# with pytest.raises(
# FieldNotFoundError,
# match=r'Field \'foo_field\' not found in enclosing type \'Post\'',
# ):
# await CustomPost.prisma().find_first()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

test_field_not_found_error_selection skipped

JSON protocol does not explicitly select all fields, so it won't raise a FieldNotFoundError here.

However we could:

  • explicitly select all fields (cons: this could be hard to impl and bring overhead to serializer, prisma-engine and maybe even db); or...
  • leave it to pydantic as it will raise validation errors if data doesn't match the model.



@pytest.mark.asyncio
Expand Down
53 changes: 39 additions & 14 deletions src/prisma/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@

from pydantic import BaseModel

from ._types import Datasource, HttpConfig, PrismaMethod, MetricsFormat, TransactionId, DatasourceOverride
from ._types import (
Datasource,
HttpConfig,
PrismaMethod,
MetricsFormat,
TransactionId,
DatasourceOverride,
)
from .engine import (
SyncQueryEngine,
AsyncQueryEngine,
BaseAbstractEngine,
SyncAbstractEngine,
AsyncAbstractEngine,
json as json_proto,
)
from .errors import ClientNotConnectedError, ClientNotRegisteredError
from ._compat import model_parse, removeprefix
from ._builder import QueryBuilder
from ._metrics import Metrics
from ._registry import get_client
from .generator.models import EngineType
Expand Down Expand Up @@ -286,15 +293,15 @@ def _prepare_connect_args(
log.debug('datasources: %s', datasources)
return timeout, datasources

def _make_query_builder(
def _serialize(
self,
*,
method: PrismaMethod,
arguments: dict[str, Any],
model: type[BaseModel] | None,
root_selection: list[str] | None,
) -> QueryBuilder:
return QueryBuilder(
root_selection: json_proto.JsonSelectionSet | None = None,
) -> json_proto.JsonQuery:
return json_proto.serialize(
method=method,
model=model,
arguments=arguments,
Expand Down Expand Up @@ -415,12 +422,21 @@ def _execute(
method: PrismaMethod,
arguments: dict[str, Any],
model: type[BaseModel] | None = None,
root_selection: list[str] | None = None,
root_selection: json_proto.JsonSelectionSet | None = None,
) -> Any:
builder = self._make_query_builder(
method=method, model=model, arguments=arguments, root_selection=root_selection
return json_proto.deserialize(
self._engine.query(
json_proto.dumps(
self._serialize(
method=method,
arguments=arguments,
model=model,
root_selection=root_selection,
)
),
tx_id=self._tx_id,
)
)
return self._engine.query(builder.build(), tx_id=self._tx_id)


class AsyncBasePrisma(BasePrisma[AsyncAbstractEngine]):
Expand Down Expand Up @@ -535,9 +551,18 @@ async def _execute(
method: PrismaMethod,
arguments: dict[str, Any],
model: type[BaseModel] | None = None,
root_selection: list[str] | None = None,
root_selection: json_proto.JsonSelectionSet | None = None,
) -> Any:
builder = self._make_query_builder(
method=method, model=model, arguments=arguments, root_selection=root_selection
return json_proto.deserialize(
await self._engine.query(
json_proto.dumps(
self._serialize(
method=method,
arguments=arguments,
model=model,
root_selection=root_selection,
)
),
tx_id=self._tx_id,
)
)
return await self._engine.query(builder.build(), tx_id=self._tx_id)
2 changes: 1 addition & 1 deletion src/prisma/engine/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _spawn_process(
RUST_LOG='error',
RUST_LOG_FORMAT='json',
PRISMA_CLIENT_ENGINE_TYPE='binary',
PRISMA_ENGINE_PROTOCOL='graphql',
PRISMA_ENGINE_PROTOCOL='json',
)

if DEBUG:
Expand Down
3 changes: 3 additions & 0 deletions src/prisma/engine/json/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .types import *
from .serializer import dumps as dumps, serialize as serialize
from .deserializer import deserialize as deserialize
43 changes: 43 additions & 0 deletions src/prisma/engine/json/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import json
from typing import Any
from decimal import Decimal
from datetime import datetime
from typing_extensions import TypeGuard

from .types import JsonOutputTaggedValue
from ...fields import Base64


def deserialize(result: Any) -> Any:
if not result:
return result

if isinstance(result, list):
return list(map(deserialize, result))

if isinstance(result, dict):
if is_tagged_value(result):
return result['value'] # XXX: will pydantic cast this?

return {k: deserialize(v) for k, v in result.items()}

return result


def is_tagged_value(value: dict[Any, Any]) -> TypeGuard[JsonOutputTaggedValue]:
return isinstance(value.get('$type'), str)


def deserialize_tagged_value(tagged: JsonOutputTaggedValue) -> Any:
if tagged['$type'] == 'BigInt':
return int(tagged['value'])
elif tagged['$type'] == 'Bytes':
return Base64.fromb64(tagged['value'])
elif tagged['$type'] == 'DateTime':
return datetime.fromisoformat(tagged['value'])
elif tagged['$type'] == 'Decimal':
return Decimal(tagged['value'])
elif tagged['$type'] == 'Json':
return json.loads(tagged['value'])
Loading