diff --git a/docs/execution/dataloader.rst b/docs/execution/dataloader.rst index 8a8e2ae3e..618909512 100644 --- a/docs/execution/dataloader.rst +++ b/docs/execution/dataloader.rst @@ -4,7 +4,7 @@ Dataloader DataLoader is a generic utility to be used as part of your application's data fetching layer to provide a simplified and consistent API over various remote data sources such as databases or web services via batching -and caching. +and caching. It is provided by a separate package `aiodataloader `. Batching @@ -15,21 +15,19 @@ Create loaders by providing a batch loading function. .. code:: python - from promise import Promise - from promise.dataloader import DataLoader + from aiodataloader import DataLoader class UserLoader(DataLoader): - def batch_load_fn(self, keys): - # Here we return a promise that will result on the - # corresponding user for each key in keys - return Promise.resolve([get_user(id=key) for key in keys]) + async def batch_load_fn(self, keys): + # Here we call a function to return a user for each key in keys + return [get_user(id=key) for key in keys] -A batch loading function accepts a list of keys, and returns a ``Promise`` -which resolves to a list of ``values``. +A batch loading async function accepts a list of keys, and returns a list of ``values``. + ``DataLoader`` will coalesce all individual loads which occur within a -single frame of execution (executed once the wrapping promise is resolved) +single frame of execution (executed once the wrapping event loop is resolved) and then call your batch function with all requested keys. @@ -37,9 +35,11 @@ and then call your batch function with all requested keys. user_loader = UserLoader() - user_loader.load(1).then(lambda user: user_loader.load(user.best_friend_id)) + user1 = await user_loader.load(1) + user1_best_friend = await user_loader.load(user1.best_friend_id)) - user_loader.load(2).then(lambda user: user_loader.load(user.best_friend_id)) + user2 = await user_loader.load(2) + user2_best_friend = await user_loader.load(user2.best_friend_id)) A naive application may have issued *four* round-trips to a backend for the @@ -53,9 +53,9 @@ make sure that you then order the query result for the results to match the keys .. code:: python class UserLoader(DataLoader): - def batch_load_fn(self, keys): + async def batch_load_fn(self, keys): users = {user.id: user for user in User.objects.filter(id__in=keys)} - return Promise.resolve([users.get(user_id) for user_id in keys]) + return [users.get(user_id) for user_id in keys] ``DataLoader`` allows you to decouple unrelated parts of your application without @@ -110,8 +110,8 @@ leaner code and at most 4 database requests, and possibly fewer if there are cac best_friend = graphene.Field(lambda: User) friends = graphene.List(lambda: User) - def resolve_best_friend(root, info): - return user_loader.load(root.best_friend_id) + async def resolve_best_friend(root, info): + return await user_loader.load(root.best_friend_id) - def resolve_friends(root, info): - return user_loader.load_many(root.friend_ids) + async def resolve_friends(root, info): + return await user_loader.load_many(root.friend_ids) diff --git a/setup.py b/setup.py index dce6aa6c0..b87f56ccd 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ def run_tests(self): "snapshottest>=0.6,<1", "coveralls>=3.3,<4", "promise>=2.3,<3", + "aiodataloader<1", "mock>=4,<5", "pytz==2022.1", "iso8601>=1,<2", diff --git a/tests_asyncio/test_dataloader.py b/tests_asyncio/test_dataloader.py new file mode 100644 index 000000000..fb8d1630e --- /dev/null +++ b/tests_asyncio/test_dataloader.py @@ -0,0 +1,79 @@ +from collections import namedtuple +from unittest.mock import Mock +from pytest import mark +from aiodataloader import DataLoader + +from graphene import ObjectType, String, Schema, Field, List + + +CHARACTERS = { + "1": {"name": "Luke Skywalker", "sibling": "3"}, + "2": {"name": "Darth Vader", "sibling": None}, + "3": {"name": "Leia Organa", "sibling": "1"}, +} + + +get_character = Mock(side_effect=lambda character_id: CHARACTERS[character_id]) + + +class CharacterType(ObjectType): + name = String() + sibling = Field(lambda: CharacterType) + + async def resolve_sibling(character, info): + if character["sibling"]: + return await info.context.character_loader.load(character["sibling"]) + return None + + +class Query(ObjectType): + skywalker_family = List(CharacterType) + + async def resolve_skywalker_family(_, info): + return await info.context.character_loader.load_many(["1", "2", "3"]) + + +mock_batch_load_fn = Mock( + side_effect=lambda character_ids: [get_character(id) for id in character_ids] +) + + +class CharacterLoader(DataLoader): + async def batch_load_fn(self, character_ids): + return mock_batch_load_fn(character_ids) + + +Context = namedtuple("Context", "character_loader") + + +@mark.asyncio +async def test_basic_dataloader(): + schema = Schema(query=Query) + + character_loader = CharacterLoader() + context = Context(character_loader=character_loader) + + query = """ + { + skywalkerFamily { + name + sibling { + name + } + } + } + """ + + result = await schema.execute_async(query, context=context) + + assert not result.errors + assert result.data == { + "skywalkerFamily": [ + {"name": "Luke Skywalker", "sibling": {"name": "Leia Organa"}}, + {"name": "Darth Vader", "sibling": None}, + {"name": "Leia Organa", "sibling": {"name": "Luke Skywalker"}}, + ] + } + + assert mock_batch_load_fn.call_count == 1 + assert get_character.call_count == 3