Skip to content

Commit 2690eef

Browse files
brandonwillardtwiecki
authored andcommitted
Separate interface and dispatch of numba_funcify
1 parent bdfba42 commit 2690eef

File tree

11 files changed

+168
-120
lines changed

11 files changed

+168
-120
lines changed

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,15 @@ Here's an example for :class:`IfElse`:
8383
return res if n_outs > 1 else res[0]
8484
8585
86-
Step 3: Register the function with the `jax_funcify` dispatcher
86+
Step 3: Register the function with the `_jax_funcify` dispatcher
8787
---------------------------------------------------------------
8888

8989
With the PyTensor `Op` replicated in JAX, we’ll need to register the
9090
function with the PyTensor JAX `Linker`. This is done through the use of
9191
`singledispatch`. If you don't know how `singledispatch` works, see the
9292
`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch>`_.
9393

94-
The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and
94+
The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.basic._numba_funcify` and
9595
:func:`pytensor.link.jax.dispatch.jax_funcify`.
9696

9797
Here’s an example for the `Eye`\ `Op`:

pytensor/link/numba/dispatch/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# isort: off
2-
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_const_convert
2+
from pytensor.link.numba.dispatch.basic import (
3+
numba_funcify,
4+
numba_const_convert,
5+
numba_njit,
6+
)
37

48
# Load dispatch specializations
59
import pytensor.link.numba.dispatch.scalar

pytensor/link/numba/dispatch/basic.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import contextmanager
44
from functools import singledispatch
55
from textwrap import dedent
6-
from typing import Union
6+
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
77

88
import numba
99
import numba.np.unsafe.ndarray as numba_ndarray
@@ -22,6 +22,7 @@
2222
from pytensor.compile.ops import DeepCopyOp
2323
from pytensor.graph.basic import Apply, NoParams
2424
from pytensor.graph.fg import FunctionGraph
25+
from pytensor.graph.op import Op
2526
from pytensor.graph.type import Type
2627
from pytensor.ifelse import IfElse
2728
from pytensor.link.utils import (
@@ -48,6 +49,10 @@
4849
from pytensor.tensor.type_other import MakeSlice, NoneConst
4950

5051

52+
if TYPE_CHECKING:
53+
from pytensor.graph.op import StorageMapType
54+
55+
5156
def numba_njit(*args, **kwargs):
5257

5358
if len(args) > 0 and callable(args[0]):
@@ -339,9 +344,42 @@ def numba_const_convert(data, dtype=None, **kwargs):
339344
return data
340345

341346

347+
def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable:
348+
"""Convert `obj` to a Numba-JITable object."""
349+
return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs)
350+
351+
342352
@singledispatch
343-
def numba_funcify(op, node=None, storage_map=None, **kwargs):
344-
"""Create a Numba compatible function from an PyTensor `Op`."""
353+
def _numba_funcify(
354+
obj,
355+
node: Optional[Apply] = None,
356+
storage_map: Optional["StorageMapType"] = None,
357+
**kwargs,
358+
) -> Callable:
359+
r"""Dispatch on PyTensor object types to perform Numba conversions.
360+
361+
Arguments
362+
---------
363+
obj
364+
The object used to determine the appropriate conversion function based
365+
on its type. This is generally an `Op` instance, but `FunctionGraph`\s
366+
are also supported.
367+
node
368+
When `obj` is an `Op`, this value should be the corresponding `Apply` node.
369+
storage_map
370+
A storage map with, for example, the constant and `SharedVariable` values
371+
of the graph being converted.
372+
373+
Returns
374+
-------
375+
A `Callable` that can be JIT-compiled in Numba using `numba.jit`.
376+
377+
"""
378+
379+
380+
@_numba_funcify.register(Op)
381+
def numba_funcify_perform(op, node, storage_map=None, **kwargs) -> Callable:
382+
"""Create a Numba compatible function from an PyTensor `Op.perform`."""
345383

346384
warnings.warn(
347385
f"Numba will use object mode to run {op}'s perform method",
@@ -392,10 +430,10 @@ def perform(*inputs):
392430
ret = py_perform_return(inputs)
393431
return ret
394432

395-
return perform
433+
return cast(Callable, perform)
396434

397435

398-
@numba_funcify.register(OpFromGraph)
436+
@_numba_funcify.register(OpFromGraph)
399437
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
400438

401439
_ = kwargs.pop("storage_map", None)
@@ -417,7 +455,7 @@ def opfromgraph(*inputs):
417455
return opfromgraph
418456

419457

420-
@numba_funcify.register(FunctionGraph)
458+
@_numba_funcify.register(FunctionGraph)
421459
def numba_funcify_FunctionGraph(
422460
fgraph,
423461
node=None,
@@ -525,9 +563,9 @@ def {fn_name}({", ".join(input_names)}):
525563
return subtensor_def_src
526564

527565

528-
@numba_funcify.register(Subtensor)
529-
@numba_funcify.register(AdvancedSubtensor)
530-
@numba_funcify.register(AdvancedSubtensor1)
566+
@_numba_funcify.register(Subtensor)
567+
@_numba_funcify.register(AdvancedSubtensor)
568+
@_numba_funcify.register(AdvancedSubtensor1)
531569
def numba_funcify_Subtensor(op, node, **kwargs):
532570

533571
subtensor_def_src = create_index_func(
@@ -543,8 +581,8 @@ def numba_funcify_Subtensor(op, node, **kwargs):
543581
return numba_njit(subtensor_fn)
544582

545583

546-
@numba_funcify.register(IncSubtensor)
547-
@numba_funcify.register(AdvancedIncSubtensor)
584+
@_numba_funcify.register(IncSubtensor)
585+
@_numba_funcify.register(AdvancedIncSubtensor)
548586
def numba_funcify_IncSubtensor(op, node, **kwargs):
549587

550588
incsubtensor_def_src = create_index_func(
@@ -560,7 +598,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
560598
return numba_njit(incsubtensor_fn)
561599

562600

563-
@numba_funcify.register(AdvancedIncSubtensor1)
601+
@_numba_funcify.register(AdvancedIncSubtensor1)
564602
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
565603
inplace = op.inplace
566604
set_instead_of_inc = op.set_instead_of_inc
@@ -593,7 +631,7 @@ def advancedincsubtensor1(x, vals, idxs):
593631
return advancedincsubtensor1
594632

595633

596-
@numba_funcify.register(DeepCopyOp)
634+
@_numba_funcify.register(DeepCopyOp)
597635
def numba_funcify_DeepCopyOp(op, node, **kwargs):
598636

599637
# Scalars are apparently returned as actual Python scalar types and not
@@ -615,26 +653,26 @@ def deepcopyop(x):
615653
return deepcopyop
616654

617655

618-
@numba_funcify.register(MakeSlice)
619-
def numba_funcify_MakeSlice(op, **kwargs):
656+
@_numba_funcify.register(MakeSlice)
657+
def numba_funcify_MakeSlice(op, node, **kwargs):
620658
@numba_njit
621659
def makeslice(*x):
622660
return slice(*x)
623661

624662
return makeslice
625663

626664

627-
@numba_funcify.register(Shape)
628-
def numba_funcify_Shape(op, **kwargs):
665+
@_numba_funcify.register(Shape)
666+
def numba_funcify_Shape(op, node, **kwargs):
629667
@numba_njit(inline="always")
630668
def shape(x):
631669
return np.asarray(np.shape(x))
632670

633671
return shape
634672

635673

636-
@numba_funcify.register(Shape_i)
637-
def numba_funcify_Shape_i(op, **kwargs):
674+
@_numba_funcify.register(Shape_i)
675+
def numba_funcify_Shape_i(op, node, **kwargs):
638676
i = op.i
639677

640678
@numba_njit(inline="always")
@@ -664,8 +702,8 @@ def codegen(context, builder, signature, args):
664702
return sig, codegen
665703

666704

667-
@numba_funcify.register(Reshape)
668-
def numba_funcify_Reshape(op, **kwargs):
705+
@_numba_funcify.register(Reshape)
706+
def numba_funcify_Reshape(op, node, **kwargs):
669707
ndim = op.ndim
670708

671709
if ndim == 0:
@@ -687,7 +725,7 @@ def reshape(x, shape):
687725
return reshape
688726

689727

690-
@numba_funcify.register(SpecifyShape)
728+
@_numba_funcify.register(SpecifyShape)
691729
def numba_funcify_SpecifyShape(op, node, **kwargs):
692730
shape_inputs = node.inputs[1:]
693731
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
@@ -734,7 +772,7 @@ def inputs_cast(x):
734772
return inputs_cast
735773

736774

737-
@numba_funcify.register(Dot)
775+
@_numba_funcify.register(Dot)
738776
def numba_funcify_Dot(op, node, **kwargs):
739777
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
740778
# float.
@@ -749,7 +787,7 @@ def dot(x, y):
749787
return dot
750788

751789

752-
@numba_funcify.register(Softplus)
790+
@_numba_funcify.register(Softplus)
753791
def numba_funcify_Softplus(op, node, **kwargs):
754792

755793
x_dtype = np.dtype(node.inputs[0].dtype)
@@ -768,7 +806,7 @@ def softplus(x):
768806
return softplus
769807

770808

771-
@numba_funcify.register(Cholesky)
809+
@_numba_funcify.register(Cholesky)
772810
def numba_funcify_Cholesky(op, node, **kwargs):
773811
lower = op.lower
774812

@@ -804,7 +842,7 @@ def cholesky(a):
804842
return cholesky
805843

806844

807-
@numba_funcify.register(Solve)
845+
@_numba_funcify.register(Solve)
808846
def numba_funcify_Solve(op, node, **kwargs):
809847

810848
assume_a = op.assume_a
@@ -851,7 +889,7 @@ def solve(a, b):
851889
return solve
852890

853891

854-
@numba_funcify.register(BatchedDot)
892+
@_numba_funcify.register(BatchedDot)
855893
def numba_funcify_BatchedDot(op, node, **kwargs):
856894
dtype = node.outputs[0].type.numpy_dtype
857895

@@ -872,7 +910,7 @@ def batched_dot(x, y):
872910
# optimizations are apparently already performed by Numba
873911

874912

875-
@numba_funcify.register(IfElse)
913+
@_numba_funcify.register(IfElse)
876914
def numba_funcify_IfElse(op, **kwargs):
877915
n_outs = op.n_outs
878916

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.graph.op import Op
1313
from pytensor.link.numba.dispatch import basic as numba_basic
1414
from pytensor.link.numba.dispatch.basic import (
15+
_numba_funcify,
1516
create_numba_signature,
1617
create_tuple_creator,
1718
numba_funcify,
@@ -422,7 +423,7 @@ def axis_apply_fn(x):
422423
return axis_apply_fn
423424

424425

425-
@numba_funcify.register(Elemwise)
426+
@_numba_funcify.register(Elemwise)
426427
def numba_funcify_Elemwise(op, node, **kwargs):
427428

428429
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
@@ -474,7 +475,7 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
474475
return elemwise_fn
475476

476477

477-
@numba_funcify.register(CAReduce)
478+
@_numba_funcify.register(CAReduce)
478479
def numba_funcify_CAReduce(op, node, **kwargs):
479480
axes = op.axis
480481
if axes is None:
@@ -512,7 +513,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
512513
return careduce_fn
513514

514515

515-
@numba_funcify.register(DimShuffle)
516+
@_numba_funcify.register(DimShuffle)
516517
def numba_funcify_DimShuffle(op, **kwargs):
517518
shuffle = tuple(op.shuffle)
518519
transposition = tuple(op.transposition)
@@ -590,7 +591,7 @@ def dimshuffle(x):
590591
return dimshuffle
591592

592593

593-
@numba_funcify.register(Softmax)
594+
@_numba_funcify.register(Softmax)
594595
def numba_funcify_Softmax(op, node, **kwargs):
595596

596597
x_at = node.inputs[0]
@@ -627,7 +628,7 @@ def softmax_py_fn(x):
627628
return softmax
628629

629630

630-
@numba_funcify.register(SoftmaxGrad)
631+
@_numba_funcify.register(SoftmaxGrad)
631632
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
632633

633634
sm_at = node.inputs[1]
@@ -658,7 +659,7 @@ def softmax_grad_py_fn(dy, sm):
658659
return softmax_grad
659660

660661

661-
@numba_funcify.register(LogSoftmax)
662+
@_numba_funcify.register(LogSoftmax)
662663
def numba_funcify_LogSoftmax(op, node, **kwargs):
663664

664665
x_at = node.inputs[0]
@@ -692,7 +693,7 @@ def log_softmax_py_fn(x):
692693
return log_softmax
693694

694695

695-
@numba_funcify.register(MaxAndArgmax)
696+
@_numba_funcify.register(MaxAndArgmax)
696697
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
697698
axis = op.axis
698699
x_at = node.inputs[0]

0 commit comments

Comments
 (0)