From bd5a68abb6af3dd2b4e52af42e9a38e6858c9a6b Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Thu, 24 Oct 2019 10:54:28 -0400 Subject: [PATCH 1/6] Fix N+1 problem for one-to-one and many-to-one relationships. --- graphene_sqlalchemy/resolver.py | 0 graphene_sqlalchemy/tests/conftest.py | 32 ++-- graphene_sqlalchemy/tests/test_batching.py | 203 +++++++++++++++++++++ graphene_sqlalchemy/tests/test_query.py | 11 +- graphene_sqlalchemy/tests/utils.py | 8 + graphene_sqlalchemy/types.py | 164 +++++++++++++++-- setup.py | 1 + 7 files changed, 375 insertions(+), 44 deletions(-) create mode 100644 graphene_sqlalchemy/resolver.py create mode 100644 graphene_sqlalchemy/tests/test_batching.py create mode 100644 graphene_sqlalchemy/tests/utils.py diff --git a/graphene_sqlalchemy/resolver.py b/graphene_sqlalchemy/resolver.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 9dc390eb..98515051 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,6 +1,6 @@ import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm import sessionmaker import graphene @@ -23,19 +23,17 @@ def convert_composite_class(composite, registry): @pytest.yield_fixture(scope="function") -def session(): - db = create_engine(test_db_url) - connection = db.engine.connect() - transaction = connection.begin() - Base.metadata.create_all(connection) - - # options = dict(bind=connection, binds={}) - session_factory = sessionmaker(bind=connection) - session = scoped_session(session_factory) - - yield session - - # Finalize test here - transaction.rollback() - connection.close() - session.remove() +def session_factory(): + engine = create_engine(test_db_url) + Base.metadata.create_all(engine) + + yield sessionmaker(bind=engine) + + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() + + +@pytest.fixture(scope="function") +def session(session_factory): + return session_factory() diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py new file mode 100644 index 00000000..bd856e67 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -0,0 +1,203 @@ +import contextlib +import logging + +import graphene + +from ..types import SQLAlchemyObjectType +from .models import Article, Reporter +from .utils import to_std_dicts + + +class MockLoggingHandler(logging.Handler): + """Intercept and store log messages in a list.""" + def __init__(self, *args, **kwargs): + self.messages = [] + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + self.messages.append(record.getMessage()) + + +@contextlib.contextmanager +def mock_sqlalchemy_logging_handler(): + logging.basicConfig() + sql_logger = logging.getLogger('sqlalchemy.engine') + previous_level = sql_logger.level + + sql_logger.setLevel(logging.INFO) + mock_logging_handler = MockLoggingHandler() + mock_logging_handler.setLevel(logging.INFO) + sql_logger.addHandler(mock_logging_handler) + + yield mock_logging_handler + + sql_logger.setLevel(previous_level) + + +def make_fixture(session): + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + +def get_schema(session): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_articles(self, _info): + return session.query(Article).all() + + def resolve_reporters(self, _info): + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + +def test_many_to_one(session_factory): + session = session_factory() + make_fixture(session) + schema = get_schema(session) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + articles { + headline + reporter { + firstName + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date, ' + 'articles.reporter_id AS articles_reporter_id \n' + 'FROM articles', + '()', + + 'SELECT reporters.id AS reporters_id, ' + '(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters \n' + 'WHERE reporters.id IN (?, ?)', + '(1, 2)', + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], + } + + +def test_one_to_one(session_factory): + session = session_factory() + make_fixture(session) + schema = get_schema(session) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + favoriteArticle { + headline + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT articles.reporter_id AS articles_reporter_id, ' + 'articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date \n' + 'FROM articles \n' + 'WHERE articles.reporter_id IN (?, ?) ' + 'ORDER BY articles.reporter_id', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], + } diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 74a7249a..45272e0b 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -5,16 +5,7 @@ from ..fields import SQLAlchemyConnectionField from ..types import ORMField, SQLAlchemyObjectType from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter - - -def to_std_dicts(value): - """Convert nested ordered dicts to normal dicts for better comparison.""" - if isinstance(value, dict): - return {k: to_std_dicts(v) for k, v in value.items()} - elif isinstance(value, list): - return [to_std_dicts(v) for v in value] - else: - return value +from .utils import to_std_dicts def add_test_data(session): diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py new file mode 100644 index 00000000..b59ab0e8 --- /dev/null +++ b/graphene_sqlalchemy/tests/utils.py @@ -0,0 +1,8 @@ +def to_std_dicts(value): + """Convert nested ordered dicts to normal dicts for better comparison.""" + if isinstance(value, dict): + return {k: to_std_dicts(v) for k, v in value.items()} + elif isinstance(value, list): + return [to_std_dicts(v) for v in value] + else: + return value diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 2b3e5728..dfe048ca 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,11 +1,13 @@ from collections import OrderedDict import sqlalchemy +from promise import dataloader, promise from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect as sqlalchemyinspect -from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty) +from sqlalchemy.orm import (ColumnProperty, CompositeProperty, Load, + RelationshipProperty, Session) from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm.loading import PostLoad from graphene import Field from graphene.relay import Connection, Node @@ -152,22 +154,40 @@ def construct_fields( for orm_field_name, orm_field in orm_fields.items(): attr_name = orm_field.kwargs.pop('model_attr') attr = all_model_attrs[attr_name] - resolver = _get_field_resolver(obj_type, orm_field_name, attr_name) + custom_resolver = _get_custom_resolver(obj_type, orm_field_name) if isinstance(attr, ColumnProperty): - field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) + field = convert_sqlalchemy_column( + attr, + registry, + custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), + **orm_field.kwargs + ) elif isinstance(attr, RelationshipProperty): - field = convert_sqlalchemy_relationship(attr, registry, connection_field_factory, resolver, - **orm_field.kwargs) + field = convert_sqlalchemy_relationship( + attr, + registry, + connection_field_factory, + custom_resolver or _get_relationship_resolver(obj_type, attr, attr_name), + **orm_field.kwargs + ) elif isinstance(attr, CompositeProperty): if attr_name != orm_field_name or orm_field.kwargs: # TODO Add a way to override composite property fields raise ValueError( "ORMField kwargs for composite fields must be empty. " "Field: {}.{}".format(obj_type.__name__, orm_field_name)) - field = convert_sqlalchemy_composite(attr, registry, resolver) + field = convert_sqlalchemy_composite( + attr, + registry, + custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), + ) elif isinstance(attr, hybrid_property): - field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) + field = convert_sqlalchemy_hybrid_method( + attr, + custom_resolver or _get_attr_resolver(obj_type, orm_field_name, attr_name), + **orm_field.kwargs + ) else: raise Exception('Property type is not supported') # Should never happen @@ -177,22 +197,132 @@ def construct_fields( return fields -def _get_field_resolver(obj_type, orm_field_name, model_attr): +def _get_custom_resolver(obj_type, orm_field_name): + """ + Since `graphene` will call `resolve_` on a field only if it + does not have a `resolver`, we need to re-implement that logic here so + users are able to override the default resolvers that we provide. + """ + resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) + if resolver: + return get_unbound_function(resolver) + + return None + + +def _get_relationship_resolver(obj_type, relationship_prop, model_attr): + """ + Batch SQL queries using Dataloader to avoid the N+1 problem. + + :param SQLAlchemyObjectType obj_type: + :param sqlalchemy.orm.properties.RelationshipProperty relationship_prop: + :param str model_attr: the name of the SQLAlchemy attribute + :rtype: Callable + """ + child_mapper = relationship_prop.mapper + parent_mapper = relationship_prop.parent + + if relationship_prop.uselist: + # TODO Batch many-to-many and one-to-many relationships + return _get_attr_resolver(obj_type, model_attr, model_attr) + + class NonListRelationshipLoader(dataloader.DataLoader): + cache = False + + def batch_load_fn(self, parents): # pylint: disable=method-hidden + """ + Batch loads the relationship of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach is to here to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query and then + call then invoke the `selectin` post loader. + + For this reason, if you're trying to understand the steps below, + it's easier to start at the bottom (ie `post_load.invoke`) and + go backward. + """ + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + load = Load(parent_mapper.entity).selectinload(model_attr) + query = session.query(parent_mapper.entity).options(load) + + # Taken from orm.query.Query.__iter__ + # https://git.io/JeuBi + context = query._compile_context() + + # Taken from orm.loading.instances + # https://git.io/JeuBR + context.post_load_paths = {} + + # Taken from orm.strategies.SelectInLoader.__init__ + # https://git.io/JeuBd + selectin_strategy = getattr(parent_mapper.entity, model_attr).property._get_strategy(load.strategy) + + # Taken from orm.loading._instance_processor._instance + # https://git.io/JeuBq + post_load = PostLoad() + post_load.loaders[model_attr] = ( + model_attr, + parent_mapper, + selectin_strategy._load_for_path, + (child_mapper,), + {}, + ) + + # Taken from orm.loading._instance_processor._instance + # https://git.io/JeuBn + # https://git.io/Jeu4j + context.partials = {} + for parent in parents: + post_load.add_state(parent._sa_instance_state, True) + + # Taken from orm.strategies.SelectInLoader.create_row_processor + # https://git.io/Jeu4F + selectin_path = context.query._current_path + parent_mapper._path_registry + + # Taken from orm.loading.instances + # https://git.io/JeuBO + post_load.invoke(context, selectin_path.path) + + return promise.Promise.resolve([getattr(parent, model_attr) for parent in parents]) + + loader = NonListRelationshipLoader() + + def resolve(root, info): + return loader.load(root) + + return resolve + + +def _get_attr_resolver(obj_type, orm_field_name, model_attr): """ In order to support field renaming via `ORMField.model_attr`, we need to define resolver functions for each field. :param SQLAlchemyObjectType obj_type: - :param model: the SQLAlchemy model - :param str model_attr: the name of SQLAlchemy of the attribute used to resolve the field + :param str orm_field_name: + :param str model_attr: the name of the SQLAlchemy attribute :rtype: Callable """ - # Since `graphene` will call `resolve_` on a field only if it - # does not have a `resolver`, we need to re-implement that logic here. - resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) - if resolver: - return get_unbound_function(resolver) - return lambda root, _info: getattr(root, model_attr, None) diff --git a/setup.py b/setup.py index 66704b28..74333933 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ requirements = [ # To keep things simple, we only support newer versions of Graphene "graphene>=2.1.3,<3", + "promise>=2.1", # Tests fail with 1.0.19 "SQLAlchemy>=1.1,<2", "six>=1.10.0,<2", From cc0de2af66d91c0c0a4ee118144626ccd40d0418 Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Wed, 30 Oct 2019 10:50:40 -0400 Subject: [PATCH 2/6] address zzzeek comments --- graphene_sqlalchemy/types.py | 61 ++++++++++++------------------------ 1 file changed, 20 insertions(+), 41 deletions(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index dfe048ca..4a6f3dcf 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -3,11 +3,12 @@ import sqlalchemy from promise import dataloader, promise from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.inspection import inspect as sqlalchemyinspect -from sqlalchemy.orm import (ColumnProperty, CompositeProperty, Load, +from sqlalchemy.orm import (ColumnProperty, CompositeProperty, RelationshipProperty, Session) from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy.orm.loading import PostLoad +from sqlalchemy.orm.query import QueryContext +from sqlalchemy.orm.strategies import SelectInLoader +from sqlalchemy.orm.util import PathRegistry from graphene import Field from graphene.relay import Connection, Node @@ -106,7 +107,7 @@ def construct_fields( :param function connection_field_factory: :rtype: OrderedDict[str, graphene.Field] """ - inspected_model = sqlalchemyinspect(model) + inspected_model = sqlalchemy.inspect(model) # Gather all the relevant attributes from the SQLAlchemy model in order all_model_attrs = OrderedDict( inspected_model.column_attrs.items() + @@ -262,46 +263,24 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden # The behavior of `selectin` is undefined if the parent is dirty assert parent not in session.dirty - load = Load(parent_mapper.entity).selectinload(model_attr) - query = session.query(parent_mapper.entity).options(load) - - # Taken from orm.query.Query.__iter__ - # https://git.io/JeuBi - context = query._compile_context() - - # Taken from orm.loading.instances - # https://git.io/JeuBR - context.post_load_paths = {} - - # Taken from orm.strategies.SelectInLoader.__init__ - # https://git.io/JeuBd - selectin_strategy = getattr(parent_mapper.entity, model_attr).property._get_strategy(load.strategy) - - # Taken from orm.loading._instance_processor._instance - # https://git.io/JeuBq - post_load = PostLoad() - post_load.loaders[model_attr] = ( - model_attr, - parent_mapper, - selectin_strategy._load_for_path, - (child_mapper,), - {}, - ) + loader = SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - # Taken from orm.loading._instance_processor._instance - # https://git.io/JeuBn - # https://git.io/Jeu4j - context.partials = {} - for parent in parents: - post_load.add_state(parent._sa_instance_state, True) + # The path is a fixed single token in this case + path = PathRegistry.root + parent_mapper._path_registry - # Taken from orm.strategies.SelectInLoader.create_row_processor - # https://git.io/Jeu4F - selectin_path = context.query._current_path + parent_mapper._path_registry + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] - # Taken from orm.loading.instances - # https://git.io/JeuBO - post_load.invoke(context, selectin_path.path) + # For our purposes, the query_context will only used to get the session + query_context = QueryContext(session.query(parent_mapper.entity)) + + loader._load_for_path( + query_context, + path, + states, + None, + child_mapper, + ) return promise.Promise.resolve([getattr(parent, model_attr) for parent in parents]) From 60b6df9a9c838d1ffa122e3826cd4de94580a5c1 Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Wed, 30 Oct 2019 10:54:34 -0400 Subject: [PATCH 3/6] simplify path --- graphene_sqlalchemy/types.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 4a6f3dcf..2f960ad2 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -8,7 +8,6 @@ from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.query import QueryContext from sqlalchemy.orm.strategies import SelectInLoader -from sqlalchemy.orm.util import PathRegistry from graphene import Field from graphene.relay import Connection, Node @@ -265,9 +264,6 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden loader = SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) - # The path is a fixed single token in this case - path = PathRegistry.root + parent_mapper._path_registry - # Should the boolean be set to False? Does it matter for our purposes? states = [(sqlalchemy.inspect(parent), True) for parent in parents] @@ -276,7 +272,7 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden loader._load_for_path( query_context, - path, + parent_mapper._path_registry, states, None, child_mapper, From c8f39bdda65b3f5af37799791e742658cf37e9b4 Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Wed, 30 Oct 2019 10:59:29 -0400 Subject: [PATCH 4/6] update comment --- graphene_sqlalchemy/types.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 2f960ad2..9d2eb7ac 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -245,12 +245,8 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden when callling `all()`, we skip the first `SELECT` statement and jump right before the `selectin` loader is called. To accomplish this, we have to construct objects that are - normally built in the first part of the query and then - call then invoke the `selectin` post loader. - - For this reason, if you're trying to understand the steps below, - it's easier to start at the bottom (ie `post_load.invoke`) and - go backward. + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. """ session = Session.object_session(parents[0]) From 0d0067a8961a5fe656b00b21b98fa5b1c50b78f0 Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Wed, 30 Oct 2019 17:40:48 -0400 Subject: [PATCH 5/6] bump sqlalchemy --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 74333933..4e7c4f9c 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ "graphene>=2.1.3,<3", "promise>=2.1", # Tests fail with 1.0.19 - "SQLAlchemy>=1.1,<2", + "SQLAlchemy>=1.2,<2", "six>=1.10.0,<2", "singledispatch>=3.4.0.3,<4", ] From d7d90f07b69603aad6ae98e38f0800c60faf4153 Mon Sep 17 00:00:00 2001 From: "@jnak" Date: Tue, 12 Nov 2019 13:34:03 -0500 Subject: [PATCH 6/6] disable batching for sqlalchemy < 1.2 --- graphene_sqlalchemy/tests/test_batching.py | 25 ++++++++++++++++++++++ graphene_sqlalchemy/types.py | 15 ++++++++----- setup.cfg | 2 +- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index bd856e67..0881f71e 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -1,6 +1,9 @@ import contextlib import logging +import pkg_resources +import pytest + import graphene from ..types import SQLAlchemyObjectType @@ -78,6 +81,14 @@ def resolve_reporters(self, _info): return graphene.Schema(query=Query) +def is_sqlalchemy_version_less_than(version_string): + return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) + + +if is_sqlalchemy_version_less_than('1.2'): + pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) + + def test_many_to_one(session_factory): session = session_factory() make_fixture(session) @@ -99,6 +110,13 @@ def test_many_to_one(session_factory): messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + return + assert messages == [ 'BEGIN (implicit)', @@ -161,6 +179,13 @@ def test_one_to_one(session_factory): messages = sqlalchemy_logging_handler.messages assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + return + assert messages == [ 'BEGIN (implicit)', diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 9d2eb7ac..23c8288e 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -4,10 +4,9 @@ from promise import dataloader, promise from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import (ColumnProperty, CompositeProperty, - RelationshipProperty, Session) + RelationshipProperty, Session, strategies) from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.query import QueryContext -from sqlalchemy.orm.strategies import SelectInLoader from graphene import Field from graphene.relay import Connection, Node @@ -213,6 +212,8 @@ def _get_custom_resolver(obj_type, orm_field_name): def _get_relationship_resolver(obj_type, relationship_prop, model_attr): """ Batch SQL queries using Dataloader to avoid the N+1 problem. + SQL batching only works for SQLAlchemy 1.2+ since it depends on + the `selectin` loader. :param SQLAlchemyObjectType obj_type: :param sqlalchemy.orm.properties.RelationshipProperty relationship_prop: @@ -222,7 +223,7 @@ def _get_relationship_resolver(obj_type, relationship_prop, model_attr): child_mapper = relationship_prop.mapper parent_mapper = relationship_prop.parent - if relationship_prop.uselist: + if not getattr(strategies, 'SelectInLoader', None) or relationship_prop.uselist: # TODO Batch many-to-many and one-to-many relationships return _get_attr_resolver(obj_type, model_attr, model_attr) @@ -239,7 +240,7 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden than re-implementing and maintainnig a big chunk of the `selectin` loader logic ourselves. - The approach is to here to build a regular query that + The approach here is to build a regular query that selects the parent and `selectin` load the relationship. But instead of having the query emits 2 `SELECT` statements when callling `all()`, we skip the first `SELECT` statement @@ -247,6 +248,10 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden To accomplish this, we have to construct objects that are normally built in the first part of the query in order to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 """ session = Session.object_session(parents[0]) @@ -258,7 +263,7 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden # The behavior of `selectin` is undefined if the parent is dirty assert parent not in session.dirty - loader = SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) + loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) # Should the boolean be set to False? Does it matter for our purposes? states = [(sqlalchemy.inspect(parent), True) for parent in parents] diff --git a/setup.cfg b/setup.cfg index 0aa80ba9..880c87d6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ max-line-length = 120 no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=app,database,flask,mock,models,nameko,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=app,database,flask,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER skip_glob=examples/nameko_sqlalchemy