diff --git a/uarray/tests/test_typing.py b/uarray/tests/test_typing.py new file mode 100644 index 0000000..cd63fc7 --- /dev/null +++ b/uarray/tests/test_typing.py @@ -0,0 +1,66 @@ +import uarray +from uarray.typing import _generate_arg_extractor_replacer, DispatchableArg +from typing import Optional + + +def example_function(a: int, b: float, c: Optional[str] = None): + pass + + +EXAMPLE_ANNOTATIONS = ( + DispatchableArg("a", dispatch_type="int", coercible=True), + DispatchableArg("b", dispatch_type="float", coercible=False), +) + + +def test_automatic_extractor(): + extractor, _ = _generate_arg_extractor_replacer( + example_function, EXAMPLE_ANNOTATIONS + ) + + def validate_dispatchables(dispatchables, a, b): + assert isinstance(dispatchables, tuple) + assert len(dispatchables) == 2 + assert dispatchables[0].value is a + assert dispatchables[0].type == "int" + assert dispatchables[0].coercible == True + + assert dispatchables[1].value is b + assert dispatchables[1].type == "float" + assert dispatchables[1].coercible == False + + a, b = 1, 2.0 + validate_dispatchables(extractor(a, b), a, b) + validate_dispatchables(extractor(a, b=b), a, b) + validate_dispatchables(extractor(b=b, a=a), a, b) + validate_dispatchables(extractor(c="c", a=a, b=b), a, b) + validate_dispatchables(extractor(a, b, "c"), a, b) + + +def test_automatic_replacer(): + _, replacer = _generate_arg_extractor_replacer( + example_function, EXAMPLE_ANNOTATIONS + ) + + a, b = 1, 2.0 + d = (3, 4.0) + + args, kwargs = replacer((a, b), {}, d) + assert args == (3, 4.0) + assert kwargs == {} + + args, kwargs = replacer((a,), dict(b=b), d) + assert args == (3,) + assert kwargs == dict(b=4.0) + + args, kwargs = replacer((), dict(a=a, b=b), d) + assert args == () + assert kwargs == dict(a=3, b=4.0) + + args, kwargs = replacer((a, b, "c"), dict(), d) + assert args == (3, 4.0, "c") + assert kwargs == dict() + + args, kwargs = replacer((a, b), dict(c="c"), d) + assert args == (3, 4.0) + assert kwargs == dict(c="c") diff --git a/uarray/typing.py b/uarray/typing.py new file mode 100644 index 0000000..9c32d79 --- /dev/null +++ b/uarray/typing.py @@ -0,0 +1,72 @@ +from typing import Any, Callable, Sequence +import inspect +from dataclasses import dataclass +import functools + +import uarray + + +@dataclass(frozen=True) +class DispatchableArg: + name: str + dispatch_type: Any + coercible: bool = True + + +def _generate_arg_extractor_replacer( + func: Callable, dispatch_args: Sequence[DispatchableArg] +): + sig = inspect.signature(func) + dispatchable_args = [] + + annotations = {} + for d in dispatch_args: + if d.name in annotations: + raise ValueError(f"Duplicate DispatchableArg annotation for '{d.name}'") + + annotations[d.name] = d + + for i, p in enumerate(sig.parameters.values()): + ann = annotations.get(p.name, None) + if ann is None: + continue + + dispatchable_args.append((i, ann)) + + @functools.wraps(func) + def arg_extractor(*args, **kwargs): + # Raise appropriate TypeError if the signature doesn't match + func(*args, **kwargs) + + dispatchables = [] + for i, ann in dispatchable_args: + if len(args) > i: + dispatchables.append( + uarray.Dispatchable(args[i], ann.dispatch_type, ann.coercible) + ) + elif ann.name in kwargs: + dispatchables.append( + uarray.Dispatchable( + kwargs[ann.name], ann.dispatch_type, ann.coercible + ) + ) + + return tuple(dispatchables) + + def arg_replacer(args, kwargs, dispatchables): + new_args = list(args) + new_kwargs = kwargs.copy() + cur_idx = 0 + + for i, ann in dispatchable_args: + if len(args) > i: + new_args[i] = dispatchables[cur_idx] + cur_idx += 1 + elif ann.name in kwargs: + new_kwargs[ann.name] = dispatchables[cur_idx] + cur_idx += 1 + + assert cur_idx == len(dispatchables) + return tuple(new_args), new_kwargs + + return arg_extractor, arg_replacer