From 32bf439d7ed2419ba4b0d34bd4995a03d9ade287 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Mon, 16 May 2022 15:05:30 +0200 Subject: [PATCH 01/19] feat(async): add support for async sessions --- graphene_sqlalchemy/fields.py | 42 +- graphene_sqlalchemy/tests/conftest.py | 38 +- graphene_sqlalchemy/tests/models.py | 34 +- graphene_sqlalchemy/tests/test_batching.py | 414 +++++++++++------- graphene_sqlalchemy/tests/test_benchmark.py | 112 +++-- graphene_sqlalchemy/tests/test_query.py | 106 +++-- graphene_sqlalchemy/tests/test_query_enums.py | 136 ++++-- graphene_sqlalchemy/tests/test_sort_enums.py | 22 +- graphene_sqlalchemy/tests/test_types.py | 299 +++++++------ graphene_sqlalchemy/tests/utils.py | 9 + graphene_sqlalchemy/types.py | 95 ++-- graphene_sqlalchemy/utils.py | 19 +- 12 files changed, 820 insertions(+), 506 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index d7a83392..ea421450 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -3,6 +3,8 @@ from functools import partial from promise import Promise, is_thenable +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.query import Query from graphene import NonNull @@ -11,7 +13,7 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import EnumValue, get_query +from .utils import EnumValue, get_query, get_session class UnsortedSQLAlchemyConnectionField(ConnectionField): @@ -26,9 +28,7 @@ def type(self): assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" ).format(nullable_type.__name__) - assert ( - nullable_type.connection - ), "The type {} doesn't have a connection".format( + assert nullable_type.connection, "The type {} doesn't have a connection".format( nullable_type.__name__ ) assert type_ == nullable_type, ( @@ -46,9 +46,15 @@ def get_query(cls, model, info, **args): return get_query(model, info.context) @classmethod - def resolve_connection(cls, connection_type, model, info, args, resolved): + async def resolve_connection(cls, connection_type, model, info, args, resolved): + session = get_session(info.context) if resolved is None: - resolved = cls.get_query(model, info, **args) + if isinstance(session, AsyncSession): + resolved = ( + await session.scalars(cls.get_query(model, info, **args)) + ).all() + else: + resolved = cls.get_query(model, info, **args) if isinstance(resolved, Query): _len = resolved.count() else: @@ -111,7 +117,11 @@ def __init__(self, type_, *args, **kwargs): @classmethod def get_query(cls, model, info, sort=None, **args): - query = get_query(model, info.context) + session = get_session(info.context) + if isinstance(session, AsyncSession): + query = select(model) + else: + query = get_query(model, info.context) if sort is not None: if not isinstance(sort, list): sort = [sort] @@ -148,7 +158,11 @@ def wrap_resolve(self, parent_resolver): def from_relationship(cls, relationship, registry, **field_kwargs): model = relationship.mapper.entity model_type = registry.get_type_for_model(model) - return cls(model_type.connection, resolver=get_batch_resolver(relationship), **field_kwargs) + return cls( + model_type.connection, + resolver=get_batch_resolver(relationship), + **field_kwargs, + ) def default_connection_field_factory(relationship, registry, **field_kwargs): @@ -163,8 +177,8 @@ def default_connection_field_factory(relationship, registry, **field_kwargs): def createConnectionField(type_, **field_kwargs): warnings.warn( - 'createConnectionField is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "createConnectionField is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) return __connectionFactory(type_, **field_kwargs) @@ -172,8 +186,8 @@ def createConnectionField(type_, **field_kwargs): def registerConnectionFieldFactory(factoryMethod): warnings.warn( - 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "registerConnectionFieldFactory is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) global __connectionFactory @@ -182,8 +196,8 @@ def registerConnectionFieldFactory(factoryMethod): def unregisterConnectionFieldFactory(): warnings.warn( - 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + "registerConnectionFieldFactory is deprecated and will be removed in the next " + "major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.", DeprecationWarning, ) global __connectionFactory diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 34ba9d8a..1722224d 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,5 +1,6 @@ import pytest from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker import graphene @@ -8,8 +9,6 @@ from ..registry import reset_global_registry from .models import Base, CompositeFullName -test_db_url = 'sqlite://' # use in-memory database for tests - @pytest.fixture(autouse=True) def reset_registry(): @@ -22,16 +21,35 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.fixture(scope="function") -def session_factory(): - engine = create_engine(test_db_url) - Base.metadata.create_all(engine) +@pytest.fixture(params=[False, True]) +def async_session(request): + return request.param + - yield sessionmaker(bind=engine) +@pytest.fixture +def test_db_url(async_session: bool): + if async_session: + return "sqlite+aiosqlite://" + else: + return "sqlite://" - # SQLite in-memory db is deleted when its connection is closed. - # https://www.sqlite.org/inmemorydb.html - engine.dispose() + +@pytest.mark.asyncio +@pytest.fixture(scope="function") +async def session_factory(async_session: bool, test_db_url: str): + if async_session: + engine = create_async_engine(test_db_url) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) + await engine.dispose() + else: + engine = create_engine(test_db_url) + Base.metadata.create_all(engine) + yield sessionmaker(bind=engine, expire_on_commit=False) + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() @pytest.fixture(scope="function") diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index e41adb51..ff8123f5 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -15,8 +15,8 @@ class HairKind(enum.Enum): - LONG = 'long' - SHORT = 'short' + LONG = "long" + SHORT = "short" Base = declarative_base() @@ -64,9 +64,15 @@ class Reporter(Base): last_name = Column(String(30), doc="Last name") email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) - pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id") - articles = relationship("Article", backref="reporter") - favorite_article = relationship("Article", uselist=False) + pets = relationship( + "Pet", + secondary=association_table, + backref="reporters", + order_by="Pet.id", + lazy="joined", + ) + articles = relationship("Article", backref="reporter", lazy="joined") + favorite_article = relationship("Article", uselist=False, lazy="joined") @hybrid_property def hybrid_prop_with_doc(self): @@ -137,7 +143,7 @@ class ShoppingCartItem(Base): id = Column(Integer(), primary_key=True) @hybrid_property - def hybrid_prop_shopping_cart(self) -> List['ShoppingCart']: + def hybrid_prop_shopping_cart(self) -> List["ShoppingCart"]: return [ShoppingCart(id=1)] @@ -192,11 +198,17 @@ def hybrid_prop_list_date(self) -> List[datetime.date]: @hybrid_property def hybrid_prop_nested_list_int(self) -> List[List[int]]: - return [self.hybrid_prop_list_int, ] + return [ + self.hybrid_prop_list_int, + ] @hybrid_property def hybrid_prop_deeply_nested_list_int(self) -> List[List[List[int]]]: - return [[self.hybrid_prop_list_int, ], ] + return [ + [ + self.hybrid_prop_list_int, + ], + ] # Other SQLAlchemy Instances @hybrid_property @@ -216,15 +228,15 @@ def hybrid_prop_unsupported_type_tuple(self) -> Tuple[str, str]: # Self-references @hybrid_property - def hybrid_prop_self_referential(self) -> 'ShoppingCart': + def hybrid_prop_self_referential(self) -> "ShoppingCart": return ShoppingCart(id=1) @hybrid_property - def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: + def hybrid_prop_self_referential_list(self) -> List["ShoppingCart"]: return [ShoppingCart(id=1)] # Optional[T] @hybrid_property - def hybrid_prop_optional_self_referential(self) -> Optional['ShoppingCart']: + def hybrid_prop_optional_self_referential(self) -> Optional["ShoppingCart"]: return None diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 1896900b..5f8c7695 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -3,6 +3,8 @@ import logging import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession import graphene from graphene import relay @@ -10,13 +12,14 @@ from ..fields import (BatchSQLAlchemyConnectionField, default_connection_field_factory) from ..types import ORMField, SQLAlchemyObjectType -from ..utils import is_sqlalchemy_version_less_than +from ..utils import get_session, is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter from .utils import remove_cache_miss_stat, to_std_dicts class MockLoggingHandler(logging.Handler): """Intercept and store log messages in a list.""" + def __init__(self, *args, **kwargs): self.messages = [] logging.Handler.__init__(self, *args, **kwargs) @@ -28,7 +31,7 @@ def emit(self, record): @contextlib.contextmanager def mock_sqlalchemy_logging_handler(): logging.basicConfig() - sql_logger = logging.getLogger('sqlalchemy.engine') + sql_logger = logging.getLogger("sqlalchemy.engine") previous_level = sql_logger.level sql_logger.setLevel(logging.INFO) @@ -64,17 +67,30 @@ class Query(graphene.ObjectType): articles = graphene.Field(graphene.List(ArticleType)) reporters = graphene.Field(graphene.List(ReporterType)) - def resolve_articles(self, info): - return info.context.get('session').query(Article).all() + async def resolve_articles(self, info): + session = get_session(info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Article))).all() + return session.query(Article).all() - def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + async def resolve_reporters(self, info): + session = get_session(info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).all() + return session.query(Reporter).all() return graphene.Schema(query=Query) -if is_sqlalchemy_version_less_than('1.2'): - pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) + + +async def eventually_await_session(session, func, *args): + if isinstance(session, AsyncSession): + await getattr(session, func)(*args) + else: + getattr(session, func)(*args) @pytest.mark.asyncio @@ -82,31 +98,32 @@ async def test_many_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_schema() with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { articles { headline @@ -115,20 +132,26 @@ async def test_many_to_one(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN reporters' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN reporters" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -138,20 +161,20 @@ async def test_many_to_one(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "articles": [ - { - "headline": "Article_1", - "reporter": { - "firstName": "Reporter_1", - }, - }, - { - "headline": "Article_2", - "reporter": { - "firstName": "Reporter_2", - }, - }, - ], + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], } @@ -160,19 +183,19 @@ async def test_one_to_one(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) @@ -184,7 +207,8 @@ async def test_one_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -193,20 +217,26 @@ async def test_one_to_one(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -216,20 +246,20 @@ async def test_one_to_one(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "favoriteArticle": { - "headline": "Article_1", - }, - }, - { - "firstName": "Reporter_2", - "favoriteArticle": { - "headline": "Article_2", - }, - }, - ], + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], } @@ -238,27 +268,27 @@ async def test_one_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_1 session.add(article_2) - article_3 = Article(headline='Article_3') + article_3 = Article(headline="Article_3") article_3.reporter = reporter_2 session.add(article_3) - article_4 = Article(headline='Article_4') + article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) @@ -270,7 +300,8 @@ async def test_one_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -283,20 +314,26 @@ async def test_one_to_many(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -306,42 +343,42 @@ async def test_one_to_many(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "articles": { - "edges": [ - { - "node": { - "headline": "Article_1", - }, - }, - { - "node": { - "headline": "Article_2", + "reporters": [ + { + "firstName": "Reporter_1", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_1", + }, + }, + { + "node": { + "headline": "Article_2", + }, + }, + ], }, - }, - ], - }, - }, - { - "firstName": "Reporter_2", - "articles": { - "edges": [ - { - "node": { - "headline": "Article_3", + }, + { + "firstName": "Reporter_2", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_3", + }, + }, + { + "node": { + "headline": "Article_4", + }, + }, + ], }, - }, - { - "node": { - "headline": "Article_4", - }, - }, - ], - }, - }, - ], + }, + ], } @@ -350,27 +387,27 @@ async def test_many_to_many(session_factory): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + pet_1 = Pet(name="Pet_1", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_1) - pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + pet_2 = Pet(name="Pet_2", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_2) reporter_1.pets.append(pet_1) reporter_1.pets.append(pet_2) - pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + pet_3 = Pet(name="Pet_3", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_3) - pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + pet_4 = Pet(name="Pet_4", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_4) reporter_2.pets.append(pet_3) @@ -384,7 +421,8 @@ async def test_many_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - result = await schema.execute_async(""" + result = await schema.execute_async( + """ query { reporters { firstName @@ -397,20 +435,26 @@ async def test_many_to_many(session_factory): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN pets' in message] + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN pets" in message + ] assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than('1.4'): + if not is_sqlalchemy_version_less_than("1.4"): messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -420,50 +464,50 @@ async def test_many_to_many(session_factory): assert not result.errors result = to_std_dicts(result.data) assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "pets": { - "edges": [ - { - "node": { - "name": "Pet_1", - }, - }, - { - "node": { - "name": "Pet_2", - }, - }, - ], - }, - }, - { - "firstName": "Reporter_2", - "pets": { - "edges": [ - { - "node": { - "name": "Pet_3", + "reporters": [ + { + "firstName": "Reporter_1", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_1", + }, + }, + { + "node": { + "name": "Pet_2", + }, + }, + ], }, - }, - { - "node": { - "name": "Pet_4", + }, + { + "firstName": "Reporter_2", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_3", + }, + }, + { + "node": { + "name": "Pet_4", + }, + }, + ], }, - }, - ], - }, - }, - ], + }, + ], } def test_disable_batching_via_ormfield(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -486,7 +530,7 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) @@ -494,7 +538,8 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { favoriteArticle { @@ -502,17 +547,24 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 # Test one-to-many and many-to-many relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { articles { @@ -524,19 +576,25 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 @pytest.mark.asyncio async def test_connection_factory_field_overrides_batching_is_false(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -559,14 +617,15 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - await schema.execute_async(""" + await schema.execute_async( + """ query { reporters { articles { @@ -578,24 +637,34 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - if is_sqlalchemy_version_less_than('1.3'): + if is_sqlalchemy_version_less_than("1.3"): # The batched SQL statement generated is different in 1.2.x # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` # See https://git.io/JewQu - select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] else: - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 1 def test_connection_factory_field_overrides_batching_is_true(session_factory): session = session_factory() - reporter_1 = Reporter(first_name='Reporter_1') + reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) - reporter_2 = Reporter(first_name='Reporter_2') + reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) session.commit() session.close() @@ -618,14 +687,15 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + return info.context.get("session").query(Reporter).all() schema = graphene.Schema(query=Query) with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute(""" + schema.execute( + """ query { reporters { articles { @@ -637,8 +707,14 @@ def resolve_reporters(self, info): } } } - """, context_value={"session": session}) + """, + context_value={"session": session}, + ) messages = sqlalchemy_logging_handler.messages - select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + select_statements = [ + message + for message in messages + if "SELECT" in message and "FROM articles" in message + ] assert len(select_statements) == 2 diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 11e9d0e0..77e57db0 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,14 +1,17 @@ import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession import graphene from graphene import relay from ..types import SQLAlchemyObjectType -from ..utils import is_sqlalchemy_version_less_than +from ..utils import get_session, is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter +from .utils import eventually_await_session -if is_sqlalchemy_version_less_than('1.2'): - pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) def get_schema(): @@ -31,51 +34,61 @@ class Query(graphene.ObjectType): articles = graphene.Field(graphene.List(ArticleType)) reporters = graphene.Field(graphene.List(ReporterType)) - def resolve_articles(self, info): - return info.context.get('session').query(Article).all() + async def resolve_articles(self, info): + session = get_session(info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Article))).all() + return session.query(Article).all() - def resolve_reporters(self, info): - return info.context.get('session').query(Reporter).all() + async def resolve_reporters(self, info): + session = get_session(info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).all() + return session.query(Reporter).all() return graphene.Schema(query=Query) -def benchmark_query(session_factory, benchmark, query): +async def benchmark_query(session_factory, benchmark, query): schema = get_schema() @benchmark - def execute_query(): - result = schema.execute( - query, - context_value={"session": session_factory()}, + async def execute_query(): + result = await schema.execute_async( + query, + context_value={"session": session_factory()}, ) assert not result.errors -def test_one_to_one(session_factory, benchmark): +@pytest.mark.asyncio +async def test_one_to_one(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query(session_factory, benchmark, """ + await benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -84,33 +97,37 @@ def test_one_to_one(session_factory, benchmark): } } } - """) + """, + ) def test_many_to_one(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { articles { headline @@ -119,41 +136,45 @@ def test_many_to_one(session_factory, benchmark): } } } - """) + """, + ) def test_one_to_many(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - article_1 = Article(headline='Article_1') + article_1 = Article(headline="Article_1") article_1.reporter = reporter_1 session.add(article_1) - article_2 = Article(headline='Article_2') + article_2 = Article(headline="Article_2") article_2.reporter = reporter_1 session.add(article_2) - article_3 = Article(headline='Article_3') + article_3 = Article(headline="Article_3") article_3.reporter = reporter_2 session.add(article_3) - article_4 = Article(headline='Article_4') + article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -166,34 +187,35 @@ def test_one_to_many(session_factory, benchmark): } } } - """) + """, + ) def test_many_to_many(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( - first_name='Reporter_1', + first_name="Reporter_1", ) session.add(reporter_1) reporter_2 = Reporter( - first_name='Reporter_2', + first_name="Reporter_2", ) session.add(reporter_2) - pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + pet_1 = Pet(name="Pet_1", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_1) - pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + pet_2 = Pet(name="Pet_2", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_2) reporter_1.pets.append(pet_1) reporter_1.pets.append(pet_2) - pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + pet_3 = Pet(name="Pet_3", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_3) - pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + pet_4 = Pet(name="Pet_4", pet_kind="cat", hair_kind=HairKind.LONG) session.add(pet_4) reporter_2.pets.append(pet_3) @@ -202,7 +224,10 @@ def test_many_to_many(session_factory, benchmark): session.commit() session.close() - benchmark_query(session_factory, benchmark, """ + benchmark_query( + session_factory, + benchmark, + """ query { reporters { firstName @@ -215,4 +240,5 @@ def test_many_to_many(session_factory, benchmark): } } } - """) + """, + ) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 39140814..ae1c1c78 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,36 +1,40 @@ +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + import graphene from graphene.relay import Node from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField from ..types import ORMField, SQLAlchemyObjectType +from ..utils import get_session from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter -from .utils import to_std_dicts +from .utils import eventually_await_session, to_std_dicts -def add_test_data(session): - reporter = Reporter( - first_name='John', last_name='Doe', favorite_pet_kind='cat') +async def add_test_data(session): + reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) - pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT) session.add(pet) pet.reporters.append(reporter) - article = Article(headline='Hi!') + article = Article(headline="Hi!") article.reporter = reporter session.add(article) - reporter = Reporter( - first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + reporter = Reporter(first_name="Jane", last_name="Roe", favorite_pet_kind="dog") session.add(reporter) - pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG) pet.reporters.append(reporter) session.add(pet) editor = Editor(name="Jack") session.add(editor) - session.commit() + await eventually_await_session(session, "commit") -def test_query_fields(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_fields(session): + await add_test_data(session) @convert_sqlalchemy_composite.register(CompositeFullName) def convert_composite_class(composite, registry): @@ -44,10 +48,16 @@ class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - def resolve_reporters(self, _info): + async def resolve_reporters(self, _info): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) query = """ @@ -73,14 +83,15 @@ def resolve_reporters(self, _info): "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_query_node(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_node(session): + await add_test_data(session) class ReporterNode(SQLAlchemyObjectType): class Meta: @@ -101,7 +112,10 @@ class Query(graphene.ObjectType): reporter = graphene.Field(ReporterNode) all_articles = SQLAlchemyConnectionField(ArticleNode.connection) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).first() return session.query(Reporter).first() query = """ @@ -145,14 +159,15 @@ def resolve_reporter(self, _info): "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_orm_field(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_orm_field(session): + await add_test_data(session) @convert_sqlalchemy_composite.register(CompositeFullName) def convert_composite_class(composite, registry): @@ -163,12 +178,12 @@ class Meta: model = Reporter interfaces = (Node,) - first_name_v2 = ORMField(model_attr='first_name') - hybrid_prop_v2 = ORMField(model_attr='hybrid_prop') - column_prop_v2 = ORMField(model_attr='column_prop') + first_name_v2 = ORMField(model_attr="first_name") + hybrid_prop_v2 = ORMField(model_attr="hybrid_prop") + column_prop_v2 = ORMField(model_attr="column_prop") composite_prop = ORMField() - favorite_article_v2 = ORMField(model_attr='favorite_article') - articles_v2 = ORMField(model_attr='articles') + favorite_article_v2 = ORMField(model_attr="favorite_article") + articles_v2 = ORMField(model_attr="articles") class ArticleType(SQLAlchemyObjectType): class Meta: @@ -178,7 +193,10 @@ class Meta: class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).first() return session.query(Reporter).first() query = """ @@ -212,14 +230,15 @@ def resolve_reporter(self, _info): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_custom_identifier(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_custom_identifier(session): + await add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -253,14 +272,15 @@ class Query(graphene.ObjectType): } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_mutation(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_mutation(session, session_factory): + await add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -273,8 +293,11 @@ class Meta: interfaces = (Node,) @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name="Cookie Monster") + async def get_node(cls, id, info): + session = get_session(info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() + return session.query(Reporter).first() class ArticleNode(SQLAlchemyObjectType): class Meta: @@ -289,11 +312,14 @@ class Arguments: ok = graphene.Boolean() article = graphene.Field(ArticleNode) - def mutate(self, info, headline, reporter_id): + async def mutate(self, info, headline, reporter_id): + reporter = await ReporterNode.get_node(reporter_id, info) new_article = Article(headline=headline, reporter_id=reporter_id) + reporter.articles = [*reporter.articles, new_article] + session = get_session(info.context) + session.add(reporter) - session.add(new_article) - session.commit() + await eventually_await_session(session, "commit") ok = True return CreateArticle(article=new_article, ok=ok) @@ -332,7 +358,9 @@ class Mutation(graphene.ObjectType): } schema = graphene.Schema(query=Query, mutation=Mutation) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async( + query, context_value={"session": session_factory()} + ) assert not result.errors result = to_std_dicts(result.data) assert result == expected diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index 5166c45f..0375da20 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -1,15 +1,22 @@ +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + import graphene +from graphene_sqlalchemy.tests.utils import eventually_await_session +from graphene_sqlalchemy.utils import get_session from ..types import SQLAlchemyObjectType from .models import HairKind, Pet, Reporter from .test_query import add_test_data, to_std_dicts -def test_query_pet_kinds(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_pet_kinds(session, session_factory): + await add_test_data(session) + await eventually_await_session(session, "close") class PetType(SQLAlchemyObjectType): - class Meta: model = Pet @@ -20,16 +27,29 @@ class Meta: class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - pets = graphene.List(PetType, kind=graphene.Argument( - PetType.enum_for_field('pet_kind'))) + pets = graphene.List( + PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) + ) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - def resolve_reporters(self, _info): + async def resolve_reporters(self, _info): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) - def resolve_pets(self, _info, kind): + async def resolve_pets(self, _info, kind): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + query = select(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind.value) + return (await session.scalars(query)).unique().all() query = session.query(Pet) if kind: query = query.filter_by(pet_kind=kind.value) @@ -58,36 +78,36 @@ def resolve_pets(self, _info, kind): } """ expected = { - 'reporter': { - 'firstName': 'John', - 'lastName': 'Doe', - 'email': None, - 'favoritePetKind': 'CAT', - 'pets': [{ - 'name': 'Garfield', - 'petKind': 'CAT' - }] + "reporter": { + "firstName": "John", + "lastName": "Doe", + "email": None, + "favoritePetKind": "CAT", + "pets": [{"name": "Garfield", "petKind": "CAT"}], }, - 'reporters': [{ - 'firstName': 'John', - 'favoritePetKind': 'CAT', - }, { - 'firstName': 'Jane', - 'favoritePetKind': 'DOG', - }], - 'pets': [{ - 'name': 'Lassie', - 'petKind': 'DOG' - }] + "reporters": [ + { + "firstName": "John", + "favoritePetKind": "CAT", + }, + { + "firstName": "Jane", + "favoritePetKind": "DOG", + }, + ], + "pets": [{"name": "Lassie", "petKind": "DOG"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async( + query, context_value={"session": session_factory()} + ) assert not result.errors assert result.data == expected -def test_query_more_enums(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_more_enums(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -96,7 +116,10 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field(PetType) - def resolve_pet(self, _info): + async def resolve_pet(self, _info): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Pet))).first() return session.query(Pet).first() query = """ @@ -110,14 +133,15 @@ def resolve_pet(self, _info): """ expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_enum_as_argument(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_enum_as_argument(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -125,10 +149,16 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field( - PetType, - kind=graphene.Argument(PetType.enum_for_field('pet_kind'))) + PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) + ) - def resolve_pet(self, info, kind=None): + async def resolve_pet(self, info, kind=None): + session = get_session(info.context) + if isinstance(session, AsyncSession): + query = select(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind.value) + return (await session.scalars(query)).first() query = session.query(Pet) if kind: query = query.filter(Pet.pet_kind == kind.value) @@ -145,19 +175,24 @@ def resolve_pet(self, info, kind=None): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "CAT"}) + result = await schema.execute_async( + query, variables={"kind": "CAT"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} assert result.data == expected - result = schema.execute(query, variables={"kind": "DOG"}) + result = await schema.execute_async( + query, variables={"kind": "DOG"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} result = to_std_dicts(result.data) assert result == expected -def test_py_enum_as_argument(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_py_enum_as_argument(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -169,7 +204,14 @@ class Query(graphene.ObjectType): kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type), ) - def resolve_pet(self, _info, kind=None): + async def resolve_pet(self, _info, kind=None): + session = get_session(_info.context) + if isinstance(session, AsyncSession): + return ( + await session.scalars( + select(Pet).filter(Pet.hair_kind == HairKind(kind)) + ) + ).first() query = session.query(Pet) if kind: # enum arguments are expected to be strings, not PyEnums @@ -187,11 +229,15 @@ def resolve_pet(self, _info, kind=None): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "SHORT"}) + result = await schema.execute_async( + query, variables={"kind": "SHORT"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} assert result.data == expected - result = schema.execute(query, variables={"kind": "LONG"}) + result = await schema.execute_async( + query, variables={"kind": "LONG"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} result = to_std_dicts(result.data) diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index 6291d4f8..0fd769b6 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -9,16 +9,17 @@ from ..utils import to_type_name from .models import Base, HairKind, Pet from .test_query import to_std_dicts +from .utils import eventually_await_session -def add_pets(session): +async def add_pets(session): pets = [ Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), ] session.add_all(pets) - session.commit() + await eventually_await_session(session, "commit") def test_sort_enum(): @@ -241,8 +242,9 @@ def get_symbol_name(column_name, sort_asc=True): assert sort_arg.default_value == ["IdUp"] -def test_sort_query(session): - add_pets(session) +@pytest.mark.asyncio +async def test_sort_query(session): + await add_pets(session) class PetNode(SQLAlchemyObjectType): class Meta: @@ -315,9 +317,7 @@ def makeNodes(nodeList): return {"edges": nodes} expected = { - "defaultSort": makeNodes( - [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] - ), + "defaultSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), "noDefaultSort": makeNodes( [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] @@ -336,7 +336,7 @@ def makeNodes(nodeList): } # yapf: disable schema = Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -352,9 +352,9 @@ def makeNodes(nodeList): } } """ - result = schema.execute(queryError, context_value={"session": session}) + result = await schema.execute_async(queryError, context_value={"session": session}) assert result.errors is not None - assert 'cannot represent non-enum value' in result.errors[0].message + assert "cannot represent non-enum value" in result.errors[0].message queryNoSort = """ query sortTest { @@ -375,7 +375,7 @@ def makeNodes(nodeList): } """ - result = schema.execute(queryNoSort, context_value={"session": session}) + result = await schema.execute_async(queryNoSort, context_value={"session": session}) assert not result.errors # TODO: SQLite usually returns the results ordered by primary key, # so we cannot test this way whether sorting actually happens or not. diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 9a2e992d..a8736806 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -3,6 +3,8 @@ import pytest import sqlalchemy.exc import sqlalchemy.orm.exc +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from graphene import (Boolean, Dynamic, Field, Float, GlobalID, Int, List, Node, NonNull, ObjectType, Schema, String) @@ -16,11 +18,13 @@ unregisterConnectionFieldFactory) from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions from .models import Article, CompositeFullName, Pet, Reporter +from .utils import eventually_await_session def test_should_raise_if_no_model(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character1(SQLAlchemyObjectType): pass @@ -28,12 +32,14 @@ class Character1(SQLAlchemyObjectType): def test_should_raise_if_model_is_invalid(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character(SQLAlchemyObjectType): class Meta: model = 1 -def test_sqlalchemy_node(session): +@pytest.mark.asyncioas +async def test_sqlalchemy_node(session): class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -44,8 +50,8 @@ class Meta: reporter = Reporter() session.add(reporter) - session.commit() - info = mock.Mock(context={'session': session}) + await eventually_await_session(session, "commit") + info = mock.Mock(context={"session": session}) reporter_node = ReporterType.get_node(info, reporter.id) assert reporter == reporter_node @@ -74,91 +80,93 @@ class Meta: model = Article interfaces = (Node,) - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - # Columns - "column_prop", # SQLAlchemy retuns column properties first - "id", - "first_name", - "last_name", - "email", - "favorite_pet_kind", - # Composite - "composite_prop", - # Hybrid - "hybrid_prop_with_doc", - "hybrid_prop", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - # Relationship - "pets", - "articles", - "favorite_article", - ]) + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + # Columns + "column_prop", # SQLAlchemy retuns column properties first + "id", + "first_name", + "last_name", + "email", + "favorite_pet_kind", + # Composite + "composite_prop", + # Hybrid + "hybrid_prop_with_doc", + "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + # Relationship + "pets", + "articles", + "favorite_article", + ] + ) # column - first_name_field = ReporterType._meta.fields['first_name'] + first_name_field = ReporterType._meta.fields["first_name"] assert first_name_field.type == String assert first_name_field.description == "First name" # column_property - column_prop_field = ReporterType._meta.fields['column_prop'] + column_prop_field = ReporterType._meta.fields["column_prop"] assert column_prop_field.type == Int # "doc" is ignored by column_property assert column_prop_field.description is None # composite - full_name_field = ReporterType._meta.fields['composite_prop'] + full_name_field = ReporterType._meta.fields["composite_prop"] assert full_name_field.type == String # "doc" is ignored by composite assert full_name_field.description is None # hybrid_property - hybrid_prop = ReporterType._meta.fields['hybrid_prop'] + hybrid_prop = ReporterType._meta.fields["hybrid_prop"] assert hybrid_prop.type == String # "doc" is ignored by hybrid_property assert hybrid_prop.description is None # hybrid_property_str - hybrid_prop_str = ReporterType._meta.fields['hybrid_prop_str'] + hybrid_prop_str = ReporterType._meta.fields["hybrid_prop_str"] assert hybrid_prop_str.type == String # "doc" is ignored by hybrid_property assert hybrid_prop_str.description is None # hybrid_property_int - hybrid_prop_int = ReporterType._meta.fields['hybrid_prop_int'] + hybrid_prop_int = ReporterType._meta.fields["hybrid_prop_int"] assert hybrid_prop_int.type == Int # "doc" is ignored by hybrid_property assert hybrid_prop_int.description is None # hybrid_property_float - hybrid_prop_float = ReporterType._meta.fields['hybrid_prop_float'] + hybrid_prop_float = ReporterType._meta.fields["hybrid_prop_float"] assert hybrid_prop_float.type == Float # "doc" is ignored by hybrid_property assert hybrid_prop_float.description is None # hybrid_property_bool - hybrid_prop_bool = ReporterType._meta.fields['hybrid_prop_bool'] + hybrid_prop_bool = ReporterType._meta.fields["hybrid_prop_bool"] assert hybrid_prop_bool.type == Boolean # "doc" is ignored by hybrid_property assert hybrid_prop_bool.description is None # hybrid_property_list - hybrid_prop_list = ReporterType._meta.fields['hybrid_prop_list'] + hybrid_prop_list = ReporterType._meta.fields["hybrid_prop_list"] assert hybrid_prop_list.type == List(Int) # "doc" is ignored by hybrid_property assert hybrid_prop_list.description is None # hybrid_prop_with_doc - hybrid_prop_with_doc = ReporterType._meta.fields['hybrid_prop_with_doc'] + hybrid_prop_with_doc = ReporterType._meta.fields["hybrid_prop_with_doc"] assert hybrid_prop_with_doc.type == String # docstring is picked up from hybrid_prop_with_doc assert hybrid_prop_with_doc.description == "Docstring test" # relationship - favorite_article_field = ReporterType._meta.fields['favorite_article'] + favorite_article_field = ReporterType._meta.fields["favorite_article"] assert isinstance(favorite_article_field, Dynamic) assert favorite_article_field.type().type == ArticleType assert favorite_article_field.type().description is None @@ -172,7 +180,7 @@ def convert_composite_class(composite, registry): class ReporterMixin(object): # columns first_name = ORMField(required=True) - last_name = ORMField(description='Overridden') + last_name = ORMField(description="Overridden") class ReporterType(SQLAlchemyObjectType, ReporterMixin): class Meta: @@ -180,8 +188,8 @@ class Meta: interfaces = (Node,) # columns - email = ORMField(deprecation_reason='Overridden') - email_v2 = ORMField(model_attr='email', type_=Int) + email = ORMField(deprecation_reason="Overridden") + email_v2 = ORMField(model_attr="email", type_=Int) # column_property column_prop = ORMField(type_=String) @@ -190,13 +198,13 @@ class Meta: composite_prop = ORMField() # hybrid_property - hybrid_prop_with_doc = ORMField(description='Overridden') - hybrid_prop = ORMField(description='Overridden') + hybrid_prop_with_doc = ORMField(description="Overridden") + hybrid_prop = ORMField(description="Overridden") # relationships - favorite_article = ORMField(description='Overridden') - articles = ORMField(deprecation_reason='Overridden') - pets = ORMField(description='Overridden') + favorite_article = ORMField(description="Overridden") + articles = ORMField(deprecation_reason="Overridden") + pets = ORMField(description="Overridden") class ArticleType(SQLAlchemyObjectType): class Meta: @@ -209,99 +217,101 @@ class Meta: interfaces = (Node,) use_connection = False - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - # Fields from ReporterMixin - "first_name", - "last_name", - # Fields from ReporterType - "email", - "email_v2", - "column_prop", - "composite_prop", - "hybrid_prop_with_doc", - "hybrid_prop", - "favorite_article", - "articles", - "pets", - # Then the automatic SQLAlchemy fields - "id", - "favorite_pet_kind", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - ]) - - first_name_field = ReporterType._meta.fields['first_name'] + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + # Fields from ReporterMixin + "first_name", + "last_name", + # Fields from ReporterType + "email", + "email_v2", + "column_prop", + "composite_prop", + "hybrid_prop_with_doc", + "hybrid_prop", + "favorite_article", + "articles", + "pets", + # Then the automatic SQLAlchemy fields + "id", + "favorite_pet_kind", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + ] + ) + + first_name_field = ReporterType._meta.fields["first_name"] assert isinstance(first_name_field.type, NonNull) assert first_name_field.type.of_type == String assert first_name_field.description == "First name" assert first_name_field.deprecation_reason is None - last_name_field = ReporterType._meta.fields['last_name'] + last_name_field = ReporterType._meta.fields["last_name"] assert last_name_field.type == String assert last_name_field.description == "Overridden" assert last_name_field.deprecation_reason is None - email_field = ReporterType._meta.fields['email'] + email_field = ReporterType._meta.fields["email"] assert email_field.type == String assert email_field.description == "Email" assert email_field.deprecation_reason == "Overridden" - email_field_v2 = ReporterType._meta.fields['email_v2'] + email_field_v2 = ReporterType._meta.fields["email_v2"] assert email_field_v2.type == Int assert email_field_v2.description == "Email" assert email_field_v2.deprecation_reason is None - hybrid_prop_field = ReporterType._meta.fields['hybrid_prop'] + hybrid_prop_field = ReporterType._meta.fields["hybrid_prop"] assert hybrid_prop_field.type == String assert hybrid_prop_field.description == "Overridden" assert hybrid_prop_field.deprecation_reason is None - hybrid_prop_with_doc_field = ReporterType._meta.fields['hybrid_prop_with_doc'] + hybrid_prop_with_doc_field = ReporterType._meta.fields["hybrid_prop_with_doc"] assert hybrid_prop_with_doc_field.type == String assert hybrid_prop_with_doc_field.description == "Overridden" assert hybrid_prop_with_doc_field.deprecation_reason is None - column_prop_field_v2 = ReporterType._meta.fields['column_prop'] + column_prop_field_v2 = ReporterType._meta.fields["column_prop"] assert column_prop_field_v2.type == String assert column_prop_field_v2.description is None assert column_prop_field_v2.deprecation_reason is None - composite_prop_field = ReporterType._meta.fields['composite_prop'] + composite_prop_field = ReporterType._meta.fields["composite_prop"] assert composite_prop_field.type == String assert composite_prop_field.description is None assert composite_prop_field.deprecation_reason is None - favorite_article_field = ReporterType._meta.fields['favorite_article'] + favorite_article_field = ReporterType._meta.fields["favorite_article"] assert isinstance(favorite_article_field, Dynamic) assert favorite_article_field.type().type == ArticleType - assert favorite_article_field.type().description == 'Overridden' + assert favorite_article_field.type().description == "Overridden" - articles_field = ReporterType._meta.fields['articles'] + articles_field = ReporterType._meta.fields["articles"] assert isinstance(articles_field, Dynamic) assert isinstance(articles_field.type(), UnsortedSQLAlchemyConnectionField) assert articles_field.type().deprecation_reason == "Overridden" - pets_field = ReporterType._meta.fields['pets'] + pets_field = ReporterType._meta.fields["pets"] assert isinstance(pets_field, Dynamic) assert isinstance(pets_field.type().type, List) assert pets_field.type().type.of_type == PetType - assert pets_field.type().description == 'Overridden' + assert pets_field.type().description == "Overridden" def test_invalid_model_attr(): err_msg = ( - "Cannot map ORMField to a model attribute.\n" - "Field: 'ReporterType.first_name'" + "Cannot map ORMField to a model attribute.\n" "Field: 'ReporterType.first_name'" ) with pytest.raises(ValueError, match=err_msg): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter - first_name = ORMField(model_attr='does_not_exist') + first_name = ORMField(model_attr="does_not_exist") def test_only_fields(): @@ -325,29 +335,32 @@ class Meta: first_name = ORMField() # Takes precedence last_name = ORMField() # Noop - assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ - "first_name", - "last_name", - "column_prop", - "email", - "favorite_pet_kind", - "composite_prop", - "hybrid_prop_with_doc", - "hybrid_prop", - "hybrid_prop_str", - "hybrid_prop_int", - "hybrid_prop_float", - "hybrid_prop_bool", - "hybrid_prop_list", - "pets", - "articles", - "favorite_article", - ]) + assert sorted(list(ReporterType._meta.fields.keys())) == sorted( + [ + "first_name", + "last_name", + "column_prop", + "email", + "favorite_pet_kind", + "composite_prop", + "hybrid_prop_with_doc", + "hybrid_prop", + "hybrid_prop_str", + "hybrid_prop_int", + "hybrid_prop_float", + "hybrid_prop_bool", + "hybrid_prop_list", + "pets", + "articles", + "favorite_article", + ] + ) def test_only_and_exclude_fields(): re_err = r"'only_fields' and 'exclude_fields' cannot be both set" with pytest.raises(Exception, match=re_err): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -367,19 +380,29 @@ class Meta: assert first_name_field.type == Int -def test_resolvers(session): +@pytest.mark.asyncio +async def test_resolvers(session): """Test that the correct resolver functions are called""" + reporter = Reporter( + first_name="first_name", + last_name="last_name", + email="email", + favorite_pet_kind="cat", + ) + session.add(reporter) + await eventually_await_session(session, "commit") + class ReporterMixin(object): def resolve_id(root, _info): - return 'ID' + return "ID" class ReporterType(ReporterMixin, SQLAlchemyObjectType): class Meta: model = Reporter email = ORMField() - email_v2 = ORMField(model_attr='email') + email_v2 = ORMField(model_attr="email") favorite_pet_kind = Field(String) favorite_pet_kind_v2 = Field(String) @@ -387,23 +410,23 @@ def resolve_last_name(root, _info): return root.last_name.upper() def resolve_email_v2(root, _info): - return root.email + '_V2' + return root.email + "_V2" def resolve_favorite_pet_kind_v2(root, _info): - return str(root.favorite_pet_kind) + '_V2' + return str(root.favorite_pet_kind) + "_V2" class Query(ObjectType): reporter = Field(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = utils.get_session(_info.context) + if isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - reporter = Reporter(first_name='first_name', last_name='last_name', email='email', favorite_pet_kind='cat') - session.add(reporter) - session.commit() - schema = Schema(query=Query) - result = schema.execute(""" + result = await schema.execute_async( + """ query { reporter { id @@ -415,27 +438,30 @@ def resolve_reporter(self, _info): favoritePetKindV2 } } - """) + """, + context_value={"session": session}, + ) assert not result.errors # Custom resolver on a base class - assert result.data['reporter']['id'] == 'ID' + assert result.data["reporter"]["id"] == "ID" # Default field + default resolver - assert result.data['reporter']['firstName'] == 'first_name' + assert result.data["reporter"]["firstName"] == "first_name" # Default field + custom resolver - assert result.data['reporter']['lastName'] == 'LAST_NAME' + assert result.data["reporter"]["lastName"] == "LAST_NAME" # ORMField + default resolver - assert result.data['reporter']['email'] == 'email' + assert result.data["reporter"]["email"] == "email" # ORMField + custom resolver - assert result.data['reporter']['emailV2'] == 'email_V2' + assert result.data["reporter"]["emailV2"] == "email_V2" # Field + default resolver - assert result.data['reporter']['favoritePetKind'] == 'cat' + assert result.data["reporter"]["favoritePetKind"] == "cat" # Field + custom resolver - assert result.data['reporter']['favoritePetKindV2'] == 'cat_V2' + assert result.data["reporter"]["favoritePetKindV2"] == "cat_V2" # Test Custom SQLAlchemyObjectType Implementation + def test_custom_objecttype_registered(): class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): class Meta: @@ -479,6 +505,7 @@ class Meta: # Tests for connection_field_factory + class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField): pass @@ -494,7 +521,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), UnsortedSQLAlchemyConnectionField + ) def test_custom_connection_field_factory(): @@ -514,7 +543,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_registerConnectionFieldFactory(): @@ -531,7 +562,9 @@ class Meta: model = Article interfaces = (Node,) - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_unregisterConnectionFieldFactory(): @@ -549,7 +582,9 @@ class Meta: model = Article interfaces = (Node,) - assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + assert not isinstance( + ReporterType._meta.fields["articles"].type(), _TestSQLAlchemyConnectionField + ) def test_deprecated_createConnectionField(): @@ -557,7 +592,7 @@ def test_deprecated_createConnectionField(): createConnectionField(None) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_unique_errors_propagate(class_mapper_mock): # Define unique error to detect class UniqueError(Exception): @@ -569,9 +604,11 @@ class UniqueError(Exception): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleOne(SQLAlchemyObjectType): class Meta(object): model = Article + except UniqueError as e: error = e @@ -580,7 +617,7 @@ class Meta(object): assert isinstance(error, UniqueError) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_argument_errors_propagate(class_mapper_mock): # Mock class_mapper effect class_mapper_mock.side_effect = sqlalchemy.exc.ArgumentError @@ -588,9 +625,11 @@ def test_argument_errors_propagate(class_mapper_mock): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleTwo(SQLAlchemyObjectType): class Meta(object): model = Article + except sqlalchemy.exc.ArgumentError as e: error = e @@ -599,7 +638,7 @@ class Meta(object): assert isinstance(error, sqlalchemy.exc.ArgumentError) -@mock.patch(utils.__name__ + '.class_mapper') +@mock.patch(utils.__name__ + ".class_mapper") def test_unmapped_errors_reformat(class_mapper_mock): # Mock class_mapper effect class_mapper_mock.side_effect = sqlalchemy.orm.exc.UnmappedClassError(object) @@ -607,9 +646,11 @@ def test_unmapped_errors_reformat(class_mapper_mock): # Make sure that errors are propagated from class_mapper when instantiating new classes error = None try: + class ArticleThree(SQLAlchemyObjectType): class Meta(object): model = Article + except ValueError as e: error = e diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index c90ee476..42960a32 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,5 +1,7 @@ import re +from sqlalchemy.ext.asyncio import AsyncSession + def to_std_dicts(value): """Convert nested ordered dicts to normal dicts for better comparison.""" @@ -15,3 +17,10 @@ def remove_cache_miss_stat(message): """Remove the stat from the echoed query message when the cache is missed for sqlalchemy version >= 1.4""" # https://github.com/sqlalchemy/sqlalchemy/blob/990eb3d8813369d3b8a7776ae85fb33627443d30/lib/sqlalchemy/engine/default.py#L1177 return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message) + + +async def eventually_await_session(session, func, *args): + if isinstance(session, AsyncSession): + await getattr(session, func)(*args) + else: + getattr(session, func)(*args) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index ac69b697..d48bfe10 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,6 +1,7 @@ from collections import OrderedDict import sqlalchemy +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import (ColumnProperty, CompositeProperty, RelationshipProperty) @@ -20,7 +21,7 @@ sort_enum_for_object_type) from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import get_query, is_mapped_class, is_mapped_instance +from .utils import get_query, get_session, is_mapped_class, is_mapped_instance class ORMField(OrderedType): @@ -76,20 +77,28 @@ class Meta: super(ORMField, self).__init__(_creation_counter=_creation_counter) # The is only useful for documentation and auto-completion common_kwargs = { - 'model_attr': model_attr, - 'type_': type_, - 'required': required, - 'description': description, - 'deprecation_reason': deprecation_reason, - 'batching': batching, + "model_attr": model_attr, + "type_": type_, + "required": required, + "description": description, + "deprecation_reason": deprecation_reason, + "batching": batching, + } + common_kwargs = { + kwarg: value for kwarg, value in common_kwargs.items() if value is not None } - common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} self.kwargs = field_kwargs self.kwargs.update(common_kwargs) def construct_fields( - obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory + obj_type, + model, + registry, + only_fields, + exclude_fields, + batching, + connection_field_factory, ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -110,17 +119,22 @@ def construct_fields( inspected_model = sqlalchemy.inspect(model) # Gather all the relevant attributes from the SQLAlchemy model in order all_model_attrs = OrderedDict( - inspected_model.column_attrs.items() + - inspected_model.composites.items() + - [(name, item) for name, item in inspected_model.all_orm_descriptors.items() - if isinstance(item, hybrid_property)] + - inspected_model.relationships.items() + inspected_model.column_attrs.items() + + inspected_model.composites.items() + + [ + (name, item) + for name, item in inspected_model.all_orm_descriptors.items() + if isinstance(item, hybrid_property) + ] + + inspected_model.relationships.items() ) # Filter out excluded fields auto_orm_field_names = [] for attr_name, attr in all_model_attrs.items(): - if (only_fields and attr_name not in only_fields) or (attr_name in exclude_fields): + if (only_fields and attr_name not in only_fields) or ( + attr_name in exclude_fields + ): continue auto_orm_field_names.append(attr_name) @@ -135,13 +149,15 @@ def construct_fields( # Set the model_attr if not set for orm_field_name, orm_field in custom_orm_fields_items: - attr_name = orm_field.kwargs.get('model_attr', orm_field_name) + attr_name = orm_field.kwargs.get("model_attr", orm_field_name) if attr_name not in all_model_attrs: - raise ValueError(( - "Cannot map ORMField to a model attribute.\n" - "Field: '{}.{}'" - ).format(obj_type.__name__, orm_field_name,)) - orm_field.kwargs['model_attr'] = attr_name + raise ValueError( + ("Cannot map ORMField to a model attribute.\n" "Field: '{}.{}'").format( + obj_type.__name__, + orm_field_name, + ) + ) + orm_field.kwargs["model_attr"] = attr_name # Merge automatic fields with custom ORM fields orm_fields = OrderedDict(custom_orm_fields_items) @@ -153,27 +169,38 @@ def construct_fields( # Build all the field dictionary fields = OrderedDict() for orm_field_name, orm_field in orm_fields.items(): - attr_name = orm_field.kwargs.pop('model_attr') + attr_name = orm_field.kwargs.pop("model_attr") attr = all_model_attrs[attr_name] - resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(obj_type, attr_name) + resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver( + obj_type, attr_name + ) if isinstance(attr, ColumnProperty): - field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) + field = convert_sqlalchemy_column( + attr, registry, resolver, **orm_field.kwargs + ) elif isinstance(attr, RelationshipProperty): - batching_ = orm_field.kwargs.pop('batching', batching) + batching_ = orm_field.kwargs.pop("batching", batching) field = convert_sqlalchemy_relationship( - attr, obj_type, connection_field_factory, batching_, orm_field_name, **orm_field.kwargs) + attr, + obj_type, + connection_field_factory, + batching_, + orm_field_name, + **orm_field.kwargs + ) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: # TODO Add a way to override composite property fields raise ValueError( "ORMField kwargs for composite fields must be empty. " - "Field: {}.{}".format(obj_type.__name__, orm_field_name)) + "Field: {}.{}".format(obj_type.__name__, orm_field_name) + ) field = convert_sqlalchemy_composite(attr, registry, resolver) elif isinstance(attr, hybrid_property): field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) else: - raise Exception('Property type is not supported') # Should never happen + raise Exception("Property type is not supported") # Should never happen registry.register_orm_field(obj_type, orm_field_name, attr) fields[orm_field_name] = field @@ -210,7 +237,8 @@ def __init_subclass_with_meta__( # Make sure model is a valid SQLAlchemy model if not is_mapped_class(model): raise ValueError( - "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".'.format(cls.__name__, model) + "You need to pass a valid SQLAlchemy Model in " + '{}.Meta, received "{}".'.format(cls.__name__, model) ) if not registry: @@ -222,7 +250,9 @@ def __init_subclass_with_meta__( ).format(cls.__name__, registry) if only_fields and exclude_fields: - raise ValueError("The options 'only_fields' and 'exclude_fields' cannot be both set on the same type.") + raise ValueError( + "The options 'only_fields' and 'exclude_fields' cannot be both set on the same type." + ) sqla_fields = yank_fields_from_attrs( construct_fields( @@ -294,7 +324,10 @@ def get_query(cls, info): return get_query(model, info.context) @classmethod - def get_node(cls, info, id): + async def get_node(cls, info, id): + session = get_session(info.context) + if isinstance(session, AsyncSession): + return await session.get(cls._meta.model, id) try: return cls.get_query(info).get(id) except NoResultFound: diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 301e782c..d7b8b92f 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -4,7 +4,9 @@ from typing import Any, Callable, Dict, Optional import pkg_resources +from sqlalchemy import select from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError @@ -24,6 +26,8 @@ def get_query(model, context): "A query in the model Base or a session in the schema is required for querying.\n" "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" ) + if isinstance(session, AsyncSession): + return select(model) query = session.query(model) return query @@ -154,7 +158,9 @@ def sort_argument_for_model(cls, has_default=True): def is_sqlalchemy_version_less_than(version_string): """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) + return pkg_resources.get_distribution( + "SQLAlchemy" + ).parsed_version < pkg_resources.parse_version(version_string) class singledispatchbymatchfunction: @@ -178,7 +184,6 @@ def __call__(self, *args, **kwargs): return self.default(*args, **kwargs) def register(self, matcher_function: Callable[[Any], bool]): - def grab_function_from_outside(f): self.registry[matcher_function] = f return self @@ -188,7 +193,7 @@ def grab_function_from_outside(f): def value_equals(value): """A simple function that makes the equality based matcher functions for - SingleDispatchByMatchFunction prettier""" + SingleDispatchByMatchFunction prettier""" return lambda x: x == value @@ -198,11 +203,17 @@ def safe_isinstance_checker(arg): return isinstance(arg, cls) except TypeError: pass + return safe_isinstance_checker def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: try: - return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys()))) + return next( + filter( + lambda x: x.__name__ == model_name, + list(get_global_registry()._registry.keys()), + ) + ) except StopIteration: pass From 41c88f9890b20b4d1e4925ec9f377b08deeeb6a3 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Mon, 16 May 2022 15:25:17 +0200 Subject: [PATCH 02/19] fix(test batching): ensure that objects are added to database in async mode --- graphene_sqlalchemy/tests/test_batching.py | 52 +++++++++++----------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 5f8c7695..fee9550d 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -124,15 +124,15 @@ async def test_many_to_one(session_factory): session = session_factory() result = await schema.execute_async( """ - query { - articles { - headline - reporter { - firstName + query { + articles { + headline + reporter { + firstName + } + } } - } - } - """, + """, context_value={"session": session}, ) messages = sqlalchemy_logging_handler.messages @@ -199,8 +199,8 @@ async def test_one_to_one(session_factory): article_2.reporter = reporter_2 session.add(article_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_schema() @@ -291,9 +291,8 @@ async def test_one_to_many(session_factory): article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) - - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_schema() @@ -412,9 +411,8 @@ async def test_many_to_many(session_factory): reporter_2.pets.append(pet_3) reporter_2.pets.append(pet_4) - - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_schema() @@ -503,14 +501,15 @@ async def test_many_to_many(session_factory): } -def test_disable_batching_via_ormfield(session_factory): +@pytest.mark.asyncio +async def test_disable_batching_via_ormfield(session_factory): session = session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") class ReporterType(SQLAlchemyObjectType): class Meta: @@ -563,7 +562,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute( + await schema.execute_async( """ query { reporters { @@ -596,8 +595,8 @@ async def test_connection_factory_field_overrides_batching_is_false(session_fact session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") class ReporterType(SQLAlchemyObjectType): class Meta: @@ -660,14 +659,15 @@ def resolve_reporters(self, info): assert len(select_statements) == 1 -def test_connection_factory_field_overrides_batching_is_true(session_factory): +@pytest.mark.asyncio +async def test_connection_factory_field_overrides_batching_is_true(session_factory): session = session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") class ReporterType(SQLAlchemyObjectType): class Meta: @@ -694,7 +694,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = session_factory() - schema.execute( + await schema.execute_async( """ query { reporters { From 47d224e0e7a5bc4b1954475641a05d8829df09d7 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Tue, 17 May 2022 10:09:52 +0200 Subject: [PATCH 03/19] test: only run batching tests with sync session --- graphene_sqlalchemy/tests/conftest.py | 10 +++++ graphene_sqlalchemy/tests/test_batching.py | 44 +++++++++++----------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 1722224d..7ed3e54e 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -52,6 +52,16 @@ async def session_factory(async_session: bool, test_db_url: str): engine.dispose() +@pytest.fixture(scope="function") +async def sync_session_factory(): + engine = create_engine("sqlite://") + Base.metadata.create_all(engine) + yield sessionmaker(bind=engine, expire_on_commit=False) + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() + + @pytest.fixture(scope="function") def session(session_factory): return session_factory() diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index fee9550d..b82f83e9 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -94,8 +94,8 @@ async def eventually_await_session(session, func, *args): @pytest.mark.asyncio -async def test_many_to_one(session_factory): - session = session_factory() +async def test_many_to_one(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter( first_name="Reporter_1", @@ -121,7 +121,7 @@ async def test_many_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = await schema.execute_async( """ query { @@ -179,8 +179,8 @@ async def test_many_to_one(session_factory): @pytest.mark.asyncio -async def test_one_to_one(session_factory): - session = session_factory() +async def test_one_to_one(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter( first_name="Reporter_1", @@ -206,7 +206,7 @@ async def test_one_to_one(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = await schema.execute_async( """ query { @@ -264,8 +264,8 @@ async def test_one_to_one(session_factory): @pytest.mark.asyncio -async def test_one_to_many(session_factory): - session = session_factory() +async def test_one_to_many(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter( first_name="Reporter_1", @@ -298,7 +298,7 @@ async def test_one_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = await schema.execute_async( """ query { @@ -382,8 +382,8 @@ async def test_one_to_many(session_factory): @pytest.mark.asyncio -async def test_many_to_many(session_factory): - session = session_factory() +async def test_many_to_many(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter( first_name="Reporter_1", @@ -418,7 +418,7 @@ async def test_many_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = await schema.execute_async( """ query { @@ -502,8 +502,8 @@ async def test_many_to_many(session_factory): @pytest.mark.asyncio -async def test_disable_batching_via_ormfield(session_factory): - session = session_factory() +async def test_disable_batching_via_ormfield(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -536,7 +536,7 @@ def resolve_reporters(self, info): # Test one-to-one and many-to-one relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() schema.execute( """ query { @@ -561,7 +561,7 @@ def resolve_reporters(self, info): # Test one-to-many and many-to-many relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() await schema.execute_async( """ query { @@ -589,8 +589,8 @@ def resolve_reporters(self, info): @pytest.mark.asyncio -async def test_connection_factory_field_overrides_batching_is_false(session_factory): - session = session_factory() +async def test_connection_factory_field_overrides_batching_is_false(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -622,7 +622,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() await schema.execute_async( """ query { @@ -660,8 +660,8 @@ def resolve_reporters(self, info): @pytest.mark.asyncio -async def test_connection_factory_field_overrides_batching_is_true(session_factory): - session = session_factory() +async def test_connection_factory_field_overrides_batching_is_true(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -693,7 +693,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() await schema.execute_async( """ query { From 811cdf289c443165ee26a4f18bc4faab1a8efa80 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Tue, 31 May 2022 15:27:47 +0200 Subject: [PATCH 04/19] chore(fields): use get_query instead of manually crafting the query --- graphene_sqlalchemy/fields.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index ea421450..ca784745 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -3,7 +3,6 @@ from functools import partial from promise import Promise, is_thenable -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.query import Query @@ -117,11 +116,7 @@ def __init__(self, type_, *args, **kwargs): @classmethod def get_query(cls, model, info, sort=None, **args): - session = get_session(info.context) - if isinstance(session, AsyncSession): - query = select(model) - else: - query = get_query(model, info.context) + query = get_query(model, info.context) if sort is not None: if not isinstance(sort, list): sort = [sort] From 3149830842e708583529ff5f3e2dadc1ac25dead Mon Sep 17 00:00:00 2001 From: Jendrik Date: Tue, 31 May 2022 20:19:28 +0200 Subject: [PATCH 05/19] fix: throw exceptions if Async Session is used with old sql alchemy --- graphene_sqlalchemy/__init__.py | 2 +- graphene_sqlalchemy/fields.py | 8 ++++- graphene_sqlalchemy/tests/conftest.py | 10 +++++- graphene_sqlalchemy/tests/test_batching.py | 34 ++++++++++----------- graphene_sqlalchemy/tests/test_benchmark.py | 28 +++++++++-------- graphene_sqlalchemy/tests/test_types.py | 4 +-- graphene_sqlalchemy/types.py | 18 ++++++----- graphene_sqlalchemy/utils.py | 6 ++++ setup.py | 5 ++- 9 files changed, 70 insertions(+), 45 deletions(-) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 060bd13b..18d34f1d 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ -from .types import SQLAlchemyObjectType from .fields import SQLAlchemyConnectionField +from .types import SQLAlchemyObjectType from .utils import get_query, get_session __version__ = "3.0.0b1" diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index ca784745..2c5cb95c 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -12,7 +12,8 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import EnumValue, get_query, get_session +from .utils import (EnumValue, get_query, get_session, + is_sqlalchemy_version_less_than) class UnsortedSQLAlchemyConnectionField(ConnectionField): @@ -49,6 +50,11 @@ async def resolve_connection(cls, connection_type, model, info, args, resolved): session = get_session(info.context) if resolved is None: if isinstance(session, AsyncSession): + if is_sqlalchemy_version_less_than("1.4"): + raise Exception( + "You are using an async session with SQLAlchemy < 1.4.\n" + "Please upgrade SQLAlchemy to 1.4.0 or higher." + ) resolved = ( await session.scalars(cls.get_query(model, info, **args)) ).all() diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 7ed3e54e..6c47e92b 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import sessionmaker import graphene +from graphene_sqlalchemy.utils import is_sqlalchemy_version_less_than from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry @@ -21,7 +22,14 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.fixture(params=[False, True]) +@pytest.fixture( + params=[ + False, + pytest.mark.xfail(True, strict=True) + if is_sqlalchemy_version_less_than("1.4") + else True, + ] +) def async_session(request): return request.param diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index b82f83e9..a1d5528a 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -114,8 +114,8 @@ async def test_many_to_one(sync_session_factory): article_2.reporter = reporter_2 session.add(article_2) - await eventually_await_session(session, "commit") - await eventually_await_session(session, "close") + session.commit() + session.close() schema = get_schema() @@ -199,8 +199,8 @@ async def test_one_to_one(sync_session_factory): article_2.reporter = reporter_2 session.add(article_2) - await eventually_await_session(session, "commit") - await eventually_await_session(session, "close") + session.commit() + session.close() schema = get_schema() @@ -291,8 +291,8 @@ async def test_one_to_many(sync_session_factory): article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) - await eventually_await_session(session, "commit") - await eventually_await_session(session, "close") + session.commit() + session.close() schema = get_schema() @@ -501,15 +501,14 @@ async def test_many_to_many(sync_session_factory): } -@pytest.mark.asyncio -async def test_disable_batching_via_ormfield(sync_session_factory): +def test_disable_batching_via_ormfield(sync_session_factory): session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) - await eventually_await_session(session, "commit") - await eventually_await_session(session, "close") + session.commit() + session.close() class ReporterType(SQLAlchemyObjectType): class Meta: @@ -562,7 +561,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = sync_session_factory() - await schema.execute_async( + schema.execute( """ query { reporters { @@ -595,8 +594,8 @@ async def test_connection_factory_field_overrides_batching_is_false(sync_session session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) - await eventually_await_session(session, "commit") - await eventually_await_session(session, "close") + session.commit() + session.close() class ReporterType(SQLAlchemyObjectType): class Meta: @@ -659,15 +658,14 @@ def resolve_reporters(self, info): assert len(select_statements) == 1 -@pytest.mark.asyncio -async def test_connection_factory_field_overrides_batching_is_true(sync_session_factory): +def test_connection_factory_field_overrides_batching_is_true(sync_session_factory): session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") session.add(reporter_2) - await eventually_await_session(session, "commit") - await eventually_await_session(session, "close") + session.commit() + session.close() class ReporterType(SQLAlchemyObjectType): class Meta: @@ -694,7 +692,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level session = sync_session_factory() - await schema.execute_async( + schema.execute( """ query { reporters { diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 77e57db0..20d36d4d 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -63,6 +63,7 @@ async def execute_query(): @pytest.mark.asyncio async def test_one_to_one(session_factory, benchmark): + print(is_sqlalchemy_version_less_than("1.4")) session = session_factory() reporter_1 = Reporter( @@ -101,7 +102,8 @@ async def test_one_to_one(session_factory, benchmark): ) -def test_many_to_one(session_factory, benchmark): +@pytest.mark.asyncio +async def test_many_to_one(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( @@ -121,10 +123,10 @@ def test_many_to_one(session_factory, benchmark): article_2.reporter = reporter_2 session.add(article_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( + await benchmark_query( session_factory, benchmark, """ @@ -140,7 +142,8 @@ def test_many_to_one(session_factory, benchmark): ) -def test_one_to_many(session_factory, benchmark): +@pytest.mark.asyncio +async def test_one_to_many(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( @@ -168,10 +171,10 @@ def test_one_to_many(session_factory, benchmark): article_4.reporter = reporter_2 session.add(article_4) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( + await benchmark_query( session_factory, benchmark, """ @@ -191,7 +194,8 @@ def test_one_to_many(session_factory, benchmark): ) -def test_many_to_many(session_factory, benchmark): +@pytest.mark.asyncio +async def test_many_to_many(session_factory, benchmark): session = session_factory() reporter_1 = Reporter( @@ -221,10 +225,10 @@ def test_many_to_many(session_factory, benchmark): reporter_2.pets.append(pet_3) reporter_2.pets.append(pet_4) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( + await benchmark_query( session_factory, benchmark, """ diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index a8736806..30ae0d64 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -38,7 +38,7 @@ class Meta: model = 1 -@pytest.mark.asyncioas +@pytest.mark.asyncio async def test_sqlalchemy_node(session): class ReporterType(SQLAlchemyObjectType): class Meta: @@ -52,7 +52,7 @@ class Meta: session.add(reporter) await eventually_await_session(session, "commit") info = mock.Mock(context={"session": session}) - reporter_node = ReporterType.get_node(info, reporter.id) + reporter_node = await ReporterType.get_node(info, reporter.id) assert reporter == reporter_node diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index d48bfe10..8acecf00 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -21,7 +21,8 @@ sort_enum_for_object_type) from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import get_query, get_session, is_mapped_class, is_mapped_instance +from .utils import (get_query, get_session, is_mapped_class, + is_mapped_instance, is_sqlalchemy_version_less_than) class ORMField(OrderedType): @@ -325,13 +326,16 @@ def get_query(cls, info): @classmethod async def get_node(cls, info, id): + session = get_session(info.context) - if isinstance(session, AsyncSession): - return await session.get(cls._meta.model, id) - try: - return cls.get_query(info).get(id) - except NoResultFound: - return None + if is_sqlalchemy_version_less_than("1.4") or not isinstance( + session, AsyncSession + ): + try: + return cls.get_query(info).get(id) + except NoResultFound: + return None + return await session.get(cls._meta.model, id) def resolve_id(self, info): # graphene_type = info.parent_type.graphene_type diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index d7b8b92f..686a86b8 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -27,6 +27,11 @@ def get_query(model, context): "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" ) if isinstance(session, AsyncSession): + if is_sqlalchemy_version_less_than("1.4"): + raise Exception( + "You are using an async session with SQLAlchemy < 1.4.\n" + "Please upgrade SQLAlchemy to 1.4.0 or higher." + ) return select(model) query = session.query(model) return query @@ -144,6 +149,7 @@ def sort_argument_for_model(cls, has_default=True): ) from graphene import Argument, List + from .enums import sort_enum_for_object_type enum = sort_enum_for_object_type( diff --git a/setup.py b/setup.py index da49f1d4..52b1e84c 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,7 @@ _version_re = re.compile(r"__version__\s+=\s+(.*)") with open("graphene_sqlalchemy/__init__.py", "rb") as f: - version = str( - ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) - ) + version = str(ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))) requirements = [ # To keep things simple, we only support newer versions of Graphene @@ -25,6 +23,7 @@ "pytest-cov>=2.11.0,<3.0", "sqlalchemy_utils>=0.37.0,<1.0", "pytest-benchmark>=3.4.0,<4.0", + "aiosqlite>=0.17.0", ] setup( From 0e60c03b48a1726c00f229248bc254079b52b071 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Tue, 31 May 2022 21:07:19 +0200 Subject: [PATCH 06/19] test: fix sqlalchemy 1.2 and 1.3 tests, fix batching tests by separating models --- graphene_sqlalchemy/fields.py | 13 ++-- graphene_sqlalchemy/tests/conftest.py | 15 ++--- graphene_sqlalchemy/tests/models.py | 10 +-- graphene_sqlalchemy/tests/models_batching.py | 64 +++++++++++++++++++ graphene_sqlalchemy/tests/test_batching.py | 20 ++++-- graphene_sqlalchemy/tests/test_benchmark.py | 12 ++-- graphene_sqlalchemy/tests/test_query.py | 26 ++++++-- graphene_sqlalchemy/tests/test_query_enums.py | 19 +++--- graphene_sqlalchemy/tests/test_types.py | 15 +++-- graphene_sqlalchemy/tests/utils.py | 6 +- graphene_sqlalchemy/types.py | 17 +++-- graphene_sqlalchemy/utils.py | 28 ++++---- 12 files changed, 172 insertions(+), 73 deletions(-) create mode 100644 graphene_sqlalchemy/tests/models_batching.py diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 2c5cb95c..c9b1af08 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -3,7 +3,6 @@ from functools import partial from promise import Promise, is_thenable -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.query import Query from graphene import NonNull @@ -15,6 +14,9 @@ from .utils import (EnumValue, get_query, get_session, is_sqlalchemy_version_less_than) +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + class UnsortedSQLAlchemyConnectionField(ConnectionField): @property @@ -49,12 +51,9 @@ def get_query(cls, model, info, **args): async def resolve_connection(cls, connection_type, model, info, args, resolved): session = get_session(info.context) if resolved is None: - if isinstance(session, AsyncSession): - if is_sqlalchemy_version_less_than("1.4"): - raise Exception( - "You are using an async session with SQLAlchemy < 1.4.\n" - "Please upgrade SQLAlchemy to 1.4.0 or higher." - ) + if is_sqlalchemy_version_less_than("1.4"): + resolved = cls.get_query(model, info, **args) + elif isinstance(session, AsyncSession): resolved = ( await session.scalars(cls.get_query(model, info, **args)) ).all() diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 6c47e92b..a5cf559a 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,6 +1,5 @@ import pytest from sqlalchemy import create_engine -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker import graphene @@ -10,6 +9,9 @@ from ..registry import reset_global_registry from .models import Base, CompositeFullName +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + @pytest.fixture(autouse=True) def reset_registry(): @@ -22,14 +24,7 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.fixture( - params=[ - False, - pytest.mark.xfail(True, strict=True) - if is_sqlalchemy_version_less_than("1.4") - else True, - ] -) +@pytest.fixture(params=[False, True]) def async_session(request): return request.param @@ -46,6 +41,8 @@ def test_db_url(async_session: bool): @pytest.fixture(scope="function") async def session_factory(async_session: bool, test_db_url: str): if async_session: + if is_sqlalchemy_version_less_than("1.4"): + pytest.skip(f"Async Sessions only work in sql alchemy 1.4 and above") engine = create_async_engine(test_db_url) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index ff8123f5..2a68a752 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -69,10 +69,10 @@ class Reporter(Base): secondary=association_table, backref="reporters", order_by="Pet.id", - lazy="joined", + lazy="selectin", ) - articles = relationship("Article", backref="reporter", lazy="joined") - favorite_article = relationship("Article", uselist=False, lazy="joined") + articles = relationship("Article", backref="reporter", lazy="selectin") + favorite_article = relationship("Article", uselist=False, lazy="selectin") @hybrid_property def hybrid_prop_with_doc(self): @@ -107,7 +107,9 @@ def hybrid_prop_list(self) -> List[int]: select([func.cast(func.count(id), Integer)]), doc="Column property" ) - composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") + composite_prop = composite( + CompositeFullName, first_name, last_name, doc="Composite" + ) class Article(Base): diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py new file mode 100644 index 00000000..11216ea5 --- /dev/null +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import + +import enum + +from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, + func, select) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import column_property, relationship + +PetKind = Enum("cat", "dog", name="pet_kind") + + +class HairKind(enum.Enum): + LONG = "long" + SHORT = "short" + + +Base = declarative_base() + +association_table = Table( + "association", + Base.metadata, + Column("pet_id", Integer, ForeignKey("pets.id")), + Column("reporter_id", Integer, ForeignKey("reporters.id")), +) + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pet_kind = Column(PetKind, nullable=False) + hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + + +class Reporter(Base): + __tablename__ = "reporters" + + id = Column(Integer(), primary_key=True) + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + email = Column(String(), doc="Email") + favorite_pet_kind = Column(PetKind) + pets = relationship( + "Pet", + secondary=association_table, + backref="reporters", + order_by="Pet.id", + ) + articles = relationship("Article", backref="reporter") + favorite_article = relationship("Article", uselist=False) + + column_prop = column_property( + select([func.cast(func.count(id), Integer)]), doc="Column property" + ) + + +class Article(Base): + __tablename__ = "articles" + id = Column(Integer(), primary_key=True) + headline = Column(String(100)) + pub_date = Column(Date()) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index a1d5528a..21fc2604 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -4,7 +4,6 @@ import pytest from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession import graphene from graphene import relay @@ -13,9 +12,12 @@ default_connection_field_factory) from ..types import ORMField, SQLAlchemyObjectType from ..utils import get_session, is_sqlalchemy_version_less_than -from .models import Article, HairKind, Pet, Reporter +from .models_batching import Article, HairKind, Pet, Reporter from .utils import remove_cache_miss_stat, to_std_dicts +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + class MockLoggingHandler(logging.Handler): """Intercept and store log messages in a list.""" @@ -69,13 +71,17 @@ class Query(graphene.ObjectType): async def resolve_articles(self, info): session = get_session(info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Article))).all() return session.query(Article).all() async def resolve_reporters(self, info): session = get_session(info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Reporter))).all() return session.query(Reporter).all() @@ -87,7 +93,7 @@ async def resolve_reporters(self, info): async def eventually_await_session(session, func, *args): - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): await getattr(session, func)(*args) else: getattr(session, func)(*args) @@ -588,7 +594,9 @@ def resolve_reporters(self, info): @pytest.mark.asyncio -async def test_connection_factory_field_overrides_batching_is_false(sync_session_factory): +async def test_connection_factory_field_overrides_batching_is_false( + sync_session_factory, +): session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 20d36d4d..3e2a1a97 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,6 +1,5 @@ import pytest from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession import graphene from graphene import relay @@ -10,6 +9,8 @@ from .models import Article, HairKind, Pet, Reporter from .utils import eventually_await_session +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession if is_sqlalchemy_version_less_than("1.2"): pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) @@ -36,13 +37,17 @@ class Query(graphene.ObjectType): async def resolve_articles(self, info): session = get_session(info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Article))).all() return session.query(Article).all() async def resolve_reporters(self, info): session = get_session(info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Reporter))).all() return session.query(Reporter).all() @@ -63,7 +68,6 @@ async def execute_query(): @pytest.mark.asyncio async def test_one_to_one(session_factory, benchmark): - print(is_sqlalchemy_version_less_than("1.4")) session = session_factory() reporter_1 = Reporter( diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index ae1c1c78..d9b24898 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,6 +1,5 @@ import pytest from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession import graphene from graphene.relay import Node @@ -8,10 +7,13 @@ from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField from ..types import ORMField, SQLAlchemyObjectType -from ..utils import get_session +from ..utils import get_session, is_sqlalchemy_version_less_than from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter from .utils import eventually_await_session, to_std_dicts +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + async def add_test_data(session): reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") @@ -50,13 +52,17 @@ class Query(graphene.ObjectType): async def resolve_reporter(self, _info): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() async def resolve_reporters(self, _info): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) @@ -114,7 +120,9 @@ class Query(graphene.ObjectType): async def resolve_reporter(self, _info): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Reporter))).first() return session.query(Reporter).first() @@ -195,7 +203,9 @@ class Query(graphene.ObjectType): async def resolve_reporter(self, _info): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Reporter))).first() return session.query(Reporter).first() @@ -295,7 +305,9 @@ class Meta: @classmethod async def get_node(cls, id, info): session = get_session(info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index 0375da20..d9119dc3 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -1,15 +1,18 @@ import pytest from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession import graphene from graphene_sqlalchemy.tests.utils import eventually_await_session -from graphene_sqlalchemy.utils import get_session +from graphene_sqlalchemy.utils import (get_session, + is_sqlalchemy_version_less_than) from ..types import SQLAlchemyObjectType from .models import HairKind, Pet, Reporter from .test_query import add_test_data, to_std_dicts +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + @pytest.mark.asyncio async def test_query_pet_kinds(session, session_factory): @@ -33,19 +36,19 @@ class Query(graphene.ObjectType): async def resolve_reporter(self, _info): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() async def resolve_reporters(self, _info): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) async def resolve_pets(self, _info, kind): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): query = select(Pet) if kind: query = query.filter(Pet.pet_kind == kind.value) @@ -118,7 +121,7 @@ class Query(graphene.ObjectType): async def resolve_pet(self, _info): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): return (await session.scalars(select(Pet))).first() return session.query(Pet).first() @@ -154,7 +157,7 @@ class Query(graphene.ObjectType): async def resolve_pet(self, info, kind=None): session = get_session(info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): query = select(Pet) if kind: query = query.filter(Pet.pet_kind == kind.value) @@ -206,7 +209,7 @@ class Query(graphene.ObjectType): async def resolve_pet(self, _info, kind=None): session = get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): return ( await session.scalars( select(Pet).filter(Pet.hair_kind == HairKind(kind)) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 30ae0d64..c440be1a 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,7 +4,6 @@ import sqlalchemy.exc import sqlalchemy.orm.exc from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from graphene import (Boolean, Dynamic, Field, Float, GlobalID, Int, List, Node, NonNull, ObjectType, Schema, String) @@ -17,9 +16,13 @@ registerConnectionFieldFactory, unregisterConnectionFieldFactory) from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions +from ..utils import is_sqlalchemy_version_less_than from .models import Article, CompositeFullName, Pet, Reporter from .utils import eventually_await_session +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + def test_should_raise_if_no_model(): re_err = r"valid SQLAlchemy Model" @@ -420,7 +423,9 @@ class Query(ObjectType): async def resolve_reporter(self, _info): session = utils.get_session(_info.context) - if isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() @@ -489,9 +494,9 @@ class Meta: def __init_subclass_with_meta__(cls, custom_option=None, **options): _meta = CustomOptions(cls) _meta.custom_option = custom_option - super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) + super( + SQLAlchemyObjectTypeWithCustomOptions, cls + ).__init_subclass_with_meta__(_meta=_meta, **options) class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): class Meta: diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index 42960a32..4a118243 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,7 +1,6 @@ +import inspect import re -from sqlalchemy.ext.asyncio import AsyncSession - def to_std_dicts(value): """Convert nested ordered dicts to normal dicts for better comparison.""" @@ -20,7 +19,8 @@ def remove_cache_miss_stat(message): async def eventually_await_session(session, func, *args): - if isinstance(session, AsyncSession): + + if inspect.iscoroutinefunction(getattr(session, func)): await getattr(session, func)(*args) else: getattr(session, func)(*args) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 8acecf00..9f52417b 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,7 +1,6 @@ from collections import OrderedDict import sqlalchemy -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import (ColumnProperty, CompositeProperty, RelationshipProperty) @@ -24,6 +23,9 @@ from .utils import (get_query, get_session, is_mapped_class, is_mapped_instance, is_sqlalchemy_version_less_than) +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + class ORMField(OrderedType): def __init__( @@ -327,15 +329,18 @@ def get_query(cls, info): @classmethod async def get_node(cls, info, id): - session = get_session(info.context) - if is_sqlalchemy_version_less_than("1.4") or not isinstance( - session, AsyncSession - ): + if is_sqlalchemy_version_less_than("1.4"): try: return cls.get_query(info).get(id) except NoResultFound: return None - return await session.get(cls._meta.model, id) + session = get_session(info.context) + if isinstance(session, AsyncSession): + return await session.get(cls._meta.model, id) + try: + return cls.get_query(info).get(id) + except NoResultFound: + return None def resolve_id(self, info): # graphene_type = info.parent_type.graphene_type diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 686a86b8..896b8d5d 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -6,13 +6,23 @@ import pkg_resources from sqlalchemy import select from sqlalchemy.exc import ArgumentError -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError from graphene_sqlalchemy.registry import get_global_registry +def is_sqlalchemy_version_less_than(version_string): + """Check the installed SQLAlchemy version""" + return pkg_resources.get_distribution( + "SQLAlchemy" + ).parsed_version < pkg_resources.parse_version(version_string) + + +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + + def get_session(context): return context.get("session") @@ -26,12 +36,9 @@ def get_query(model, context): "A query in the model Base or a session in the schema is required for querying.\n" "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" ) - if isinstance(session, AsyncSession): - if is_sqlalchemy_version_less_than("1.4"): - raise Exception( - "You are using an async session with SQLAlchemy < 1.4.\n" - "Please upgrade SQLAlchemy to 1.4.0 or higher." - ) + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): return select(model) query = session.query(model) return query @@ -162,13 +169,6 @@ def sort_argument_for_model(cls, has_default=True): return Argument(List(enum), default_value=enum.default) -def is_sqlalchemy_version_less_than(version_string): - """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution( - "SQLAlchemy" - ).parsed_version < pkg_resources.parse_version(version_string) - - class singledispatchbymatchfunction: """ Inspired by @singledispatch, this is a variant that works using a matcher function From 0180f696280699f056e3d3f8221ca0d87c460aba Mon Sep 17 00:00:00 2001 From: Jendrik Date: Thu, 2 Jun 2022 11:20:59 +0200 Subject: [PATCH 07/19] fix: ensure that synchronous execute calls are still feasible --- graphene_sqlalchemy/fields.py | 41 ++++++++++- graphene_sqlalchemy/tests/test_query.py | 98 ++++++++++++++++++++++++- graphene_sqlalchemy/tests/test_types.py | 5 +- graphene_sqlalchemy/types.py | 11 ++- 4 files changed, 144 insertions(+), 11 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index c9b1af08..bba19c6e 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -48,15 +48,20 @@ def get_query(cls, model, info, **args): return get_query(model, info.context) @classmethod - async def resolve_connection(cls, connection_type, model, info, args, resolved): + def resolve_connection(cls, connection_type, model, info, args, resolved): session = get_session(info.context) if resolved is None: if is_sqlalchemy_version_less_than("1.4"): resolved = cls.get_query(model, info, **args) elif isinstance(session, AsyncSession): - resolved = ( - await session.scalars(cls.get_query(model, info, **args)) - ).all() + + async def get_result(): + return await cls.resolve_connection_async( + connection_type, model, info, args, resolved + ) + + return get_result() + else: resolved = cls.get_query(model, info, **args) if isinstance(resolved, Query): @@ -81,6 +86,34 @@ def adjusted_connection_adapter(edges, pageInfo): connection.length = _len return connection + @classmethod + async def resolve_connection_async(cls, connection_type, model, info, args, resolved): + session = get_session(info.context) + if resolved is None: + query = cls.get_query(model, info, **args) + resolved = (await session.scalars(query)).all() + if isinstance(resolved, Query): + _len = resolved.count() + else: + _len = len(resolved) + + def adjusted_connection_adapter(edges, pageInfo): + return connection_adapter(connection_type, edges, pageInfo) + + connection = connection_from_array_slice( + array_slice=resolved, + args=args, + slice_start=0, + array_length=_len, + array_slice_length=_len, + connection_type=adjusted_connection_adapter, + edge_type=connection_type.Edge, + page_info_type=page_info_adapter, + ) + connection.iterable = resolved + connection.length = _len + return connection + @classmethod def connection_resolver(cls, resolver, connection_type, model, root, info, **args): resolved = resolver(root, info, **args) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index d9b24898..3c4e47be 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -96,7 +96,7 @@ async def resolve_reporters(self, _info): @pytest.mark.asyncio -async def test_query_node(session): +async def test_query_node_sync(session): await add_test_data(session) class ReporterNode(SQLAlchemyObjectType): @@ -118,12 +118,104 @@ class Query(graphene.ObjectType): reporter = graphene.Field(ReporterNode) all_articles = SQLAlchemyConnectionField(ArticleNode.connection) - async def resolve_reporter(self, _info): + def resolve_reporter(self, _info): session = get_session(_info.context) if not is_sqlalchemy_version_less_than("1.4") and isinstance( session, AsyncSession ): - return (await session.scalars(select(Reporter))).first() + + async def get_result(): + return (await session.scalars(select(Reporter))).first() + + return get_result() + + return session.query(Reporter).first() + + query = """ + query { + reporter { + id + firstName + articles { + edges { + node { + headline + } + } + } + } + allArticles { + edges { + node { + headline + } + } + } + myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { + id + ... on ReporterNode { + firstName + } + ... on ArticleNode { + headline + } + } + } + """ + expected = { + "reporter": { + "id": "UmVwb3J0ZXJOb2RlOjE=", + "firstName": "John", + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, + "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, + } + schema = graphene.Schema(query=Query) + if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): + result = schema.execute(query, context_value={"session": session}) + assert result.errors + else: + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +@pytest.mark.asyncio +async def test_query_node_async(session): + await add_test_data(session) + + class ReporterNode(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + @classmethod + def get_node(cls, info, id): + return Reporter(id=2, first_name="Cookie Monster") + + class ArticleNode(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + class Query(graphene.ObjectType): + node = Node.Field() + reporter = graphene.Field(ReporterNode) + all_articles = SQLAlchemyConnectionField(ArticleNode.connection) + + def resolve_reporter(self, _info): + session = get_session(_info.context) + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): + + async def get_result(): + return (await session.scalars(select(Reporter))).first() + + return get_result() + return session.query(Reporter).first() query = """ diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index c440be1a..b3a4fcf8 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -3,6 +3,7 @@ import pytest import sqlalchemy.exc import sqlalchemy.orm.exc +from graphql.pyutils import is_awaitable from sqlalchemy import select from graphene import (Boolean, Dynamic, Field, Float, GlobalID, Int, List, @@ -55,7 +56,9 @@ class Meta: session.add(reporter) await eventually_await_session(session, "commit") info = mock.Mock(context={"session": session}) - reporter_node = await ReporterType.get_node(info, reporter.id) + reporter_node = ReporterType.get_node(info, reporter.id) + if is_awaitable(reporter_node): + reporter_node = await reporter_node assert reporter == reporter_node diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 9f52417b..02709361 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from typing import Any import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property @@ -327,16 +328,20 @@ def get_query(cls, info): return get_query(model, info.context) @classmethod - async def get_node(cls, info, id): - + def get_node(cls, info, id): if is_sqlalchemy_version_less_than("1.4"): try: return cls.get_query(info).get(id) except NoResultFound: return None + session = get_session(info.context) if isinstance(session, AsyncSession): - return await session.get(cls._meta.model, id) + + async def get_result() -> Any: + return await session.get(cls._meta.model, id) + + return get_result() try: return cls.get_query(info).get(id) except NoResultFound: From ec766975e39d41ea076cf1c712c0782db3d921b0 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Tue, 7 Jun 2022 16:00:44 +0200 Subject: [PATCH 08/19] refactor: remove duplicate code by fixing if condition --- graphene_sqlalchemy/fields.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index bba19c6e..09f48672 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -11,8 +11,7 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import (EnumValue, get_query, get_session, - is_sqlalchemy_version_less_than) +from .utils import EnumValue, get_query, get_session, is_sqlalchemy_version_less_than if not is_sqlalchemy_version_less_than("1.4"): from sqlalchemy.ext.asyncio import AsyncSession @@ -51,9 +50,9 @@ def get_query(cls, model, info, **args): def resolve_connection(cls, connection_type, model, info, args, resolved): session = get_session(info.context) if resolved is None: - if is_sqlalchemy_version_less_than("1.4"): - resolved = cls.get_query(model, info, **args) - elif isinstance(session, AsyncSession): + if not is_sqlalchemy_version_less_than("1.4") and isinstance( + session, AsyncSession + ): async def get_result(): return await cls.resolve_connection_async( @@ -87,7 +86,9 @@ def adjusted_connection_adapter(edges, pageInfo): return connection @classmethod - async def resolve_connection_async(cls, connection_type, model, info, args, resolved): + async def resolve_connection_async( + cls, connection_type, model, info, args, resolved + ): session = get_session(info.context) if resolved is None: query = cls.get_query(model, info, **args) From 6a008462dab69a2f0986c126fcf8f59a1593a5d2 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Tue, 7 Jun 2022 16:08:19 +0200 Subject: [PATCH 09/19] chore: add specific error if awaitable is returned in synchronous execution context --- graphene_sqlalchemy/fields.py | 3 ++- graphene_sqlalchemy/types.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 09f48672..16ec6fc5 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -11,7 +11,8 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import EnumValue, get_query, get_session, is_sqlalchemy_version_less_than +from .utils import (EnumValue, get_query, get_session, + is_sqlalchemy_version_less_than) if not is_sqlalchemy_version_less_than("1.4"): from sqlalchemy.ext.asyncio import AsyncSession diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 02709361..d1a853fb 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from inspect import isawaitable from typing import Any import sqlalchemy @@ -318,6 +319,11 @@ def __init_subclass_with_meta__( def is_type_of(cls, root, info): if isinstance(root, cls): return True + if isawaitable(root): + raise Exception( + "Received coroutine instead of sql alchemy model. " + "You seem to use an async engine with synchronous schema execution" + ) if not is_mapped_instance(root): raise Exception(('Received incompatible instance "{}".').format(root)) return isinstance(root, cls._meta.model) From fff782f03b7b3c9d6e718db42eea90b2651d24d7 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Thu, 15 Sep 2022 18:15:02 +0200 Subject: [PATCH 10/19] test: use pytest_asyncio.fixture instead normal fixture, fix issues in batching test --- graphene_sqlalchemy/tests/conftest.py | 7 +++-- graphene_sqlalchemy/tests/test_batching.py | 33 +++++++++++----------- setup.py | 2 +- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 1a975e4b..2d8293b2 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,4 +1,5 @@ import pytest +import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker @@ -38,7 +39,7 @@ def test_db_url(async_session: bool): @pytest.mark.asyncio -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def session_factory(async_session: bool, test_db_url: str): if async_session: if is_sqlalchemy_version_less_than("1.4"): @@ -57,7 +58,7 @@ async def session_factory(async_session: bool, test_db_url: str): engine.dispose() -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") async def sync_session_factory(): engine = create_engine("sqlite://") Base.metadata.create_all(engine) @@ -67,6 +68,6 @@ async def sync_session_factory(): engine.dispose() -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") def session(session_factory): return session_factory() diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index edadde24..99d0a78b 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -513,8 +513,6 @@ async def test_many_to_many(sync_session_factory): ) messages = sqlalchemy_logging_handler.messages - print(messages) - print(result) assert not result.errors result = to_std_dicts(result.data) assert result == { @@ -875,7 +873,9 @@ def resolve_reporters(self, info): @pytest.mark.asyncio -async def test_batching_across_nested_relay_schema(session_factory): +async def test_batching_across_nested_relay_schema( + session_factory, async_session: bool +): session = session_factory() for first_name in "fgerbhjikzutzxsdfdqqa": @@ -890,8 +890,8 @@ async def test_batching_across_nested_relay_schema(session_factory): reader.articles = [article] session.add(reader) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_full_relay_schema() @@ -929,17 +929,18 @@ async def test_batching_across_nested_relay_schema(session_factory): messages = sqlalchemy_logging_handler.messages result = to_std_dicts(result.data) - print(result) select_statements = [message for message in messages if "SELECT" in message] - print(select_statements) - assert len(select_statements) == 4 - assert select_statements[-1].startswith("SELECT articles_1.id") - if is_sqlalchemy_version_less_than("1.3"): - assert select_statements[-2].startswith("SELECT reporters_1.id") - assert "WHERE reporters_1.id IN" in select_statements[-2] + if async_session: + assert len(select_statements) == 2 # TODO: Figure out why async has less calls else: - assert select_statements[-2].startswith("SELECT articles.reporter_id") - assert "WHERE articles.reporter_id IN" in select_statements[-2] + assert len(select_statements) == 4 + assert select_statements[-1].startswith("SELECT articles_1.id") + if is_sqlalchemy_version_less_than("1.3"): + assert select_statements[-2].startswith("SELECT reporters_1.id") + assert "WHERE reporters_1.id IN" in select_statements[-2] + else: + assert select_statements[-2].startswith("SELECT articles.reporter_id") + assert "WHERE articles.reporter_id IN" in select_statements[-2] @pytest.mark.asyncio @@ -953,8 +954,8 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f article_1.reporter = reporter_1 session.add(article_1) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_full_relay_schema() diff --git a/setup.py b/setup.py index c5e56df0..e5b0f58e 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ tests_require = [ "pytest>=6.2.0,<7.0", - "pytest-asyncio>=0.15.1", + "pytest-asyncio>=0.18.3", "pytest-cov>=2.11.0,<3.0", "sqlalchemy_utils>=0.37.0,<1.0", "pytest-benchmark>=3.4.0,<4.0", From 125023190cda78208a2c451271820de780580779 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Fri, 7 Oct 2022 10:27:41 +0200 Subject: [PATCH 11/19] chore: remove duplicate eventually_await_session --- graphene_sqlalchemy/tests/test_batching.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 99d0a78b..1c99c0e0 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -12,7 +12,7 @@ from ..types import ORMField, SQLAlchemyObjectType from ..utils import get_session, is_sqlalchemy_version_less_than from .models_batching import Article, HairKind, Pet, Reader, Reporter -from .utils import remove_cache_miss_stat, to_std_dicts +from .utils import eventually_await_session, remove_cache_miss_stat, to_std_dicts if not is_sqlalchemy_version_less_than("1.4"): from sqlalchemy.ext.asyncio import AsyncSession @@ -125,13 +125,6 @@ def resolve_reporters(self, info): pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) -async def eventually_await_session(session, func, *args): - if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): - await getattr(session, func)(*args) - else: - getattr(session, func)(*args) - - def get_full_relay_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: From eee2314de6e02c64e1092fdf1631b37ea734cc34 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Fri, 7 Oct 2022 10:29:23 +0200 Subject: [PATCH 12/19] chore: remove duplicate skip statement --- graphene_sqlalchemy/tests/test_batching.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 1c99c0e0..73989d66 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -159,10 +159,6 @@ class Query(graphene.ObjectType): return graphene.Schema(query=Query) -if is_sqlalchemy_version_less_than("1.2"): - pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) - - @pytest.mark.asyncio @pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) async def test_many_to_one(sync_session_factory, schema_provider): From bacf15d55b8d620aeed2d545c67fe5be5c9d5639 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Fri, 7 Oct 2022 14:42:08 +0200 Subject: [PATCH 13/19] fix: fix benchmark tests not being executed properly --- graphene_sqlalchemy/tests/models.py | 10 +++--- graphene_sqlalchemy/tests/test_benchmark.py | 40 +++++++++++---------- setup.py | 5 ++- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 2ef649c8..1a70c1eb 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -19,7 +19,7 @@ ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import column_property, composite, mapper, relationship +from sqlalchemy.orm import column_property, composite, mapper, relationship, backref PetKind = Enum("cat", "dog", name="pet_kind") @@ -81,7 +81,9 @@ class Reporter(Base): order_by="Pet.id", lazy="selectin", ) - articles = relationship("Article", backref="reporter", lazy="selectin") + articles = relationship( + "Article", backref=backref("reporter", lazy="selectin"), lazy="selectin" + ) favorite_article = relationship("Article", uselist=False, lazy="selectin") @hybrid_property @@ -117,9 +119,7 @@ def hybrid_prop_list(self) -> List[int]: select([func.cast(func.count(id), Integer)]), doc="Column property" ) - composite_prop = composite( - CompositeFullName, first_name, last_name, doc="Composite" - ) + composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") class Article(Base): diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 5fa0acfd..906e3a26 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,3 +1,4 @@ +import asyncio import pytest from sqlalchemy import select @@ -83,19 +84,26 @@ def resolve_reporters(self, info): return graphene.Schema(query=Query) -async def benchmark_query(session_factory, benchmark, query, schema): - schema = get_schema() +async def benchmark_query(session, benchmark, schema, query): + import nest_asyncio - @benchmark - async def execute_query(): - result = await schema.execute_async( - query, - context_value={"session": session_factory()}, + nest_asyncio.apply() + loop = asyncio.get_event_loop() + result = benchmark( + lambda: loop.run_until_complete( + schema.execute_async(query, context_value={"session": session}) ) - assert not result.errors + ) + assert not result.errors + + +@pytest.fixture(params=[get_schema, get_async_schema]) +def schema_provider(request, async_session): + if async_session and request.param == get_schema: + pytest.skip("Cannot test sync schema with async sessions") + return request.param -@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) @pytest.mark.asyncio async def test_one_to_one(session_factory, benchmark, schema_provider): session = session_factory() @@ -122,7 +130,7 @@ async def test_one_to_one(session_factory, benchmark, schema_provider): await eventually_await_session(session, "close") await benchmark_query( - session_factory, + session, benchmark, schema, """ @@ -138,12 +146,10 @@ async def test_one_to_one(session_factory, benchmark, schema_provider): ) -@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) @pytest.mark.asyncio async def test_many_to_one(session_factory, benchmark, schema_provider): session = session_factory() schema = schema_provider() - reporter_1 = Reporter( first_name="Reporter_1", ) @@ -160,12 +166,12 @@ async def test_many_to_one(session_factory, benchmark, schema_provider): article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) - + await eventually_await_session(session, "flush") await eventually_await_session(session, "commit") await eventually_await_session(session, "close") await benchmark_query( - session_factory, + session, benchmark, schema, """ @@ -182,7 +188,6 @@ async def test_many_to_one(session_factory, benchmark, schema_provider): @pytest.mark.asyncio -@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) async def test_one_to_many(session_factory, benchmark, schema_provider): session = session_factory() schema = schema_provider() @@ -216,7 +221,7 @@ async def test_one_to_many(session_factory, benchmark, schema_provider): await eventually_await_session(session, "close") await benchmark_query( - session_factory, + session, benchmark, schema, """ @@ -236,7 +241,6 @@ async def test_one_to_many(session_factory, benchmark, schema_provider): ) -@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) @pytest.mark.asyncio async def test_many_to_many(session_factory, benchmark, schema_provider): session = session_factory() @@ -272,7 +276,7 @@ async def test_many_to_many(session_factory, benchmark, schema_provider): await eventually_await_session(session, "close") await benchmark_query( - session_factory, + session, benchmark, schema, """ diff --git a/setup.py b/setup.py index e5b0f58e..c18a9515 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,7 @@ _version_re = re.compile(r"__version__\s+=\s+(.*)") with open("graphene_sqlalchemy/__init__.py", "rb") as f: - version = str( - ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) - ) + version = str(ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))) requirements = [ # To keep things simple, we only support newer versions of Graphene @@ -26,6 +24,7 @@ "sqlalchemy_utils>=0.37.0,<1.0", "pytest-benchmark>=3.4.0,<4.0", "aiosqlite>=0.17.0", + "nest-asyncio", ] setup( From 2bc6f8493a8940a7d3acf915aa1bfa39e8019557 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Fri, 7 Oct 2022 14:50:53 +0200 Subject: [PATCH 14/19] chore: format files --- graphene_sqlalchemy/tests/models.py | 6 ++++-- graphene_sqlalchemy/tests/test_benchmark.py | 1 + setup.py | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 1a70c1eb..0e7e1870 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -19,7 +19,7 @@ ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import column_property, composite, mapper, relationship, backref +from sqlalchemy.orm import backref, column_property, composite, mapper, relationship PetKind = Enum("cat", "dog", name="pet_kind") @@ -119,7 +119,9 @@ def hybrid_prop_list(self) -> List[int]: select([func.cast(func.count(id), Integer)]), doc="Column property" ) - composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") + composite_prop = composite( + CompositeFullName, first_name, last_name, doc="Composite" + ) class Article(Base): diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 906e3a26..86dcf0d4 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,4 +1,5 @@ import asyncio + import pytest from sqlalchemy import select diff --git a/setup.py b/setup.py index c18a9515..b6a2c95c 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,9 @@ _version_re = re.compile(r"__version__\s+=\s+(.*)") with open("graphene_sqlalchemy/__init__.py", "rb") as f: - version = str(ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))) + version = str( + ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) + ) requirements = [ # To keep things simple, we only support newer versions of Graphene From a968ff85319354af31a0cba51c723ef38114ce6f Mon Sep 17 00:00:00 2001 From: Jendrik Date: Fri, 7 Oct 2022 14:59:27 +0200 Subject: [PATCH 15/19] chore: move is_graphene_version_less_than to top of file --- graphene_sqlalchemy/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index bdf865a3..d9245cc9 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -17,6 +17,13 @@ def is_sqlalchemy_version_less_than(version_string): ).parsed_version < pkg_resources.parse_version(version_string) +def is_graphene_version_less_than(version_string): # pragma: no cover + """Check the installed graphene version""" + return pkg_resources.get_distribution( + "graphene" + ).parsed_version < pkg_resources.parse_version(version_string) + + if not is_sqlalchemy_version_less_than("1.4"): from sqlalchemy.ext.asyncio import AsyncSession @@ -167,13 +174,6 @@ def sort_argument_for_model(cls, has_default=True): return Argument(List(enum), default_value=enum.default) -def is_graphene_version_less_than(version_string): # pragma: no cover - """Check the installed graphene version""" - return pkg_resources.get_distribution( - "graphene" - ).parsed_version < pkg_resources.parse_version(version_string) - - class singledispatchbymatchfunction: """ Inspired by @singledispatch, this is a variant that works using a matcher function From 1039f03f9c49c36a922017bf5fff21bb2e9a88f2 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Fri, 7 Oct 2022 15:07:10 +0200 Subject: [PATCH 16/19] test: remove unnecessary pytest.mark.asyncio, auto-reformatting --- examples/flask_sqlalchemy/app.py | 3 +-- examples/flask_sqlalchemy/database.py | 4 +--- examples/flask_sqlalchemy/schema.py | 4 ++-- examples/nameko_sqlalchemy/app.py | 3 +-- examples/nameko_sqlalchemy/database.py | 4 +--- examples/nameko_sqlalchemy/schema.py | 4 ++-- graphene_sqlalchemy/converter.py | 5 ++--- graphene_sqlalchemy/enums.py | 3 +-- graphene_sqlalchemy/fields.py | 9 +++------ graphene_sqlalchemy/registry.py | 11 +++-------- graphene_sqlalchemy/tests/conftest.py | 2 +- graphene_sqlalchemy/tests/models.py | 4 +--- graphene_sqlalchemy/tests/test_batching.py | 13 ++++--------- graphene_sqlalchemy/tests/test_benchmark.py | 5 ++--- graphene_sqlalchemy/tests/test_converter.py | 7 +++---- graphene_sqlalchemy/tests/test_enums.py | 16 +++++----------- graphene_sqlalchemy/tests/test_fields.py | 3 +-- graphene_sqlalchemy/tests/test_query.py | 5 ++--- graphene_sqlalchemy/tests/test_query_enums.py | 2 +- graphene_sqlalchemy/tests/test_registry.py | 5 ++--- graphene_sqlalchemy/tests/test_sort_enums.py | 5 +---- graphene_sqlalchemy/tests/test_types.py | 11 +++++------ graphene_sqlalchemy/tests/test_utils.py | 1 - graphene_sqlalchemy/types.py | 11 ++++------- setup.py | 4 +--- 25 files changed, 50 insertions(+), 94 deletions(-) diff --git a/examples/flask_sqlalchemy/app.py b/examples/flask_sqlalchemy/app.py index 1066020c..ab13857e 100755 --- a/examples/flask_sqlalchemy/app.py +++ b/examples/flask_sqlalchemy/app.py @@ -2,9 +2,8 @@ from database import db_session, init_db from flask import Flask -from schema import schema - from flask_graphql import GraphQLView +from schema import schema app = Flask(__name__) app.debug = True diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index 74ec7ca9..0fd39399 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -3,9 +3,7 @@ from sqlalchemy.orm import scoped_session, sessionmaker engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) -db_session = scoped_session( - sessionmaker(autocommit=False, autoflush=False, bind=engine) -) +db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine)) Base = declarative_base() Base.query = db_session.query_property() diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index c4a91e63..51e2bdc4 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -1,9 +1,9 @@ +import graphene +from graphene import relay from models import Department as DepartmentModel from models import Employee as EmployeeModel from models import Role as RoleModel -import graphene -from graphene import relay from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType diff --git a/examples/nameko_sqlalchemy/app.py b/examples/nameko_sqlalchemy/app.py index 64d305ea..dfc35189 100755 --- a/examples/nameko_sqlalchemy/app.py +++ b/examples/nameko_sqlalchemy/app.py @@ -1,6 +1,4 @@ from database import db_session, init_db -from schema import schema - from graphql_server import ( HttpQueryError, default_format_error, @@ -9,6 +7,7 @@ load_json_body, run_http_query, ) +from schema import schema class App: diff --git a/examples/nameko_sqlalchemy/database.py b/examples/nameko_sqlalchemy/database.py index 74ec7ca9..0fd39399 100644 --- a/examples/nameko_sqlalchemy/database.py +++ b/examples/nameko_sqlalchemy/database.py @@ -3,9 +3,7 @@ from sqlalchemy.orm import scoped_session, sessionmaker engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) -db_session = scoped_session( - sessionmaker(autocommit=False, autoflush=False, bind=engine) -) +db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine)) Base = declarative_base() Base.query = db_session.query_property() diff --git a/examples/nameko_sqlalchemy/schema.py b/examples/nameko_sqlalchemy/schema.py index ced300b3..77ecf164 100644 --- a/examples/nameko_sqlalchemy/schema.py +++ b/examples/nameko_sqlalchemy/schema.py @@ -1,9 +1,9 @@ +import graphene +from graphene import relay from models import Department as DepartmentModel from models import Employee as EmployeeModel from models import Role as RoleModel -import graphene -from graphene import relay from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index d1873c2b..8db7e62d 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -6,13 +6,12 @@ from functools import singledispatch from typing import Any, cast +import graphene +from graphene.types.json import JSONString from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql from sqlalchemy.orm import interfaces, strategies -import graphene -from graphene.types.json import JSONString - from .batching import get_batch_resolver from .enums import enum_for_sa_enum from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index 97f8997c..d35a47df 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -1,8 +1,7 @@ +from graphene import Argument, Enum, List from sqlalchemy.orm import ColumnProperty from sqlalchemy.types import Enum as SQLAlchemyEnumType -from graphene import Argument, Enum, List - from .utils import EnumValue, to_enum_value_name, to_type_name diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 9f69b53f..8377d2ef 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -2,13 +2,12 @@ import warnings from functools import partial -from promise import Promise, is_thenable -from sqlalchemy.orm.query import Query - from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import connection_adapter, page_info_adapter from graphql_relay import connection_from_array_slice +from promise import Promise, is_thenable +from sqlalchemy.orm.query import Query from .batching import get_batch_resolver from .utils import EnumValue, get_query, get_session, is_sqlalchemy_version_less_than @@ -122,9 +121,7 @@ def adjusted_connection_adapter(edges, pageInfo): return connection @classmethod - async def resolve_connection_async( - cls, connection_type, model, info, args, resolved - ): + async def resolve_connection_async(cls, connection_type, model, info, args, resolved): session = get_session(info.context) if resolved is None: query = cls.get_query(model, info, **args) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 8f2bc9e7..a53c0974 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,10 +1,9 @@ from collections import defaultdict from typing import List, Type -from sqlalchemy.types import Enum as SQLAlchemyEnumType - import graphene from graphene import Enum +from sqlalchemy.types import Enum as SQLAlchemyEnumType class Registry(object): @@ -61,13 +60,9 @@ def get_converter_for_composite(self, composite): def register_enum(self, sa_enum: SQLAlchemyEnumType, graphene_enum: Enum): if not isinstance(sa_enum, SQLAlchemyEnumType): - raise TypeError( - "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) - ) + raise TypeError("Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum)) if not isinstance(graphene_enum, type(Enum)): - raise TypeError( - "Expected Graphene Enum, but got: {!r}".format(graphene_enum) - ) + raise TypeError("Expected Graphene Enum, but got: {!r}".format(graphene_enum)) self._registry_enums[sa_enum] = graphene_enum diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 2d8293b2..2343c6b3 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,9 +1,9 @@ +import graphene import pytest import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -import graphene from graphene_sqlalchemy.utils import is_sqlalchemy_version_less_than from ..converter import convert_sqlalchemy_composite diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 0e7e1870..2b967b2d 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -119,9 +119,7 @@ def hybrid_prop_list(self) -> List[int]: select([func.cast(func.count(id), Integer)]), doc="Column property" ) - composite_prop = composite( - CompositeFullName, first_name, last_name, doc="Composite" - ) + composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") class Article(Base): diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 73989d66..ff2fbed2 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -2,11 +2,10 @@ import contextlib import logging -import pytest -from sqlalchemy import select - import graphene +import pytest from graphene import Connection, relay +from sqlalchemy import select from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from ..types import ORMField, SQLAlchemyObjectType @@ -651,7 +650,6 @@ def resolve_reporters(self, info): assert len(select_statements) == 2 -@pytest.mark.asyncio def test_batch_sorting_with_custom_ormfield(sync_session_factory): session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") @@ -862,9 +860,7 @@ def resolve_reporters(self, info): @pytest.mark.asyncio -async def test_batching_across_nested_relay_schema( - session_factory, async_session: bool -): +async def test_batching_across_nested_relay_schema(session_factory, async_session: bool): session = session_factory() for first_name in "fgerbhjikzutzxsdfdqqa": @@ -967,6 +963,5 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f result = to_std_dicts(result.data) assert [ - r["node"]["firstName"] + r["node"]["email"] - for r in result["reporters"]["edges"] + r["node"]["firstName"] + r["node"]["email"] for r in result["reporters"]["edges"] ] == ["aa", "ba", "bb", "bc", "ca", "da"] diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 86dcf0d4..5a6610cd 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,10 +1,9 @@ import asyncio -import pytest -from sqlalchemy import select - import graphene +import pytest from graphene import relay +from sqlalchemy import select from ..types import SQLAlchemyObjectType from ..utils import get_session, is_sqlalchemy_version_less_than diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 812b4cea..f1503559 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -2,8 +2,11 @@ import sys from typing import Dict, Union +import graphene import pytest import sqlalchemy_utils as sqa_utils +from graphene.relay import Node +from graphene.types.structures import Structure from sqlalchemy import Column, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base @@ -11,10 +14,6 @@ from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -import graphene -from graphene.relay import Node -from graphene.types.structures import Structure - from ..converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py index 3de6904b..ec4546b4 100644 --- a/graphene_sqlalchemy/tests/test_enums.py +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -1,9 +1,8 @@ from enum import Enum as PyEnum import pytest -from sqlalchemy.types import Enum as SQLAlchemyEnumType - from graphene import Enum +from sqlalchemy.types import Enum as SQLAlchemyEnumType from ..enums import _convert_sa_to_graphene_enum, enum_for_field from ..types import SQLAlchemyObjectType @@ -41,8 +40,7 @@ class Color(PyEnum): assert graphene_enum._meta.name == "Color" assert graphene_enum._meta.enum is not Color assert [ - (key, value.value) - for key, value in graphene_enum._meta.enum.__members__.items() + (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() ] == [("RED", 1), ("GREEN", 2), ("BLUE", 3)] @@ -52,8 +50,7 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): assert isinstance(graphene_enum, type(Enum)) assert graphene_enum._meta.name == "ColorValues" assert [ - (key, value.value) - for key, value in graphene_enum._meta.enum.__members__.items() + (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] @@ -63,8 +60,7 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): assert isinstance(graphene_enum, type(Enum)) assert graphene_enum._meta.name == "FallbackName" assert [ - (key, value.value) - for key, value in graphene_enum._meta.enum.__members__.items() + (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] @@ -83,9 +79,7 @@ class Meta: enum = enum_for_field(PetType, "pet_kind") assert isinstance(enum, type(Enum)) assert enum._meta.name == "PetKind" - assert [ - (key, value.value) for key, value in enum._meta.enum.__members__.items() - ] == [ + assert [(key, value.value) for key, value in enum._meta.enum.__members__.items()] == [ ("CAT", "cat"), ("DOG", "dog"), ] diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index 9fed146d..c5a1040b 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,8 +1,7 @@ import pytest -from promise import Promise - from graphene import NonNull, ObjectType from graphene.relay import Connection, Node +from promise import Promise from ..fields import SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 3c4e47be..3e2ba7ee 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,8 +1,7 @@ -import pytest -from sqlalchemy import select - import graphene +import pytest from graphene.relay import Node +from sqlalchemy import select from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index a329b88a..6a8428c7 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -1,7 +1,7 @@ +import graphene import pytest from sqlalchemy import select -import graphene from graphene_sqlalchemy.tests.utils import eventually_await_session from graphene_sqlalchemy.utils import get_session, is_sqlalchemy_version_less_than diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index cb7e9034..5aea8d76 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -1,8 +1,7 @@ -import pytest -from sqlalchemy.types import Enum as SQLAlchemyEnum - import graphene +import pytest from graphene import Enum as GrapheneEnum +from sqlalchemy.types import Enum as SQLAlchemyEnum from ..registry import Registry from ..types import SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index f8f1ff8c..5461885d 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -1,6 +1,5 @@ import pytest import sqlalchemy as sa - from graphene import Argument, Enum, List, ObjectType, Schema from graphene.relay import Node @@ -317,9 +316,7 @@ def makeNodes(nodeList): return {"edges": nodes} expected = { - "defaultSort": makeNodes( - [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] - ), + "defaultSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), "noDefaultSort": makeNodes( [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 4637a115..6e61fb63 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -3,9 +3,6 @@ import pytest import sqlalchemy.exc import sqlalchemy.orm.exc -from graphql.pyutils import is_awaitable -from sqlalchemy import select - from graphene import ( Boolean, Dynamic, @@ -21,6 +18,8 @@ String, ) from graphene.relay import Connection +from graphql.pyutils import is_awaitable +from sqlalchemy import select from .. import utils from ..converter import convert_sqlalchemy_composite @@ -512,9 +511,9 @@ class Meta: def __init_subclass_with_meta__(cls, custom_option=None, **options): _meta = CustomOptions(cls) _meta.custom_option = custom_option - super( - SQLAlchemyObjectTypeWithCustomOptions, cls - ).__init_subclass_with_meta__(_meta=_meta, **options) + super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): class Meta: diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index 75328280..b88600f0 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -1,6 +1,5 @@ import pytest import sqlalchemy as sa - from graphene import Enum, List, ObjectType, Schema, String from ..utils import ( diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 69305fe3..8240ccda 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -3,15 +3,14 @@ from typing import Any import sqlalchemy -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty -from sqlalchemy.orm.exc import NoResultFound - from graphene import Field from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty +from sqlalchemy.orm.exc import NoResultFound from .converter import ( convert_sqlalchemy_column, @@ -283,9 +282,7 @@ def __init_subclass_with_meta__( ) if use_connection is None and interfaces: - use_connection = any( - issubclass(interface, Node) for interface in interfaces - ) + use_connection = any(issubclass(interface, Node) for interface in interfaces) if use_connection and not connection: # We create the connection automatically diff --git a/setup.py b/setup.py index b6a2c95c..c18a9515 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,7 @@ _version_re = re.compile(r"__version__\s+=\s+(.*)") with open("graphene_sqlalchemy/__init__.py", "rb") as f: - version = str( - ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) - ) + version = str(ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))) requirements = [ # To keep things simple, we only support newer versions of Graphene From e61df343fb18f4c2e14c2d32ad87034ef4392ee5 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Fri, 7 Oct 2022 15:10:31 +0200 Subject: [PATCH 17/19] chore: revert faulty formatting --- examples/flask_sqlalchemy/app.py | 3 ++- examples/flask_sqlalchemy/database.py | 4 +++- examples/flask_sqlalchemy/schema.py | 4 ++-- examples/nameko_sqlalchemy/app.py | 3 ++- examples/nameko_sqlalchemy/database.py | 4 +++- examples/nameko_sqlalchemy/schema.py | 4 ++-- graphene_sqlalchemy/converter.py | 5 +++-- graphene_sqlalchemy/enums.py | 3 ++- graphene_sqlalchemy/fields.py | 9 ++++++--- graphene_sqlalchemy/registry.py | 11 ++++++++--- graphene_sqlalchemy/tests/conftest.py | 2 +- graphene_sqlalchemy/tests/models.py | 4 +++- graphene_sqlalchemy/tests/test_batching.py | 12 ++++++++---- graphene_sqlalchemy/tests/test_benchmark.py | 5 +++-- graphene_sqlalchemy/tests/test_converter.py | 7 ++++--- graphene_sqlalchemy/tests/test_enums.py | 16 +++++++++++----- graphene_sqlalchemy/tests/test_fields.py | 3 ++- graphene_sqlalchemy/tests/test_query.py | 5 +++-- graphene_sqlalchemy/tests/test_query_enums.py | 2 +- graphene_sqlalchemy/tests/test_registry.py | 5 +++-- graphene_sqlalchemy/tests/test_sort_enums.py | 5 ++++- graphene_sqlalchemy/tests/test_types.py | 11 ++++++----- graphene_sqlalchemy/tests/test_utils.py | 1 + graphene_sqlalchemy/types.py | 11 +++++++---- setup.py | 4 +++- 25 files changed, 93 insertions(+), 50 deletions(-) diff --git a/examples/flask_sqlalchemy/app.py b/examples/flask_sqlalchemy/app.py index ab13857e..1066020c 100755 --- a/examples/flask_sqlalchemy/app.py +++ b/examples/flask_sqlalchemy/app.py @@ -2,9 +2,10 @@ from database import db_session, init_db from flask import Flask -from flask_graphql import GraphQLView from schema import schema +from flask_graphql import GraphQLView + app = Flask(__name__) app.debug = True diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index 0fd39399..74ec7ca9 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -3,7 +3,9 @@ from sqlalchemy.orm import scoped_session, sessionmaker engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine)) +db_session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) Base = declarative_base() Base.query = db_session.query_property() diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index 51e2bdc4..c4a91e63 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -1,9 +1,9 @@ -import graphene -from graphene import relay from models import Department as DepartmentModel from models import Employee as EmployeeModel from models import Role as RoleModel +import graphene +from graphene import relay from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType diff --git a/examples/nameko_sqlalchemy/app.py b/examples/nameko_sqlalchemy/app.py index dfc35189..64d305ea 100755 --- a/examples/nameko_sqlalchemy/app.py +++ b/examples/nameko_sqlalchemy/app.py @@ -1,4 +1,6 @@ from database import db_session, init_db +from schema import schema + from graphql_server import ( HttpQueryError, default_format_error, @@ -7,7 +9,6 @@ load_json_body, run_http_query, ) -from schema import schema class App: diff --git a/examples/nameko_sqlalchemy/database.py b/examples/nameko_sqlalchemy/database.py index 0fd39399..74ec7ca9 100644 --- a/examples/nameko_sqlalchemy/database.py +++ b/examples/nameko_sqlalchemy/database.py @@ -3,7 +3,9 @@ from sqlalchemy.orm import scoped_session, sessionmaker engine = create_engine("sqlite:///database.sqlite3", convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine)) +db_session = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) Base = declarative_base() Base.query = db_session.query_property() diff --git a/examples/nameko_sqlalchemy/schema.py b/examples/nameko_sqlalchemy/schema.py index 77ecf164..ced300b3 100644 --- a/examples/nameko_sqlalchemy/schema.py +++ b/examples/nameko_sqlalchemy/schema.py @@ -1,9 +1,9 @@ -import graphene -from graphene import relay from models import Department as DepartmentModel from models import Employee as EmployeeModel from models import Role as RoleModel +import graphene +from graphene import relay from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 8db7e62d..d1873c2b 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -6,12 +6,13 @@ from functools import singledispatch from typing import Any, cast -import graphene -from graphene.types.json import JSONString from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql from sqlalchemy.orm import interfaces, strategies +import graphene +from graphene.types.json import JSONString + from .batching import get_batch_resolver from .enums import enum_for_sa_enum from .fields import BatchSQLAlchemyConnectionField, default_connection_field_factory diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index d35a47df..97f8997c 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -1,7 +1,8 @@ -from graphene import Argument, Enum, List from sqlalchemy.orm import ColumnProperty from sqlalchemy.types import Enum as SQLAlchemyEnumType +from graphene import Argument, Enum, List + from .utils import EnumValue, to_enum_value_name, to_type_name diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 8377d2ef..9f69b53f 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -2,12 +2,13 @@ import warnings from functools import partial +from promise import Promise, is_thenable +from sqlalchemy.orm.query import Query + from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import connection_adapter, page_info_adapter from graphql_relay import connection_from_array_slice -from promise import Promise, is_thenable -from sqlalchemy.orm.query import Query from .batching import get_batch_resolver from .utils import EnumValue, get_query, get_session, is_sqlalchemy_version_less_than @@ -121,7 +122,9 @@ def adjusted_connection_adapter(edges, pageInfo): return connection @classmethod - async def resolve_connection_async(cls, connection_type, model, info, args, resolved): + async def resolve_connection_async( + cls, connection_type, model, info, args, resolved + ): session = get_session(info.context) if resolved is None: query = cls.get_query(model, info, **args) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index a53c0974..8f2bc9e7 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,9 +1,10 @@ from collections import defaultdict from typing import List, Type +from sqlalchemy.types import Enum as SQLAlchemyEnumType + import graphene from graphene import Enum -from sqlalchemy.types import Enum as SQLAlchemyEnumType class Registry(object): @@ -60,9 +61,13 @@ def get_converter_for_composite(self, composite): def register_enum(self, sa_enum: SQLAlchemyEnumType, graphene_enum: Enum): if not isinstance(sa_enum, SQLAlchemyEnumType): - raise TypeError("Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum)) + raise TypeError( + "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) + ) if not isinstance(graphene_enum, type(Enum)): - raise TypeError("Expected Graphene Enum, but got: {!r}".format(graphene_enum)) + raise TypeError( + "Expected Graphene Enum, but got: {!r}".format(graphene_enum) + ) self._registry_enums[sa_enum] = graphene_enum diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 2343c6b3..2d8293b2 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,9 +1,9 @@ -import graphene import pytest import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +import graphene from graphene_sqlalchemy.utils import is_sqlalchemy_version_less_than from ..converter import convert_sqlalchemy_composite diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 2b967b2d..0e7e1870 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -119,7 +119,9 @@ def hybrid_prop_list(self) -> List[int]: select([func.cast(func.count(id), Integer)]), doc="Column property" ) - composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") + composite_prop = composite( + CompositeFullName, first_name, last_name, doc="Composite" + ) class Article(Base): diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index ff2fbed2..2659136d 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -2,11 +2,12 @@ import contextlib import logging -import graphene import pytest -from graphene import Connection, relay from sqlalchemy import select +import graphene +from graphene import Connection, relay + from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from ..types import ORMField, SQLAlchemyObjectType from ..utils import get_session, is_sqlalchemy_version_less_than @@ -860,7 +861,9 @@ def resolve_reporters(self, info): @pytest.mark.asyncio -async def test_batching_across_nested_relay_schema(session_factory, async_session: bool): +async def test_batching_across_nested_relay_schema( + session_factory, async_session: bool +): session = session_factory() for first_name in "fgerbhjikzutzxsdfdqqa": @@ -963,5 +966,6 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f result = to_std_dicts(result.data) assert [ - r["node"]["firstName"] + r["node"]["email"] for r in result["reporters"]["edges"] + r["node"]["firstName"] + r["node"]["email"] + for r in result["reporters"]["edges"] ] == ["aa", "ba", "bb", "bc", "ca", "da"] diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 5a6610cd..86dcf0d4 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,10 +1,11 @@ import asyncio -import graphene import pytest -from graphene import relay from sqlalchemy import select +import graphene +from graphene import relay + from ..types import SQLAlchemyObjectType from ..utils import get_session, is_sqlalchemy_version_less_than from .models import Article, HairKind, Pet, Reporter diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index f1503559..812b4cea 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -2,11 +2,8 @@ import sys from typing import Dict, Union -import graphene import pytest import sqlalchemy_utils as sqa_utils -from graphene.relay import Node -from graphene.types.structures import Structure from sqlalchemy import Column, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base @@ -14,6 +11,10 @@ from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite +import graphene +from graphene.relay import Node +from graphene.types.structures import Structure + from ..converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py index ec4546b4..3de6904b 100644 --- a/graphene_sqlalchemy/tests/test_enums.py +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -1,9 +1,10 @@ from enum import Enum as PyEnum import pytest -from graphene import Enum from sqlalchemy.types import Enum as SQLAlchemyEnumType +from graphene import Enum + from ..enums import _convert_sa_to_graphene_enum, enum_for_field from ..types import SQLAlchemyObjectType from .models import HairKind, Pet @@ -40,7 +41,8 @@ class Color(PyEnum): assert graphene_enum._meta.name == "Color" assert graphene_enum._meta.enum is not Color assert [ - (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() ] == [("RED", 1), ("GREEN", 2), ("BLUE", 3)] @@ -50,7 +52,8 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): assert isinstance(graphene_enum, type(Enum)) assert graphene_enum._meta.name == "ColorValues" assert [ - (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] @@ -60,7 +63,8 @@ def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): assert isinstance(graphene_enum, type(Enum)) assert graphene_enum._meta.name == "FallbackName" assert [ - (key, value.value) for key, value in graphene_enum._meta.enum.__members__.items() + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() ] == [("RED", "red"), ("GREEN", "green"), ("BLUE", "blue")] @@ -79,7 +83,9 @@ class Meta: enum = enum_for_field(PetType, "pet_kind") assert isinstance(enum, type(Enum)) assert enum._meta.name == "PetKind" - assert [(key, value.value) for key, value in enum._meta.enum.__members__.items()] == [ + assert [ + (key, value.value) for key, value in enum._meta.enum.__members__.items() + ] == [ ("CAT", "cat"), ("DOG", "dog"), ] diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index c5a1040b..9fed146d 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,7 +1,8 @@ import pytest +from promise import Promise + from graphene import NonNull, ObjectType from graphene.relay import Connection, Node -from promise import Promise from ..fields import SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField from ..types import SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 3e2ba7ee..3c4e47be 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,8 +1,9 @@ -import graphene import pytest -from graphene.relay import Node from sqlalchemy import select +import graphene +from graphene.relay import Node + from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField from ..types import ORMField, SQLAlchemyObjectType diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index 6a8428c7..a329b88a 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -1,7 +1,7 @@ -import graphene import pytest from sqlalchemy import select +import graphene from graphene_sqlalchemy.tests.utils import eventually_await_session from graphene_sqlalchemy.utils import get_session, is_sqlalchemy_version_less_than diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 5aea8d76..cb7e9034 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -1,8 +1,9 @@ -import graphene import pytest -from graphene import Enum as GrapheneEnum from sqlalchemy.types import Enum as SQLAlchemyEnum +import graphene +from graphene import Enum as GrapheneEnum + from ..registry import Registry from ..types import SQLAlchemyObjectType from ..utils import EnumValue diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index 5461885d..f8f1ff8c 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -1,5 +1,6 @@ import pytest import sqlalchemy as sa + from graphene import Argument, Enum, List, ObjectType, Schema from graphene.relay import Node @@ -316,7 +317,9 @@ def makeNodes(nodeList): return {"edges": nodes} expected = { - "defaultSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), + "defaultSort": makeNodes( + [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] + ), "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), "noDefaultSort": makeNodes( [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 6e61fb63..4637a115 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -3,6 +3,9 @@ import pytest import sqlalchemy.exc import sqlalchemy.orm.exc +from graphql.pyutils import is_awaitable +from sqlalchemy import select + from graphene import ( Boolean, Dynamic, @@ -18,8 +21,6 @@ String, ) from graphene.relay import Connection -from graphql.pyutils import is_awaitable -from sqlalchemy import select from .. import utils from ..converter import convert_sqlalchemy_composite @@ -511,9 +512,9 @@ class Meta: def __init_subclass_with_meta__(cls, custom_option=None, **options): _meta = CustomOptions(cls) _meta.custom_option = custom_option - super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) + super( + SQLAlchemyObjectTypeWithCustomOptions, cls + ).__init_subclass_with_meta__(_meta=_meta, **options) class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): class Meta: diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index b88600f0..75328280 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -1,5 +1,6 @@ import pytest import sqlalchemy as sa + from graphene import Enum, List, ObjectType, Schema, String from ..utils import ( diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 8240ccda..69305fe3 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -3,14 +3,15 @@ from typing import Any import sqlalchemy +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty +from sqlalchemy.orm.exc import NoResultFound + from graphene import Field from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty -from sqlalchemy.orm.exc import NoResultFound from .converter import ( convert_sqlalchemy_column, @@ -282,7 +283,9 @@ def __init_subclass_with_meta__( ) if use_connection is None and interfaces: - use_connection = any(issubclass(interface, Node) for interface in interfaces) + use_connection = any( + issubclass(interface, Node) for interface in interfaces + ) if use_connection and not connection: # We create the connection automatically diff --git a/setup.py b/setup.py index c18a9515..b6a2c95c 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,9 @@ _version_re = re.compile(r"__version__\s+=\s+(.*)") with open("graphene_sqlalchemy/__init__.py", "rb") as f: - version = str(ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))) + version = str( + ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) + ) requirements = [ # To keep things simple, we only support newer versions of Graphene From 2ff54dccafaf4c4eeec769ffce7d9f07db68a1f2 Mon Sep 17 00:00:00 2001 From: Jendrik Date: Mon, 31 Oct 2022 12:27:26 +0100 Subject: [PATCH 18/19] fix: run startup checkt for sqlalchemy version --- graphene_sqlalchemy/batching.py | 13 ++++---- graphene_sqlalchemy/fields.py | 8 ++--- graphene_sqlalchemy/tests/conftest.py | 6 ++-- graphene_sqlalchemy/tests/test_batching.py | 26 ++++++++-------- graphene_sqlalchemy/tests/test_benchmark.py | 16 +++++----- graphene_sqlalchemy/tests/test_query.py | 30 ++++++------------- graphene_sqlalchemy/tests/test_query_enums.py | 28 +++++------------ graphene_sqlalchemy/tests/test_types.py | 8 ++--- graphene_sqlalchemy/types.py | 6 ++-- graphene_sqlalchemy/utils.py | 8 +++-- setup.py | 1 + 11 files changed, 62 insertions(+), 88 deletions(-) diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 275d5904..29e13fef 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext -from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than +from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than class RelationshipLoader(aiodataloader.DataLoader): @@ -58,19 +58,19 @@ async def batch_load_fn(self, parents): # For our purposes, the query_context will only used to get the session query_context = None - if is_sqlalchemy_version_less_than("1.4"): - query_context = QueryContext(session.query(parent_mapper.entity)) - else: + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: parent_mapper_query = session.query(parent_mapper.entity) query_context = parent_mapper_query._compile_context() - - if is_sqlalchemy_version_less_than("1.4"): + else: + query_context = QueryContext(session.query(parent_mapper.entity)) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, states, None, child_mapper, + None, ) else: self.selectin_loader._load_for_path( @@ -79,7 +79,6 @@ async def batch_load_fn(self, parents): states, None, child_mapper, - None, ) return [getattr(parent, self.relationship_prop.key) for parent in parents] diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 9f69b53f..6dbc134f 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -11,9 +11,9 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import EnumValue, get_query, get_session, is_sqlalchemy_version_less_than +from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session -if not is_sqlalchemy_version_less_than("1.4"): +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -86,9 +86,7 @@ def get_query(cls, model, info, sort=None, **args): def resolve_connection(cls, connection_type, model, info, args, resolved): session = get_session(info.context) if resolved is None: - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): async def get_result(): return await cls.resolve_connection_async( diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 2d8293b2..89b357a4 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -4,13 +4,13 @@ from sqlalchemy.orm import sessionmaker import graphene -from graphene_sqlalchemy.utils import is_sqlalchemy_version_less_than +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry from .models import Base, CompositeFullName -if not is_sqlalchemy_version_less_than("1.4"): +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @@ -42,7 +42,7 @@ def test_db_url(async_session: bool): @pytest_asyncio.fixture(scope="function") async def session_factory(async_session: bool, test_db_url: str): if async_session: - if is_sqlalchemy_version_less_than("1.4"): + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: pytest.skip("Async Sessions only work in sql alchemy 1.4 and above") engine = create_async_engine(test_db_url) async with engine.begin() as conn: diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 2659136d..5eccd5fc 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -10,11 +10,15 @@ from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from ..types import ORMField, SQLAlchemyObjectType -from ..utils import get_session, is_sqlalchemy_version_less_than +from ..utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_session, + is_sqlalchemy_version_less_than, +) from .models_batching import Article, HairKind, Pet, Reader, Reporter from .utils import eventually_await_session, remove_cache_miss_stat, to_std_dicts -if not is_sqlalchemy_version_less_than("1.4"): +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -70,17 +74,13 @@ class Query(graphene.ObjectType): async def resolve_articles(self, info): session = get_session(info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Article))).all() return session.query(Article).all() async def resolve_reporters(self, info): session = get_session(info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).all() return session.query(Reporter).all() @@ -235,7 +235,7 @@ async def test_many_to_one(sync_session_factory, schema_provider): assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than("1.4"): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -319,7 +319,7 @@ async def test_one_to_one(sync_session_factory, schema_provider): assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than("1.4"): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -437,7 +437,7 @@ async def test_one_to_many(sync_session_factory): assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than("1.4"): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -557,7 +557,7 @@ async def test_many_to_many(sync_session_factory): assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than("1.4"): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) @@ -701,7 +701,7 @@ class Meta: context_value={"session": session}, ) messages = sqlalchemy_logging_handler.messages - + assert not result.errors result = to_std_dicts(result.data) assert result == { "reporters": { diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index 86dcf0d4..dc656f41 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -7,11 +7,15 @@ from graphene import relay from ..types import SQLAlchemyObjectType -from ..utils import get_session, is_sqlalchemy_version_less_than +from ..utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_session, + is_sqlalchemy_version_less_than, +) from .models import Article, HairKind, Pet, Reporter from .utils import eventually_await_session -if not is_sqlalchemy_version_less_than("1.4"): +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession if is_sqlalchemy_version_less_than("1.2"): pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) @@ -39,17 +43,13 @@ class Query(graphene.ObjectType): async def resolve_articles(self, info): session = get_session(info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Article))).all() return session.query(Article).all() async def resolve_reporters(self, info): session = get_session(info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).all() return session.query(Reporter).all() diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 3c4e47be..bd61b0b9 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -7,11 +7,11 @@ from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField from ..types import ORMField, SQLAlchemyObjectType -from ..utils import get_session, is_sqlalchemy_version_less_than +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter from .utils import eventually_await_session, to_std_dicts -if not is_sqlalchemy_version_less_than("1.4"): +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -52,17 +52,13 @@ class Query(graphene.ObjectType): async def resolve_reporter(self, _info): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() async def resolve_reporters(self, _info): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) @@ -120,9 +116,7 @@ class Query(graphene.ObjectType): def resolve_reporter(self, _info): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): async def get_result(): return (await session.scalars(select(Reporter))).first() @@ -172,7 +166,7 @@ async def get_result(): "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, } schema = graphene.Schema(query=Query) - if not is_sqlalchemy_version_less_than("1.4") and isinstance(session, AsyncSession): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): result = schema.execute(query, context_value={"session": session}) assert result.errors else: @@ -207,9 +201,7 @@ class Query(graphene.ObjectType): def resolve_reporter(self, _info): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): async def get_result(): return (await session.scalars(select(Reporter))).first() @@ -295,9 +287,7 @@ class Query(graphene.ObjectType): async def resolve_reporter(self, _info): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).first() return session.query(Reporter).first() @@ -397,9 +387,7 @@ class Meta: @classmethod async def get_node(cls, id, info): session = get_session(info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index a329b88a..14c87f74 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -3,13 +3,13 @@ import graphene from graphene_sqlalchemy.tests.utils import eventually_await_session -from graphene_sqlalchemy.utils import get_session, is_sqlalchemy_version_less_than +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session from ..types import SQLAlchemyObjectType from .models import HairKind, Pet, Reporter from .test_query import add_test_data, to_std_dicts -if not is_sqlalchemy_version_less_than("1.4"): +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -35,25 +35,19 @@ class Query(graphene.ObjectType): async def resolve_reporter(self, _info): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() async def resolve_reporters(self, _info): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) async def resolve_pets(self, _info, kind): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): query = select(Pet) if kind: query = query.filter(Pet.pet_kind == kind.value) @@ -126,9 +120,7 @@ class Query(graphene.ObjectType): async def resolve_pet(self, _info): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Pet))).first() return session.query(Pet).first() @@ -164,9 +156,7 @@ class Query(graphene.ObjectType): async def resolve_pet(self, info, kind=None): session = get_session(info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): query = select(Pet) if kind: query = query.filter(Pet.pet_kind == kind.value) @@ -218,9 +208,7 @@ class Query(graphene.ObjectType): async def resolve_pet(self, _info, kind=None): session = get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return ( await session.scalars( select(Pet).filter(Pet.hair_kind == HairKind(kind)) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 4637a115..4ba5e53f 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -32,11 +32,11 @@ unregisterConnectionFieldFactory, ) from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions -from ..utils import is_sqlalchemy_version_less_than +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 from .models import Article, CompositeFullName, Pet, Reporter from .utils import eventually_await_session -if not is_sqlalchemy_version_less_than("1.4"): +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -441,9 +441,7 @@ class Query(ObjectType): async def resolve_reporter(self, _info): session = utils.get_session(_info.context) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 69305fe3..b3490e3c 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -27,14 +27,14 @@ from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_query, get_session, is_mapped_class, is_mapped_instance, - is_sqlalchemy_version_less_than, ) -if not is_sqlalchemy_version_less_than("1.4"): +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession @@ -344,7 +344,7 @@ def get_query(cls, info): @classmethod def get_node(cls, info, id): - if is_sqlalchemy_version_less_than("1.4"): + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: try: return cls.get_query(info).get(id) except NoResultFound: diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index d9245cc9..62c71d8d 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -24,9 +24,13 @@ def is_graphene_version_less_than(version_string): # pragma: no cover ).parsed_version < pkg_resources.parse_version(version_string) +SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False + if not is_sqlalchemy_version_less_than("1.4"): from sqlalchemy.ext.asyncio import AsyncSession + SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True + def get_session(context): return context.get("session") @@ -41,9 +45,7 @@ def get_query(model, context): "A query in the model Base or a session in the schema is required for querying.\n" "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" ) - if not is_sqlalchemy_version_less_than("1.4") and isinstance( - session, AsyncSession - ): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): return select(model) query = session.query(model) return query diff --git a/setup.py b/setup.py index b6a2c95c..9122baf2 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "pytest-benchmark>=3.4.0,<4.0", "aiosqlite>=0.17.0", "nest-asyncio", + "greenlet", ] setup( From 1e857e077fab0fe56b1fef4ade945d31c9c95710 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 9 Dec 2022 17:00:17 +0100 Subject: [PATCH 19/19] fix: allow polymorphism with async session Signed-off-by: Erik Wrede --- docs/inheritance.rst | 42 +++++++++++++++++++++++-- graphene_sqlalchemy/tests/models.py | 1 + graphene_sqlalchemy/tests/test_query.py | 17 +++++----- 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/docs/inheritance.rst b/docs/inheritance.rst index 13645462..74732162 100644 --- a/docs/inheritance.rst +++ b/docs/inheritance.rst @@ -3,7 +3,7 @@ Inheritance Examples Create interfaces from inheritance relationships ------------------------------------------------ - +.. note:: If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_. SQLAlchemy has excellent support for class inheritance hierarchies. These hierarchies can be represented in your GraphQL schema by means of interfaces_. Much like ObjectTypes, Interfaces in @@ -87,9 +87,13 @@ and fields on concrete implementations using the `... on` syntax: } +.. danger:: + When using joined table inheritance, this style of querying may lead to unbatched implicit IO with negative performance implications. + See the chapter `Eager Loading & Using with AsyncSession`_ for more information on eager loading all possible types of a `SQLAlchemyInterface`. + Please note that by default, the "polymorphic_on" column is *not* generated as a field on types that use polymorphic inheritance, as -this is considered an implentation detail. The idiomatic way to +this is considered an implementation detail. The idiomatic way to retrieve the concrete GraphQL type of an object is to query for the `__typename` field. To override this behavior, an `ORMField` needs to be created @@ -104,4 +108,38 @@ class to the Schema constructor via the `types=` argument: schema = graphene.Schema(..., types=[PersonType, EmployeeType, CustomerType]) + See also: `Graphene Interfaces `_ + +Eager Loading & Using with AsyncSession +-------------------- +When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly. +This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables. +To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model: + +.. code:: python + class Person(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "person" + __mapper_args__ = { + "polymorphic_on": type, + "with_polymorphic": "*", # needed for eager loading in async session + } + +Alternatively, the specific polymorphic fields can be loaded explicitly in resolvers: + +.. code:: python + + class Query(graphene.ObjectType): + people = graphene.Field(graphene.List(PersonType)) + + async def resolve_people(self, _info): + return (await session.scalars(with_polymorphic(Person, [Engineer, Customer]))).all() + +Dynamic batching of the types based on the query to avoid eager is currently not supported, but could be implemented in a future PR. + +For more information on loading techniques for polymorphic models, please check out the `SQLAlchemy docs `_. diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index d39f5699..ee286585 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -310,6 +310,7 @@ class Person(Base): __tablename__ = "person" __mapper_args__ = { "polymorphic_on": type, + "with_polymorphic": "*", # needed for eager loading in async session } diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 557815a3..055a87f8 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -469,19 +469,20 @@ class Mutation(graphene.ObjectType): assert result == expected -def add_person_data(session): +async def add_person_data(session): bob = Employee(name="Bob", birth_date=date(1990, 1, 1), hire_date=date(2015, 1, 1)) session.add(bob) joe = Employee(name="Joe", birth_date=date(1980, 1, 1), hire_date=date(2010, 1, 1)) session.add(joe) jen = Employee(name="Jen", birth_date=date(1995, 1, 1), hire_date=date(2020, 1, 1)) session.add(jen) - session.commit() + await eventually_await_session(session, "commit") -def test_interface_query_on_base_type(sync_session_factory): - session = sync_session_factory() - add_person_data(session) +@pytest.mark.asyncio +async def test_interface_query_on_base_type(session_factory): + session = session_factory() + await add_person_data(session) class PersonType(SQLAlchemyInterface): class Meta: @@ -495,11 +496,13 @@ class Meta: class Query(graphene.ObjectType): people = graphene.Field(graphene.List(PersonType)) - def resolve_people(self, _info): + async def resolve_people(self, _info): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Person))).all() return session.query(Person).all() schema = graphene.Schema(query=Query, types=[PersonType, EmployeeType]) - result = schema.execute( + result = await schema.execute_async( """ query { people {