Skip to content

inject factories #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
introduction/fastapi
introduction/litestar
introduction/multiple-containers
introduction/inject-factories

.. toctree::
:maxdepth: 1
Expand Down
56 changes: 56 additions & 0 deletions docs/introduction/inject-factories.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Injecting factories

When you need to inject the factory itself, but not the result of its call, use:
1. `.provider` attribute for async resolver
2. `.sync_provider` attribute for sync resolver

Let's first define providers with container:
```python
import dataclasses
import datetime
import typing

from that_depends import BaseContainer, providers


async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]:
yield datetime.datetime.now(tz=datetime.timezone.utc)


@dataclasses.dataclass(kw_only=True, slots=True)
class SomeFactory:
start_at: datetime.datetime


@dataclasses.dataclass(kw_only=True, slots=True)
class FactoryWithFactories:
sync_factory: typing.Callable[..., SomeFactory]
async_factory: typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, SomeFactory]]


class DIContainer(BaseContainer):
async_resource = providers.AsyncResource(create_async_resource)
dependent_factory = providers.Factory(SomeFactory, start_at=async_resource.cast)
factory_with_factories = providers.Factory(
FactoryWithFactories,
sync_factory=dependent_factory.sync_provider,
async_factory=dependent_factory.provider,
)
```

Async factory from `.provider` attribute can be used like this:
```python
factory_with_factories = await DIContainer.factory_with_factories()
instance1 = await factory_with_factories.async_factory()
instance2 = await factory_with_factories.async_factory()
assert instance1 is not instance2
```

Sync factory from `.sync_provider` attribute can be used like this:
```python
await DIContainer.init_async_resources()
factory_with_factories = await DIContainer.factory_with_factories()
instance1 = factory_with_factories.sync_factory()
instance2 = factory_with_factories.sync_factory()
assert instance1 is not instance2
```
150 changes: 75 additions & 75 deletions poetry.lock

Large diffs are not rendered by default.

Empty file added tests/integrations/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
59 changes: 59 additions & 0 deletions tests/test_inject_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import dataclasses
import typing

import pytest

from tests import container
from that_depends import BaseContainer, providers


@dataclasses.dataclass(kw_only=True, slots=True)
class InjectedFactories:
sync_factory: typing.Callable[..., container.DependentFactory]
async_factory: typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, container.DependentFactory]]


class DIContainer(BaseContainer):
sync_resource = providers.Resource(container.create_sync_resource)
async_resource = providers.AsyncResource(container.create_async_resource)

simple_factory = providers.Factory(container.SimpleFactory, dep1="text", dep2=123)
dependent_factory = providers.Factory(
container.DependentFactory,
simple_factory=simple_factory.cast,
sync_resource=sync_resource.cast,
async_resource=async_resource.cast,
)
injected_factories = providers.Factory(
InjectedFactories,
sync_factory=dependent_factory.sync_provider,
async_factory=dependent_factory.provider,
)


async def test_async_provider() -> None:
injected_factories = await DIContainer.injected_factories()
instance1 = await injected_factories.async_factory()
instance2 = await injected_factories.async_factory()

assert isinstance(instance1, container.DependentFactory)
assert isinstance(instance2, container.DependentFactory)
assert instance1 is not instance2

await DIContainer.tear_down()


async def test_sync_provider() -> None:
injected_factories = await DIContainer.injected_factories()
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
injected_factories.sync_factory()

await DIContainer.init_async_resources()
instance1 = injected_factories.sync_factory()
instance2 = injected_factories.sync_factory()

assert isinstance(instance1, container.DependentFactory)
assert isinstance(instance2, container.DependentFactory)
assert instance1 is not instance2

await DIContainer.tear_down()
12 changes: 12 additions & 0 deletions that_depends/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,15 @@ class AbstractResource(AbstractProvider[T], abc.ABC):
@abc.abstractmethod
async def tear_down(self) -> None:
"""Tear down dependency."""


class AbstractFactory(AbstractProvider[T], abc.ABC):
"""Abstract Factory Class."""

@property
def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, T]]:
return self.async_resolve

@property
def sync_provider(self) -> typing.Callable[[], T]:
return self.sync_resolve
8 changes: 4 additions & 4 deletions that_depends/providers/factories.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import typing

from that_depends.providers.base import AbstractProvider
from that_depends.providers.base import AbstractFactory, AbstractProvider


T = typing.TypeVar("T")
P = typing.ParamSpec("P")


class Factory(AbstractProvider[T]):
class Factory(AbstractFactory[T]):
def __init__(self, factory: type[T] | typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
self._factory = factory
self._args = args
Expand All @@ -33,7 +33,7 @@ def sync_resolve(self) -> T:
)


class AsyncFactory(AbstractProvider[T]):
class AsyncFactory(AbstractFactory[T]):
def __init__(self, factory: typing.Callable[P, typing.Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> None:
self._factory = factory
self._args = args
Expand All @@ -49,6 +49,6 @@ async def async_resolve(self) -> T:
**{k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)

def sync_resolve(self) -> T:
def sync_resolve(self) -> typing.NoReturn:
msg = "AsyncFactory cannot be resolved synchronously"
raise RuntimeError(msg)
Loading