diff --git a/.gitignore b/.gitignore index e4070f31..d4f71e35 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ __pycache__/ # Distribution / packaging .Python env/ +.venv/ build/ develop-eggs/ dist/ diff --git a/examples/flask_sqlalchemy/app.py b/examples/flask_sqlalchemy/app.py index a4d3f29e..1066020c 100755 --- a/examples/flask_sqlalchemy/app.py +++ b/examples/flask_sqlalchemy/app.py @@ -1,43 +1,46 @@ #!/usr/bin/env python +from database import db_session, init_db from flask import Flask +from schema import schema from flask_graphql import GraphQLView -from .database import db_session, init_db -from .schema import schema - app = Flask(__name__) app.debug = True -default_query = ''' +example_query = """ { - allEmployees { + allEmployees(sort: [NAME_ASC, ID_ASC]) { edges { node { - id, - name, + id + name department { - id, + id name - }, + } role { - id, + id name } } } } -}'''.strip() +} +""" -app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True)) +app.add_url_rule( + "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=True) +) @app.teardown_appcontext def shutdown_session(exception=None): db_session.remove() -if __name__ == '__main__': + +if __name__ == "__main__": init_db() app.run() diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index 01e76ca6..ca4d4122 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -14,7 +14,7 @@ def init_db(): # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - from .models import Department, Employee, Role + from models import Department, Employee, Role Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index e164c015..efbbe690 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -1,8 +1,7 @@ +from database import Base from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func from sqlalchemy.orm import backref, relationship -from .database import Base - class Department(Base): __tablename__ = 'department' diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index cbee081c..9ed09464 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -1,11 +1,10 @@ +from models import Department as DepartmentModel +from models import Employee as EmployeeModel +from models import Role as RoleModel + import graphene from graphene import relay -from graphene_sqlalchemy import (SQLAlchemyConnectionField, - SQLAlchemyObjectType, utils) - -from .models import Department as DepartmentModel -from .models import Employee as EmployeeModel -from .models import Role as RoleModel +from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType class Department(SQLAlchemyObjectType): @@ -26,18 +25,11 @@ class Meta: interfaces = (relay.Node, ) -SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee', - lambda c, d: c.upper() + ('_ASC' if d else '_DESC')) - - class Query(graphene.ObjectType): node = relay.Node.Field() # Allow only single column sorting all_employees = SQLAlchemyConnectionField( - Employee, - sort=graphene.Argument( - SortEnumEmployee, - default_value=utils.EnumValue('id_asc', EmployeeModel.id.asc()))) + Employee, sort=Employee.sort_argument()) # Allows sorting over multiple columns, by default over the primary key all_roles = SQLAlchemyConnectionField(Role) # Disable sorting over this field diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 7cc259e0..9466cbaf 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -7,6 +7,9 @@ String) from graphene.types.json import JSONString +from .enums import enum_for_sa_enum +from .registry import get_global_registry + try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType except ImportError: @@ -145,21 +148,15 @@ def convert_column_to_float(type, column, registry=None): @convert_sqlalchemy_type.register(types.Enum) def convert_enum_to_enum(type, column, registry=None): - enum_class = getattr(type, 'enum_class', None) - if enum_class: # Check if an enum.Enum type is used - graphene_type = Enum.from_enum(enum_class) - else: # Nope, just a list of string options - items = zip(type.enums, type.enums) - graphene_type = Enum(type.name, items) return Field( - graphene_type, + lambda: enum_for_sa_enum(type, registry or get_global_registry()), description=get_column_doc(column), required=not (is_column_nullable(column)), ) @convert_sqlalchemy_type.register(ChoiceType) -def convert_column_to_enum(type, column, registry=None): +def convert_choice_to_enum(type, column, registry=None): name = "{}_{}".format(column.table.name, column.name).upper() return Enum(name, type.choices, description=get_column_doc(column)) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py new file mode 100644 index 00000000..6b84bf52 --- /dev/null +++ b/graphene_sqlalchemy/enums.py @@ -0,0 +1,203 @@ +from sqlalchemy import Column +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Argument, Enum, List + +from .utils import EnumValue, to_enum_value_name, to_type_name + + +def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): + """Convert the given SQLAlchemy Enum type to a Graphene Enum type. + + The name of the Graphene Enum will be determined as follows: + If the SQLAlchemy Enum is based on a Python Enum, use the name + of the Python Enum. Otherwise, if the SQLAlchemy Enum is named, + use the SQL name after conversion to a type name. Otherwise, use + the given fallback_name or raise an error if it is empty. + + The Enum value names are converted to upper case if necessary. + """ + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) + ) + enum_class = sa_enum.enum_class + if enum_class: + if all(to_enum_value_name(key) == key for key in enum_class.__members__): + return Enum.from_enum(enum_class) + name = enum_class.__name__ + members = [ + (to_enum_value_name(key), value.value) + for key, value in enum_class.__members__.items() + ] + else: + sql_enum_name = sa_enum.name + if sql_enum_name: + name = to_type_name(sql_enum_name) + elif fallback_name: + name = fallback_name + else: + raise TypeError("No type name specified for {!r}".format(sa_enum)) + members = [(to_enum_value_name(key), key) for key in sa_enum.enums] + return Enum(name, members) + + +def enum_for_sa_enum(sa_enum, registry): + """Return the Graphene Enum type for the specified SQLAlchemy Enum type.""" + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) + ) + enum = registry.get_graphene_enum_for_sa_enum(sa_enum) + if not enum: + enum = _convert_sa_to_graphene_enum(sa_enum) + registry.register_enum(sa_enum, enum) + return enum + + +def enum_for_field(obj_type, field_name): + """Return the Graphene Enum type for the specified Graphene field.""" + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) + if not field_name or not isinstance(field_name, str): + raise TypeError( + "Expected a field name, but got: {!r}".format(field_name)) + registry = obj_type._meta.registry + orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) + if orm_field is None: + raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name)) + if not isinstance(orm_field, Column): + raise TypeError( + "{}.{} does not map to model column".format(obj_type._meta.name, field_name) + ) + sa_enum = orm_field.type + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "{}.{} does not map to enum column".format(obj_type._meta.name, field_name) + ) + enum = registry.get_graphene_enum_for_sa_enum(sa_enum) + if not enum: + fallback_name = obj_type._meta.name + to_type_name(field_name) + enum = _convert_sa_to_graphene_enum(sa_enum, fallback_name) + registry.register_enum(sa_enum, enum) + return enum + + +def _default_sort_enum_symbol_name(column_name, sort_asc=True): + return to_enum_value_name(column_name) + ("_ASC" if sort_asc else "_DESC") + + +def sort_enum_for_object_type( + obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None +): + """Return Graphene Enum for sorting the given SQLAlchemyObjectType. + + Parameters + - obj_type : SQLAlchemyObjectType + The object type for which the sort Enum shall be generated. + - name : str, optional, default None + Name to use for the sort Enum. + If not provided, it will be set to the object type name + 'SortEnum' + - only_fields : sequence, optional, default None + If this is set, only fields from this sequence will be considered. + - only_indexed : bool, optional, default False + If this is set, only indexed columns will be considered. + - get_symbol_name : function, optional, default None + Function which takes the column name and a boolean indicating + if the sort direction is ascending, and returns the symbol name + for the current column and sort direction. If no such function + is passed, a default function will be used that creates the symbols + 'foo_asc' and 'foo_desc' for a column with the name 'foo'. + + Returns + - Enum + The Graphene Enum type + """ + name = name or obj_type._meta.name + "SortEnum" + registry = obj_type._meta.registry + enum = registry.get_sort_enum_for_object_type(obj_type) + custom_options = dict( + only_fields=only_fields, + only_indexed=only_indexed, + get_symbol_name=get_symbol_name, + ) + if enum: + if name != enum.__name__ or custom_options != enum.custom_options: + raise ValueError( + "Sort enum for {} has already been customized".format(obj_type) + ) + else: + members = [] + default = [] + fields = obj_type._meta.fields + get_name = get_symbol_name or _default_sort_enum_symbol_name + for field_name in fields: + if only_fields and field_name not in only_fields: + continue + orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) + if not isinstance(orm_field, Column): + continue + if only_indexed and not (orm_field.primary_key or orm_field.index): + continue + asc_name = get_name(orm_field.name, True) + asc_value = EnumValue(asc_name, orm_field.asc()) + desc_name = get_name(orm_field.name, False) + desc_value = EnumValue(desc_name, orm_field.desc()) + if orm_field.primary_key: + default.append(asc_value) + members.extend(((asc_name, asc_value), (desc_name, desc_value))) + enum = Enum(name, members) + enum.default = default # store default as attribute + enum.custom_options = custom_options + registry.register_sort_enum(obj_type, enum) + return enum + + +def sort_argument_for_object_type( + obj_type, + enum_name=None, + only_fields=None, + only_indexed=None, + get_symbol_name=None, + has_default=True, +): + """"Returns Graphene Argument for sorting the given SQLAlchemyObjectType. + + Parameters + - obj_type : SQLAlchemyObjectType + The object type for which the sort Argument shall be generated. + - enum_name : str, optional, default None + Name to use for the sort Enum. + If not provided, it will be set to the object type name + 'SortEnum' + - only_fields : sequence, optional, default None + If this is set, only fields from this sequence will be considered. + - only_indexed : bool, optional, default False + If this is set, only indexed columns will be considered. + - get_symbol_name : function, optional, default None + Function which takes the column name and a boolean indicating + if the sort direction is ascending, and returns the symbol name + for the current column and sort direction. If no such function + is passed, a default function will be used that creates the symbols + 'foo_asc' and 'foo_desc' for a column with the name 'foo'. + - has_default : bool, optional, default True + If this is set to False, no sorting will happen when this argument is not + passed. Otherwise results will be sortied by the primary key(s) of the model. + + Returns + - Enum + A Graphene Argument that accepts a list of sorting directions for the model. + """ + enum = sort_enum_for_object_type( + obj_type, + enum_name, + only_fields=only_fields, + only_indexed=only_indexed, + get_symbol_name=get_symbol_name, + ) + if not has_default: + enum.default = None + + return Argument(List(enum), default_value=enum.default) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 4a46b749..3ad15a92 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -8,7 +8,7 @@ from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice -from .utils import get_query, sort_argument_for_model +from .utils import get_query log = logging.getLogger() @@ -84,10 +84,9 @@ def __init__(self, type, *args, **kwargs): if "sort" not in kwargs and issubclass(type, Connection): # Let super class raise if type is not a Connection try: - model = type.Edge.node._type._meta.model - kwargs.setdefault("sort", sort_argument_for_model(model)) - except Exception: - raise Exception( + kwargs.setdefault("sort", type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' " to None to disabling the creation of the sort query argument".format( type.__name__ @@ -109,7 +108,7 @@ def default_connection_field_factory(relationship, registry): def createConnectionField(_type): - log.warn( + log.warning( 'createConnectionField is deprecated and will be removed in the next ' 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' ) @@ -117,7 +116,7 @@ def createConnectionField(_type): def registerConnectionFieldFactory(factoryMethod): - log.warn( + log.warning( 'registerConnectionFieldFactory is deprecated and will be removed in the next ' 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' ) @@ -126,7 +125,7 @@ def registerConnectionFieldFactory(factoryMethod): def unregisterConnectionFieldFactory(): - log.warn( + log.warning( 'registerConnectionFieldFactory is deprecated and will be removed in the next ' 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' ) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 460053f2..acfa744b 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,32 +1,91 @@ +from collections import defaultdict + +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Enum + + class Registry(object): def __init__(self): self._registry = {} self._registry_models = {} + self._registry_orm_fields = defaultdict(dict) self._registry_composites = {} + self._registry_enums = {} + self._registry_sort_enums = {} - def register(self, cls): + def register(self, obj_type): from .types import SQLAlchemyObjectType - assert issubclass(cls, SQLAlchemyObjectType), ( - "Only classes of type SQLAlchemyObjectType can be registered, " - 'received "{}"' - ).format(cls.__name__) - assert cls._meta.registry == self, "Registry for a Model have to match." + if not isinstance(obj_type, type) or not issubclass( + obj_type, SQLAlchemyObjectType + ): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + ) + assert obj_type._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( # 'SQLAlchemy model "{}" already associated with ' # 'another type "{}".' # ).format(cls._meta.model, self._registry[cls._meta.model]) - self._registry[cls._meta.model] = cls + self._registry[obj_type._meta.model] = obj_type def get_type_for_model(self, model): return self._registry.get(model) + def register_orm_field(self, obj_type, field_name, orm_field): + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass( + obj_type, SQLAlchemyObjectType + ): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + ) + if not field_name or not isinstance(field_name, str): + raise TypeError("Expected a field name, but got: {!r}".format(field_name)) + self._registry_orm_fields[obj_type][field_name] = orm_field + + def get_orm_field_for_graphene_field(self, obj_type, field_name): + return self._registry_orm_fields.get(obj_type, {}).get(field_name) + def register_composite_converter(self, composite, converter): self._registry_composites[composite] = converter def get_converter_for_composite(self, composite): return self._registry_composites.get(composite) + def register_enum(self, sa_enum, graphene_enum): + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) + ) + if not isinstance(graphene_enum, type(Enum)): + raise TypeError( + "Expected Graphene Enum, but got: {!r}".format(graphene_enum) + ) + + self._registry_enums[sa_enum] = graphene_enum + + def get_graphene_enum_for_sa_enum(self, sa_enum): + return self._registry_enums.get(sa_enum) + + def register_sort_enum(self, obj_type, sort_enum): + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass( + obj_type, SQLAlchemyObjectType + ): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + ) + if not isinstance(sort_enum, type(Enum)): + raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum)) + self._registry_sort_enums[obj_type] = sort_enum + + def get_sort_enum_for_object_type(self, obj_type): + return self._registry_sort_enums.get(obj_type) + registry = None diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py new file mode 100644 index 00000000..2825eb3c --- /dev/null +++ b/graphene_sqlalchemy/tests/conftest.py @@ -0,0 +1,32 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +from ..registry import reset_global_registry +from .models import Base + +test_db_url = 'sqlite://' # use in-memory database for tests + + +@pytest.fixture(autouse=True) +def reset_registry(): + reset_global_registry() + + +@pytest.yield_fixture(scope="function") +def session(): + db = create_engine(test_db_url) + connection = db.engine.connect() + transaction = connection.begin() + Base.metadata.create_all(connection) + + # options = dict(bind=connection, binds={}) + session_factory = sessionmaker(bind=connection) + session = scoped_session(session_factory) + + yield session + + # Finalize test here + transaction.rollback() + connection.close() + session.remove() diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 3ba23a8a..12781cc5 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -6,8 +6,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import mapper, relationship +PetKind = Enum("cat", "dog", name="pet_kind") -class Hairkind(enum.Enum): + +class HairKind(enum.Enum): LONG = 'long' SHORT = 'short' @@ -32,8 +34,8 @@ class Pet(Base): __tablename__ = "pets" id = Column(Integer(), primary_key=True) name = Column(String(30)) - pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False) - hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False) + pet_kind = Column(PetKind, nullable=False) + hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) reporter_id = Column(Integer(), ForeignKey("reporters.id")) @@ -43,6 +45,7 @@ class Reporter(Base): first_name = Column(String(30)) last_name = Column(String(30)) email = Column(String()) + favorite_pet_kind = Column(PetKind) pets = relationship("Pet", secondary=association_table, backref="reporters") articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 5cc16e79..f38999d2 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,6 +1,6 @@ import enum -from py.test import raises +import pytest from sqlalchemy import Column, Table, case, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base @@ -52,9 +52,9 @@ def assert_composite_conversion( def test_should_unknown_sqlalchemy_field_raise_exception(): - with raises(Exception) as excinfo: + re_err = "Don't know how to convert the SQLAlchemy field" + with pytest.raises(Exception, match=re_err): convert_sqlalchemy_column(None) - assert "Don't know how to convert the SQLAlchemy field" in str(excinfo.value) def test_should_date_convert_string(): @@ -87,18 +87,34 @@ def test_should_unicodetext_convert_string(): def test_should_enum_convert_enum(): field = assert_column_conversion( - types.Enum(enum.Enum("one", "two")), graphene.Field + types.Enum(enum.Enum("TwoNumbers", ("one", "two"))), graphene.Field ) field_type = field.type() assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") + field = assert_column_conversion( types.Enum("one", "two", name="two_numbers"), graphene.Field ) field_type = field.type() - assert field_type.__class__.__name__ == "two_numbers" + assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") + + +def test_should_not_enum_convert_enum_without_name(): + field = assert_column_conversion( + types.Enum("one", "two"), graphene.Field + ) + re_err = r"No type name specified for Enum\('one', 'two'\)" + with pytest.raises(TypeError, match=re_err): + field.type() def test_should_small_integer_convert_int(): @@ -260,7 +276,9 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.favorite_article.property, A._meta.registry, default_connection_field_factory + Reporter.favorite_article.property, + A._meta.registry, + default_connection_field_factory, ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -277,19 +295,26 @@ def test_should_postgresql_enum_convert(): postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field ) field_type = field.type() - assert field_type.__class__.__name__ == "two_numbers" + assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") def test_should_postgresql_py_enum_convert(): field = assert_column_conversion( - postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field + postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), + graphene.Field, ) field_type = field.type() - assert field_type.__class__.__name__ == "TwoNumbers" + assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") def test_should_postgresql_array_convert(): @@ -309,7 +334,7 @@ def test_should_postgresql_hstore_convert(): def test_should_composite_convert(): - class CompositeClass(object): + class CompositeClass: def __init__(self, col1, col2): self.col1 = col1 self.col2 = col2 @@ -331,7 +356,8 @@ def convert_composite_class(composite, registry): def test_should_unknown_sqlalchemy_composite_raise_exception(): registry = Registry() - with raises(Exception) as excinfo: + re_err = "Don't know how to convert the composite field" + with pytest.raises(Exception, match=re_err): class CompositeClass(object): def __init__(self, col1, col2): @@ -344,5 +370,3 @@ def __init__(self, col1, col2): graphene.String, registry, ) - - assert "Don't know how to convert the composite field" in str(excinfo.value) diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py new file mode 100644 index 00000000..ca376964 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -0,0 +1,122 @@ +from enum import Enum as PyEnum + +import pytest +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Enum + +from ..enums import _convert_sa_to_graphene_enum, enum_for_field +from ..types import SQLAlchemyObjectType +from .models import HairKind, Pet + + +def test_convert_sa_to_graphene_enum_bad_type(): + re_err = "Expected sqlalchemy.types.Enum, but got: 'foo'" + with pytest.raises(TypeError, match=re_err): + _convert_sa_to_graphene_enum("foo") + + +def test_convert_sa_to_graphene_enum_based_on_py_enum(): + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + sa_enum = SQLAlchemyEnumType(Color) + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "Color" + assert graphene_enum._meta.enum is Color + + +def test_convert_sa_to_graphene_enum_based_on_py_enum_with_bad_names(): + class Color(PyEnum): + red = 1 + green = 2 + blue = 3 + + sa_enum = SQLAlchemyEnumType(Color) + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "Color" + assert graphene_enum._meta.enum is not Color + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 1), ("GREEN", 2), ("BLUE", 3)] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue", name="color_values") + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "ColorValues" + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue") + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "FallbackName" + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_without_name(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue") + re_err = r"No type name specified for Enum\('red', 'green', 'blue'\)" + with pytest.raises(TypeError, match=re_err): + _convert_sa_to_graphene_enum(sa_enum) + + +def test_enum_for_field(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + enum = enum_for_field(PetType, 'pet_kind') + assert isinstance(enum, type(Enum)) + assert enum._meta.name == "PetKind" + assert [ + (key, value.value) + for key, value in enum._meta.enum.__members__.items() + ] == [("CAT", 'cat'), ("DOG", 'dog')] + enum2 = enum_for_field(PetType, 'pet_kind') + assert enum2 is enum + enum2 = PetType.enum_for_field('pet_kind') + assert enum2 is enum + + enum = enum_for_field(PetType, 'hair_kind') + assert isinstance(enum, type(Enum)) + assert enum._meta.name == "HairKind" + assert enum._meta.enum is HairKind + enum2 = PetType.enum_for_field('hair_kind') + assert enum2 is enum + + re_err = r"Cannot get PetType\.other_kind" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, 'other_kind') + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field('other_kind') + + re_err = r"PetType\.name does not map to enum column" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, 'name') + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field('name') + + re_err = r"Expected a field name, but got: None" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, None) + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field(None) + + re_err = "Expected SQLAlchemyObjectType, but got: None" + with pytest.raises(TypeError, match=re_err): + enum_for_field(None, 'other_kind') diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index ff616b30..0f8738f0 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -4,8 +4,7 @@ from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType -from ..utils import sort_argument_for_model -from .models import Editor +from .models import Editor as EditorModel from .models import Pet as PetModel @@ -14,27 +13,32 @@ class Meta: model = PetModel +class Editor(SQLAlchemyObjectType): + class Meta: + model = EditorModel + + class PetConn(Connection): class Meta: node = Pet def test_sort_added_by_default(): - arg = SQLAlchemyConnectionField(PetConn) - assert "sort" in arg.args - assert arg.args["sort"] == sort_argument_for_model(PetModel) + field = SQLAlchemyConnectionField(PetConn) + assert "sort" in field.args + assert field.args["sort"] == Pet.sort_argument() def test_sort_can_be_removed(): - arg = SQLAlchemyConnectionField(PetConn, sort=None) - assert "sort" not in arg.args + field = SQLAlchemyConnectionField(PetConn, sort=None) + assert "sort" not in field.args def test_custom_sort(): - arg = SQLAlchemyConnectionField(PetConn, sort=sort_argument_for_model(Editor)) - assert arg.args["sort"] == sort_argument_for_model(Editor) + field = SQLAlchemyConnectionField(PetConn, sort=Editor.sort_argument()) + assert field.args["sort"] == Editor.sort_argument() def test_init_raises(): - with pytest.raises(Exception, match="Cannot create sort"): + with pytest.raises(TypeError, match="Cannot create sort"): SQLAlchemyConnectionField(Connection) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 146c54e6..5279bd87 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,55 +1,44 @@ -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker - import graphene from graphene.relay import Connection, Node from ..fields import SQLAlchemyConnectionField -from ..registry import reset_global_registry from ..types import SQLAlchemyObjectType -from ..utils import sort_argument_for_model, sort_enum_for_model -from .models import Article, Base, Editor, Hairkind, Pet, Reporter - -db = create_engine("sqlite:///test_sqlalchemy.sqlite3") - - -@pytest.yield_fixture(scope="function") -def session(): - reset_global_registry() - connection = db.engine.connect() - transaction = connection.begin() - Base.metadata.create_all(connection) +from .models import Article, Editor, HairKind, Pet, Reporter - # options = dict(bind=connection, binds={}) - session_factory = sessionmaker(bind=connection) - session = scoped_session(session_factory) - yield session +def to_std_dicts(value): + """Convert nested ordered dicts to normal dicts for better comparison.""" + if isinstance(value, dict): + return {k: to_std_dicts(v) for k, v in value.items()} + elif isinstance(value, list): + return [to_std_dicts(v) for v in value] + else: + return value - # Finalize test here - transaction.rollback() - connection.close() - session.remove() - -def setup_fixtures(session): - pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG) - session.add(pet) - reporter = Reporter(first_name="ABA", last_name="X") +def add_test_data(session): + reporter = Reporter( + first_name='John', last_name='Doe', favorite_pet_kind='cat') session.add(reporter) - reporter2 = Reporter(first_name="ABO", last_name="Y") - session.add(reporter2) - article = Article(headline="Hi!") + pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + session.add(pet) + pet.reporters.append(reporter) + article = Article(headline='Hi!') article.reporter = reporter session.add(article) - editor = Editor(name="John") + reporter = Reporter( + first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + session.add(reporter) + pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet.reporters.append(reporter) + session.add(pet) + editor = Editor(name="Jack") session.add(editor) session.commit() def test_should_query_well(session): - setup_fixtures(session) + add_test_data(session) class ReporterType(SQLAlchemyObjectType): class Meta: @@ -59,17 +48,17 @@ class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - def resolve_reporter(self, *args, **kwargs): + def resolve_reporter(self, _info): return session.query(Reporter).first() - def resolve_reporters(self, *args, **kwargs): + def resolve_reporters(self, _info): return session.query(Reporter) query = """ query ReporterQuery { reporter { - firstName, - lastName, + firstName + lastName email } reporters { @@ -78,117 +67,18 @@ def resolve_reporters(self, *args, **kwargs): } """ expected = { - "reporter": {"firstName": "ABA", "lastName": "X", "email": None}, - "reporters": [{"firstName": "ABA"}, {"firstName": "ABO"}], + "reporter": {"firstName": "John", "lastName": "Doe", "email": None}, + "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], } schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert result.data == expected - - -def test_should_query_enums(session): - setup_fixtures(session) - - class PetType(SQLAlchemyObjectType): - class Meta: - model = Pet - - class Query(graphene.ObjectType): - pet = graphene.Field(PetType) - - def resolve_pet(self, *args, **kwargs): - return session.query(Pet).first() - - query = """ - query PetQuery { - pet { - name, - petKind - hairKind - } - } - """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} - schema = graphene.Schema(query=Query) - result = schema.execute(query) - assert not result.errors - assert result.data == expected, result.data - - -def test_enum_parameter(session): - setup_fixtures(session) - - class PetType(SQLAlchemyObjectType): - class Meta: - model = Pet - - class Query(graphene.ObjectType): - pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type)) - - def resolve_pet(self, info, kind=None, *args, **kwargs): - query = session.query(Pet) - if kind: - query = query.filter(Pet.pet_kind == kind) - return query.first() - - query = """ - query PetQuery($kind: pet_kind) { - pet(kind: $kind) { - name, - petKind - hairKind - } - } - """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} - schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "cat"}) - assert not result.errors - assert result.data == {"pet": None} - result = schema.execute(query, variables={"kind": "dog"}) - assert not result.errors - assert result.data == expected, result.data - - -def test_py_enum_parameter(session): - setup_fixtures(session) - - class PetType(SQLAlchemyObjectType): - class Meta: - model = Pet - - class Query(graphene.ObjectType): - pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type)) - - def resolve_pet(self, info, kind=None, *args, **kwargs): - query = session.query(Pet) - if kind: - # XXX Why kind passed in as a str instead of a Hairkind instance? - query = query.filter(Pet.hair_kind == Hairkind(kind)) - return query.first() - - query = """ - query PetQuery($kind: Hairkind) { - pet(kind: $kind) { - name, - petKind - hairKind - } - } - """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} - schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "SHORT"}) - assert not result.errors - assert result.data == {"pet": None} - result = schema.execute(query, variables={"kind": "LONG"}) - assert not result.errors - assert result.data == expected, result.data + result = to_std_dicts(result.data) + assert result == expected -def test_should_node(session): - setup_fixtures(session) +def test_should_query_node(session): + add_test_data(session) class ReporterNode(SQLAlchemyObjectType): class Meta: @@ -204,10 +94,6 @@ class Meta: model = Article interfaces = (Node,) - # @classmethod - # def get_node(cls, id, info): - # return Article(id=1, headline='Article node') - class ArticleConnection(Connection): class Meta: node = ArticleNode @@ -218,16 +104,16 @@ class Query(graphene.ObjectType): article = graphene.Field(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleConnection) - def resolve_reporter(self, *args, **kwargs): + def resolve_reporter(self, _info): return session.query(Reporter).first() - def resolve_article(self, *args, **kwargs): + def resolve_article(self, _info): return session.query(Article).first() query = """ query ReporterQuery { reporter { - id, + id firstName, articles { edges { @@ -260,8 +146,8 @@ def resolve_article(self, *args, **kwargs): expected = { "reporter": { "id": "UmVwb3J0ZXJOb2RlOjE=", - "firstName": "ABA", - "lastName": "X", + "firstName": "John", + "lastName": "Doe", "email": None, "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, }, @@ -271,11 +157,12 @@ def resolve_article(self, *args, **kwargs): schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected + result = to_std_dicts(result.data) + assert result == expected def test_should_custom_identifier(session): - setup_fixtures(session) + add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -295,7 +182,7 @@ class Query(graphene.ObjectType): allEditors { edges { node { - id, + id name } } @@ -308,18 +195,19 @@ class Query(graphene.ObjectType): } """ expected = { - "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "John"}}]}, - "node": {"name": "John"}, + "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "Jack"}}]}, + "node": {"name": "Jack"}, } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected + result = to_std_dicts(result.data) + assert result == expected def test_should_mutate_well(session): - setup_fixtures(session) + add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -385,7 +273,7 @@ class Mutation(graphene.ObjectType): "ok": True, "article": { "headline": "My Article", - "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "ABA"}, + "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "John"}, }, } } @@ -393,165 +281,5 @@ class Mutation(graphene.ObjectType): schema = graphene.Schema(query=Query, mutation=Mutation) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected - - -def sort_setup(session): - pets = [ - Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG), - Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG), - Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG), - ] - session.add_all(pets) - session.commit() - - -def test_sort(session): - sort_setup(session) - - class PetNode(SQLAlchemyObjectType): - class Meta: - model = Pet - interfaces = (Node,) - - class PetConnection(Connection): - class Meta: - node = PetNode - - class Query(graphene.ObjectType): - defaultSort = SQLAlchemyConnectionField(PetConnection) - nameSort = SQLAlchemyConnectionField(PetConnection) - multipleSort = SQLAlchemyConnectionField(PetConnection) - descSort = SQLAlchemyConnectionField(PetConnection) - singleColumnSort = SQLAlchemyConnectionField( - PetConnection, sort=graphene.Argument(sort_enum_for_model(Pet)) - ) - noDefaultSort = SQLAlchemyConnectionField( - PetConnection, sort=sort_argument_for_model(Pet, False) - ) - noSort = SQLAlchemyConnectionField(PetConnection, sort=None) - - query = """ - query sortTest { - defaultSort{ - edges{ - node{ - id - } - } - } - nameSort(sort: name_asc){ - edges{ - node{ - name - } - } - } - multipleSort(sort: [pet_kind_asc, name_desc]){ - edges{ - node{ - name - petKind - } - } - } - descSort(sort: [name_desc]){ - edges{ - node{ - name - } - } - } - singleColumnSort(sort: name_desc){ - edges{ - node{ - name - } - } - } - noDefaultSort(sort: name_asc){ - edges{ - node{ - name - } - } - } - } - """ - - def makeNodes(nodeList): - nodes = [{"node": item} for item in nodeList] - return {"edges": nodes} - - expected = { - "defaultSort": makeNodes( - [{"id": "UGV0Tm9kZToy"}, {"id": "UGV0Tm9kZToz"}, {"id": "UGV0Tm9kZToyMg=="}] - ), - "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), - "noDefaultSort": makeNodes( - [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] - ), - "multipleSort": makeNodes( - [ - {"name": "Alf", "petKind": "cat"}, - {"name": "Lassie", "petKind": "dog"}, - {"name": "Barf", "petKind": "dog"}, - ] - ), - "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), - "singleColumnSort": makeNodes( - [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] - ), - } # yapf: disable - - schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) - assert not result.errors - assert result.data == expected - - queryError = """ - query sortTest { - singleColumnSort(sort: [pet_kind_asc, name_desc]){ - edges{ - node{ - name - } - } - } - } - """ - result = schema.execute(queryError, context_value={"session": session}) - assert result.errors is not None - - queryNoSort = """ - query sortTest { - noDefaultSort{ - edges{ - node{ - name - } - } - } - noSort{ - edges{ - node{ - name - } - } - } - } - """ - - expectedNoSort = { - "noDefaultSort": makeNodes( - [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] - ), - "noSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), - } # yapf: disable - - result = schema.execute(queryNoSort, context_value={"session": session}) - assert not result.errors - for key, value in result.data.items(): - assert set(node["node"]["name"] for node in value["edges"]) == set( - node["node"]["name"] for node in expectedNoSort[key]["edges"] - ) + result = to_std_dicts(result.data) + assert result == expected diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py new file mode 100644 index 00000000..ec585d57 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -0,0 +1,198 @@ +import graphene + +from ..types import SQLAlchemyObjectType +from .models import HairKind, Pet, Reporter +from .test_query import add_test_data, to_std_dicts + + +def test_query_pet_kinds(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + + class Meta: + model = Pet + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + reporters = graphene.List(ReporterType) + pets = graphene.List(PetType, kind=graphene.Argument( + PetType.enum_for_field('pet_kind'))) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + def resolve_reporters(self, _info): + return session.query(Reporter) + + def resolve_pets(self, _info, kind): + query = session.query(Pet) + if kind: + query = query.filter_by(pet_kind=kind) + return query + + query = """ + query ReporterQuery { + reporter { + firstName + lastName + email + favoritePetKind + pets { + name + petKind + } + } + reporters { + firstName + favoritePetKind + } + pets(kind: DOG) { + name + petKind + } + } + """ + expected = { + 'reporter': { + 'firstName': 'John', + 'lastName': 'Doe', + 'email': None, + 'favoritePetKind': 'CAT', + 'pets': [{ + 'name': 'Garfield', + 'petKind': 'CAT' + }] + }, + 'reporters': [{ + 'firstName': 'John', + 'favoritePetKind': 'CAT', + }, { + 'firstName': 'Jane', + 'favoritePetKind': 'DOG', + }], + 'pets': [{ + 'name': 'Lassie', + 'petKind': 'DOG' + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_query_more_enums(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType) + + def resolve_pet(self, _info): + return session.query(Pet).first() + + query = """ + query PetQuery { + pet { + name, + petKind + hairKind + } + } + """ + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +def test_enum_as_argument(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field( + PetType, + kind=graphene.Argument(PetType.enum_for_field('pet_kind'))) + + def resolve_pet(self, info, kind=None): + query = session.query(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind) + return query.first() + + query = """ + query PetQuery($kind: PetKind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "CAT"}) + assert not result.errors + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + assert result.data == expected + result = schema.execute(query, variables={"kind": "DOG"}) + assert not result.errors + expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} + result = to_std_dicts(result.data) + assert result == expected + + +def test_py_enum_as_argument(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field( + PetType, + kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type), + ) + + def resolve_pet(self, _info, kind=None): + query = session.query(Pet) + if kind: + # enum arguments are expected to be strings, not PyEnums + query = query.filter(Pet.hair_kind == HairKind(kind)) + return query.first() + + query = """ + query PetQuery($kind: HairKind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "SHORT"}) + assert not result.errors + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + assert result.data == expected + result = schema.execute(query, variables={"kind": "LONG"}) + assert not result.errors + expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} + result = to_std_dicts(result.data) + assert result == expected diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 1945af6d..0403c4f0 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -1,25 +1,112 @@ import pytest +from sqlalchemy.types import Enum as SQLAlchemyEnum + +from graphene import Enum as GrapheneEnum from ..registry import Registry from ..types import SQLAlchemyObjectType +from ..utils import EnumValue from .models import Pet -def test_register_incorrect_objecttype(): +def test_register_object_type(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + reg.register(PetType) + assert reg.get_type_for_model(Pet) is PetType + + +def test_register_incorrect_object_type(): reg = Registry() class Spam: pass - with pytest.raises(AssertionError) as excinfo: + re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + with pytest.raises(TypeError, match=re_err): reg.register(Spam) - assert "Only classes of type SQLAlchemyObjectType can be registered" in str( - excinfo.value + +def test_register_orm_field(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + reg.register_orm_field(PetType, "name", Pet.name) + assert reg.get_orm_field_for_graphene_field(PetType, "name") is Pet.name + + +def test_register_orm_field_incorrect_types(): + reg = Registry() + + class Spam: + pass + + re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + with pytest.raises(TypeError, match=re_err): + reg.register_orm_field(Spam, "name", Pet.name) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + re_err = "Expected a field name, but got: .*Spam" + with pytest.raises(TypeError, match=re_err): + reg.register_orm_field(PetType, Spam, Pet.name) + + +def test_register_enum(): + reg = Registry() + + sa_enum = SQLAlchemyEnum("cat", "dog") + graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)]) + + reg.register_enum(sa_enum, graphene_enum) + assert reg.get_graphene_enum_for_sa_enum(sa_enum) is graphene_enum + + +def test_register_enum_incorrect_types(): + reg = Registry() + + sa_enum = SQLAlchemyEnum("cat", "dog") + graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)]) + + re_err = r"Expected Graphene Enum, but got: Enum\('cat', 'dog'\)" + with pytest.raises(TypeError, match=re_err): + reg.register_enum(sa_enum, sa_enum) + + re_err = r"Expected SQLAlchemyEnumType, but got: .*PetKind.*" + with pytest.raises(TypeError, match=re_err): + reg.register_enum(graphene_enum, graphene_enum) + + +def test_register_sort_enum(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + sort_enum = GrapheneEnum( + "PetSort", + [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], ) + reg.register_sort_enum(PetType, sort_enum) + assert reg.get_sort_enum_for_object_type(PetType) is sort_enum + -def test_register_objecttype(): +def test_register_sort_enum_incorrect_types(): reg = Registry() class PetType(SQLAlchemyObjectType): @@ -27,7 +114,15 @@ class Meta: model = Pet registry = reg - try: - reg.register(PetType) - except AssertionError: - pytest.fail("expected no AssertionError") + sort_enum = GrapheneEnum( + "PetSort", + [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], + ) + + re_err = r"Expected SQLAlchemyObjectType, but got: .*PetSort.*" + with pytest.raises(TypeError, match=re_err): + reg.register_sort_enum(sort_enum, sort_enum) + + re_err = r"Expected Graphene Enum, but got: .*PetType.*" + with pytest.raises(TypeError, match=re_err): + reg.register_sort_enum(PetType, PetType) diff --git a/graphene_sqlalchemy/tests/test_schema.py b/graphene_sqlalchemy/tests/test_schema.py index 628da185..87739bdb 100644 --- a/graphene_sqlalchemy/tests/test_schema.py +++ b/graphene_sqlalchemy/tests/test_schema.py @@ -35,6 +35,7 @@ class Meta: "first_name", "last_name", "email", + "favorite_pet_kind", "pets", "articles", "favorite_article", diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py new file mode 100644 index 00000000..1eb106da --- /dev/null +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -0,0 +1,389 @@ +import pytest +import sqlalchemy as sa + +from graphene import Argument, Enum, List, ObjectType, Schema +from graphene.relay import Connection, Node + +from ..fields import SQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType +from ..utils import to_type_name +from .models import Base, HairKind, Pet +from .test_query import to_std_dicts + + +def add_pets(session): + pets = [ + Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), + Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), + Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), + ] + session.add_all(pets) + session.commit() + + +def test_sort_enum(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum() + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "PetTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + "REPORTER_ID_ASC", + "REPORTER_ID_DESC", + ] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" + assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC" + assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC" + + +def test_sort_enum_with_custom_name(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum(name="CustomSortName") + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "CustomSortName" + + +def test_sort_enum_cache(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum() + sort_enum_2 = PetType.sort_enum() + assert sort_enum_2 is sort_enum + sort_enum_2 = PetType.sort_enum(name="PetTypeSortEnum") + assert sort_enum_2 is sort_enum + err_msg = "Sort enum for PetType has already been customized" + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(name="CustomSortName") + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(only_fields=["id"]) + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(only_indexed=True) + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(get_symbol_name=lambda: "foo") + + +def test_sort_enum_with_excluded_field_in_object_type(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + exclude_fields = ["reporter_id"] + + sort_enum = PetType.sort_enum() + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + ] + + +def test_sort_enum_only_fields(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum(only_fields=["id", "name"]) + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + ] + + +def test_sort_argument(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_arg = PetType.sort_argument() + assert isinstance(sort_arg, Argument) + + assert isinstance(sort_arg.type, List) + sort_enum = sort_arg.type._of_type + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "PetTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + "REPORTER_ID_ASC", + "REPORTER_ID_DESC", + ] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" + assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC" + assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC" + + assert sort_arg.default_value == ["ID_ASC"] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + + +def test_sort_argument_with_excluded_fields_in_object_type(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + exclude_fields = ["hair_kind", "reporter_id"] + + sort_arg = PetType.sort_argument() + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_only_fields(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + only_fields = ["id", "pet_kind"] + + sort_arg = PetType.sort_argument() + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_for_multi_column_pk(): + class MultiPkTestModel(Base): + __tablename__ = "multi_pk_test_table" + foo = sa.Column(sa.Integer, primary_key=True) + bar = sa.Column(sa.Integer, primary_key=True) + + class MultiPkTestType(SQLAlchemyObjectType): + class Meta: + model = MultiPkTestModel + + sort_arg = MultiPkTestType.sort_argument() + assert sort_arg.default_value == ["FOO_ASC", "BAR_ASC"] + + +def test_sort_argument_only_indexed(): + class IndexedTestModel(Base): + __tablename__ = "indexed_test_table" + id = sa.Column(sa.Integer, primary_key=True) + foo = sa.Column(sa.Integer, index=False) + bar = sa.Column(sa.Integer, index=True) + + class IndexedTestType(SQLAlchemyObjectType): + class Meta: + model = IndexedTestModel + + sort_arg = IndexedTestType.sort_argument(only_indexed=True) + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "BAR_ASC", + "BAR_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_with_custom_symbol_names(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + def get_symbol_name(column_name, sort_asc=True): + return to_type_name(column_name) + ("Up" if sort_asc else "Down") + + sort_arg = PetType.sort_argument(get_symbol_name=get_symbol_name) + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "IdUp", + "IdDown", + "NameUp", + "NameDown", + "PetKindUp", + "PetKindDown", + "HairKindUp", + "HairKindDown", + "ReporterIdUp", + "ReporterIdDown", + ] + assert sort_arg.default_value == ["IdUp"] + + +def test_sort_query(session): + add_pets(session) + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + + class PetConnection(Connection): + class Meta: + node = PetNode + + class Query(ObjectType): + defaultSort = SQLAlchemyConnectionField(PetConnection) + nameSort = SQLAlchemyConnectionField(PetConnection) + multipleSort = SQLAlchemyConnectionField(PetConnection) + descSort = SQLAlchemyConnectionField(PetConnection) + singleColumnSort = SQLAlchemyConnectionField( + PetConnection, sort=Argument(PetNode.sort_enum()) + ) + noDefaultSort = SQLAlchemyConnectionField( + PetConnection, sort=PetNode.sort_argument(has_default=False) + ) + noSort = SQLAlchemyConnectionField(PetConnection, sort=None) + + query = """ + query sortTest { + defaultSort { + edges { + node { + name + } + } + } + nameSort(sort: NAME_ASC) { + edges { + node { + name + } + } + } + multipleSort(sort: [PET_KIND_ASC, NAME_DESC]) { + edges { + node { + name + petKind + } + } + } + descSort(sort: [NAME_DESC]) { + edges { + node { + name + } + } + } + singleColumnSort(sort: NAME_DESC) { + edges { + node { + name + } + } + } + noDefaultSort(sort: NAME_ASC) { + edges { + node { + name + } + } + } + } + """ + + def makeNodes(nodeList): + nodes = [{"node": item} for item in nodeList] + return {"edges": nodes} + + expected = { + "defaultSort": makeNodes( + [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] + ), + "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), + "noDefaultSort": makeNodes( + [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] + ), + "multipleSort": makeNodes( + [ + {"name": "Alf", "petKind": "CAT"}, + {"name": "Lassie", "petKind": "DOG"}, + {"name": "Barf", "petKind": "DOG"}, + ] + ), + "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), + "singleColumnSort": makeNodes( + [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] + ), + } # yapf: disable + + schema = Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + queryError = """ + query sortTest { + singleColumnSort(sort: [PET_KIND_ASC, NAME_DESC]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(queryError, context_value={"session": session}) + assert result.errors is not None + assert '"sort" has invalid value' in result.errors[0].message + + queryNoSort = """ + query sortTest { + noDefaultSort { + edges { + node { + name + } + } + } + noSort { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute(queryNoSort, context_value={"session": session}) + assert not result.errors + # TODO: SQLite usually returns the results ordered by primary key, + # so we cannot test this way whether sorting actually happens or not. + # Also, no sort order is guaranteed by SQLite if "no order" by is used. + assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [ + node["node"]["name"] for node in result.data["noDefaultSort"]["edges"] + ] diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 0360a644..b76136fb 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -57,6 +57,7 @@ def test_objecttype_registered(): "first_name", "last_name", "email", + "favorite_pet_kind", "pets", "articles", "favorite_article", @@ -124,6 +125,7 @@ def test_custom_objecttype_registered(): "first_name", "last_name", "email", + "favorite_pet_kind", "pets", "articles", "favorite_article", @@ -168,6 +170,7 @@ def test_objecttype_with_custom_options(): "first_name", "last_name", "email", + "favorite_pet_kind", "pets", "articles", "favorite_article", @@ -181,7 +184,7 @@ class TestConnection(Connection): class Meta: node = ReporterWithCustomOptions - def resolver(*args, **kwargs): + def resolver(_obj, _info): return Promise.resolve([]) result = SQLAlchemyConnectionField.connection_resolver( diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index a7b902fe..e13d919c 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -1,9 +1,11 @@ +import pytest import sqlalchemy as sa from graphene import Enum, List, ObjectType, Schema, String -from ..utils import get_session, sort_argument_for_model, sort_enum_for_model -from .models import Editor, Pet +from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model, + to_enum_value_name, to_type_name) +from .models import Base, Editor, Pet def test_get_session(): @@ -27,8 +29,25 @@ def resolve_x(self, info): assert result.data["x"] == session +def test_to_type_name(): + assert to_type_name("make_camel_case") == "MakeCamelCase" + assert to_type_name("AlreadyCamelCase") == "AlreadyCamelCase" + assert to_type_name("A_Snake_and_a_Camel") == "ASnakeAndACamel" + + +def test_to_enum_value_name(): + assert to_enum_value_name("make_enum_value_name") == "MAKE_ENUM_VALUE_NAME" + assert to_enum_value_name("makeEnumValueName") == "MAKE_ENUM_VALUE_NAME" + assert to_enum_value_name("HTTPStatus400Message") == "HTTP_STATUS400_MESSAGE" + assert to_enum_value_name("ALREADY_ENUM_VALUE_NAME") == "ALREADY_ENUM_VALUE_NAME" + + +# test deprecated sort enum utility functions + + def test_sort_enum_for_model(): - enum = sort_enum_for_model(Pet) + with pytest.warns(DeprecationWarning): + enum = sort_enum_for_model(Pet) assert isinstance(enum, type(Enum)) assert str(enum) == "PetSortEnum" for col in sa.inspect(Pet).columns: @@ -37,7 +56,10 @@ def test_sort_enum_for_model(): def test_sort_enum_for_model_custom_naming(): - enum = sort_enum_for_model(Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D")) + with pytest.warns(DeprecationWarning): + enum = sort_enum_for_model( + Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D") + ) assert str(enum) == "Foo" for col in sa.inspect(Pet).columns: assert hasattr(enum, col.name.upper() + "A") @@ -45,32 +67,35 @@ def test_sort_enum_for_model_custom_naming(): def test_enum_cache(): - assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) + with pytest.warns(DeprecationWarning): + assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) def test_sort_argument_for_model(): - arg = sort_argument_for_model(Pet) + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(Pet) assert isinstance(arg.type, List) assert arg.default_value == [Pet.id.name + "_asc"] - assert arg.type.of_type == sort_enum_for_model(Pet) + with pytest.warns(DeprecationWarning): + assert arg.type.of_type is sort_enum_for_model(Pet) def test_sort_argument_for_model_no_default(): - arg = sort_argument_for_model(Pet, False) + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(Pet, False) assert arg.default_value is None def test_sort_argument_for_model_multiple_pk(): - Base = sa.ext.declarative.declarative_base() - class MultiplePK(Base): foo = sa.Column(sa.Integer, primary_key=True) bar = sa.Column(sa.Integer, primary_key=True) __tablename__ = "MultiplePK" - arg = sort_argument_for_model(MultiplePK) + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(MultiplePK) assert set(arg.default_value) == set( (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") ) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 394d5062..c20e8cfc 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -5,7 +5,7 @@ from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.orm.exc import NoResultFound -from graphene import Field # , annotate, ResolveInfo +from graphene import Field from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs @@ -14,12 +14,16 @@ convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship) +from .enums import (enum_for_field, sort_argument_for_object_type, + sort_enum_for_object_type) from .fields import default_connection_field_factory from .registry import Registry, get_global_registry from .utils import get_query, is_mapped_class, is_mapped_instance -def construct_fields(model, registry, only_fields, exclude_fields, connection_field_factory): +def construct_fields( + obj_type, model, registry, only_fields, exclude_fields, connection_field_factory +): inspected_model = sqlalchemyinspect(model) fields = OrderedDict() @@ -33,6 +37,7 @@ def construct_fields(model, registry, only_fields, exclude_fields, connection_fi # in there. Or when we exclude this field in exclude_fields continue converted_column = convert_sqlalchemy_column(column, registry) + registry.register_orm_field(obj_type, name, column) fields[name] = converted_column for name, composite in inspected_model.composites.items(): @@ -44,6 +49,7 @@ def construct_fields(model, registry, only_fields, exclude_fields, connection_fi # in there. Or when we exclude this field in exclude_fields continue converted_composite = convert_sqlalchemy_composite(composite, registry) + registry.register_orm_field(obj_type, name, composite) fields[name] = converted_composite for hybrid_item in inspected_model.all_orm_descriptors: @@ -61,6 +67,7 @@ def construct_fields(model, registry, only_fields, exclude_fields, connection_fi continue converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item) + registry.register_orm_field(obj_type, name, hybrid_item) fields[name] = converted_hybrid_property # Get all the columns for the relationships on the model @@ -72,8 +79,11 @@ def construct_fields(model, registry, only_fields, exclude_fields, connection_fi # We skip this field if we specify only_fields and is not # in there. Or when we exclude this field in exclude_fields continue - converted_relationship = convert_sqlalchemy_relationship(relationship, registry, connection_field_factory) + converted_relationship = convert_sqlalchemy_relationship( + relationship, registry, connection_field_factory + ) name = relationship.key + registry.register_orm_field(obj_type, name, relationship) fields[name] = converted_relationship return fields @@ -118,13 +128,14 @@ def __init_subclass_with_meta__( sqla_fields = yank_fields_from_attrs( construct_fields( + obj_type=cls, model=model, registry=registry, only_fields=only_fields, exclude_fields=exclude_fields, - connection_field_factory=connection_field_factory + connection_field_factory=connection_field_factory, ), - _as=Field + _as=Field, ) if use_connection is None and interfaces: @@ -191,3 +202,11 @@ def resolve_id(self, info): # graphene_type = info.parent_type.graphene_type keys = self.__mapper__.primary_key_from_instance(self) return tuple(keys) if len(keys) > 1 else keys[0] + + @classmethod + def enum_for_field(cls, field_name): + return enum_for_field(cls, field_name) + + sort_enum = classmethod(sort_enum_for_object_type) + + sort_argument = classmethod(sort_argument_for_object_type) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 276a8075..7139eefc 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,10 +1,10 @@ +import re +import warnings + from sqlalchemy.exc import ArgumentError -from sqlalchemy.inspection import inspect from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError -from graphene import Argument, Enum, List - def get_session(context): return context.get("session") @@ -41,70 +41,102 @@ def is_mapped_instance(cls): return True -def _symbol_name(column_name, is_asc): - return column_name + ("_asc" if is_asc else "_desc") +def to_type_name(name): + """Convert the given name to a GraphQL type name.""" + return "".join(part[:1].upper() + part[1:] for part in name.split("_")) + + +_re_enum_value_name_1 = re.compile("(.)([A-Z][a-z]+)") +_re_enum_value_name_2 = re.compile("([a-z0-9])([A-Z])") + + +def to_enum_value_name(name): + """Convert the given name to a GraphQL enum value name.""" + return _re_enum_value_name_2.sub( + r"\1_\2", _re_enum_value_name_1.sub(r"\1_\2", name) + ).upper() class EnumValue(str): - """Subclass of str that stores a string and an arbitrary value in the "value" property""" + """String that has an additional value attached. - def __new__(cls, str_value, value): - return super(EnumValue, cls).__new__(cls, str_value) + This is used to attach SQLAlchemy model columns to Enum symbols. + """ + + def __new__(cls, s, value): + return super(EnumValue, cls).__new__(cls, s) - def __init__(self, str_value, value): + def __init__(self, _s, value): super(EnumValue, self).__init__() self.value = value -# Cache for the generated enums, to avoid name clash -_ENUM_CACHE = {} - - -def _sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): - name = name or cls.__name__ + "SortEnum" - if name in _ENUM_CACHE: - return _ENUM_CACHE[name] - items = [] - default = [] - for column in inspect(cls).columns.values(): - asc_name = symbol_name(column.name, True) - asc_value = EnumValue(asc_name, column.asc()) - desc_name = symbol_name(column.name, False) - desc_value = EnumValue(desc_name, column.desc()) - if column.primary_key: - default.append(asc_value) - items.extend(((asc_name, asc_value), (desc_name, desc_value))) - enum = Enum(name, items) - _ENUM_CACHE[name] = (enum, default) - return enum, default - - -def sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): - """Create Graphene Enum for sorting a SQLAlchemy class query - - Parameters - - cls : Sqlalchemy model class - Model used to create the sort enumerator - - name : str, optional, default None - Name to use for the enumerator. If not provided it will be set to `cls.__name__ + 'SortEnum'` - - symbol_name : function, optional, default `_symbol_name` - Function which takes the column name and a boolean indicating if the sort direction is ascending, - and returns the symbol name for the current column and sort direction. - The default function will create, for a column named 'foo', the symbols 'foo_asc' and 'foo_desc' - - Returns - - Enum - The Graphene enumerator +def _deprecated_default_symbol_name(column_name, sort_asc): + return column_name + ("_asc" if sort_asc else "_desc") + + +# unfortunately, we cannot use lru_cache because we still support Python 2 +_deprecated_object_type_cache = {} + + +def _deprecated_object_type_for_model(cls, name): + + try: + return _deprecated_object_type_cache[cls, name] + except KeyError: + from .types import SQLAlchemyObjectType + + obj_type_name = name or cls.__name__ + + class ObjType(SQLAlchemyObjectType): + class Meta: + name = obj_type_name + model = cls + + _deprecated_object_type_cache[cls, name] = ObjType + return ObjType + + +def sort_enum_for_model(cls, name=None, symbol_name=None): + """Get a Graphene Enum for sorting the given model class. + + This is deprecated, please use object_type.sort_enum() instead. """ - enum, _ = _sort_enum_for_model(cls, name, symbol_name) - return enum + warnings.warn( + "sort_enum_for_model() is deprecated; use object_type.sort_enum() instead.", + DeprecationWarning, + stacklevel=2, + ) + + from .enums import sort_enum_for_object_type + + return sort_enum_for_object_type( + _deprecated_object_type_for_model(cls, name), + name, + get_symbol_name=symbol_name or _deprecated_default_symbol_name, + ) def sort_argument_for_model(cls, has_default=True): - """Returns a Graphene argument for the sort field that accepts a list of sorting directions for a model. - If `has_default` is True (the default) it will sort the result by the primary key(s) + """Get a Graphene Argument for sorting the given model class. + + This is deprecated, please use object_type.sort_argument() instead. """ - enum, default = _sort_enum_for_model(cls) + warnings.warn( + "sort_argument_for_model() is deprecated;" + " use object_type.sort_argument() instead.", + DeprecationWarning, + stacklevel=2, + ) + + from graphene import Argument, List + from .enums import sort_enum_for_object_type + + enum = sort_enum_for_object_type( + _deprecated_object_type_for_model(cls, None), + get_symbol_name=_deprecated_default_symbol_name, + ) if not has_default: - default = None - return Argument(List(enum), default_value=default) + enum.default = None + + return Argument(List(enum), default_value=enum.default) diff --git a/setup.cfg b/setup.cfg index 7fd23df6..39a48fd2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ max-line-length = 120 [isort] known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=flask,nameko,promise,py,pytest,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=database,flask,models,nameko,promise,py,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER no_lines_before=FIRSTPARTY