Skip to content

Proof of concept: annotation-based argument extractor/replacer #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions uarray/tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import uarray
from uarray.typing import _generate_arg_extractor_replacer, DispatchableArg
from typing import Optional


def example_function(a: int, b: float, c: Optional[str] = None):
pass


EXAMPLE_ANNOTATIONS = (
DispatchableArg("a", dispatch_type="int", coercible=True),
DispatchableArg("b", dispatch_type="float", coercible=False),
)


def test_automatic_extractor():
extractor, _ = _generate_arg_extractor_replacer(
example_function, EXAMPLE_ANNOTATIONS
)

def validate_dispatchables(dispatchables, a, b):
assert isinstance(dispatchables, tuple)
assert len(dispatchables) == 2
assert dispatchables[0].value is a
assert dispatchables[0].type == "int"
assert dispatchables[0].coercible == True

assert dispatchables[1].value is b
assert dispatchables[1].type == "float"
assert dispatchables[1].coercible == False

a, b = 1, 2.0
validate_dispatchables(extractor(a, b), a, b)
validate_dispatchables(extractor(a, b=b), a, b)
validate_dispatchables(extractor(b=b, a=a), a, b)
validate_dispatchables(extractor(c="c", a=a, b=b), a, b)
validate_dispatchables(extractor(a, b, "c"), a, b)


def test_automatic_replacer():
_, replacer = _generate_arg_extractor_replacer(
example_function, EXAMPLE_ANNOTATIONS
)

a, b = 1, 2.0
d = (3, 4.0)

args, kwargs = replacer((a, b), {}, d)
assert args == (3, 4.0)
assert kwargs == {}

args, kwargs = replacer((a,), dict(b=b), d)
assert args == (3,)
assert kwargs == dict(b=4.0)

args, kwargs = replacer((), dict(a=a, b=b), d)
assert args == ()
assert kwargs == dict(a=3, b=4.0)

args, kwargs = replacer((a, b, "c"), dict(), d)
assert args == (3, 4.0, "c")
assert kwargs == dict()

args, kwargs = replacer((a, b), dict(c="c"), d)
assert args == (3, 4.0)
assert kwargs == dict(c="c")
72 changes: 72 additions & 0 deletions uarray/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any, Callable, Sequence
import inspect
from dataclasses import dataclass
import functools

import uarray


@dataclass(frozen=True)
class DispatchableArg:
name: str
dispatch_type: Any
coercible: bool = True


def _generate_arg_extractor_replacer(
func: Callable, dispatch_args: Sequence[DispatchableArg]
):
sig = inspect.signature(func)
dispatchable_args = []

annotations = {}
for d in dispatch_args:
if d.name in annotations:
raise ValueError(f"Duplicate DispatchableArg annotation for '{d.name}'")

annotations[d.name] = d

for i, p in enumerate(sig.parameters.values()):
ann = annotations.get(p.name, None)
if ann is None:
continue

dispatchable_args.append((i, ann))

@functools.wraps(func)
def arg_extractor(*args, **kwargs):
# Raise appropriate TypeError if the signature doesn't match
func(*args, **kwargs)

dispatchables = []
for i, ann in dispatchable_args:
if len(args) > i:
dispatchables.append(
uarray.Dispatchable(args[i], ann.dispatch_type, ann.coercible)
)
elif ann.name in kwargs:
dispatchables.append(
uarray.Dispatchable(
kwargs[ann.name], ann.dispatch_type, ann.coercible
)
)

return tuple(dispatchables)

def arg_replacer(args, kwargs, dispatchables):
new_args = list(args)
new_kwargs = kwargs.copy()
cur_idx = 0

for i, ann in dispatchable_args:
if len(args) > i:
new_args[i] = dispatchables[cur_idx]
cur_idx += 1
elif ann.name in kwargs:
new_kwargs[ann.name] = dispatchables[cur_idx]
cur_idx += 1

assert cur_idx == len(dispatchables)
return tuple(new_args), new_kwargs

return arg_extractor, arg_replacer