3
3
from contextlib import contextmanager
4
4
from functools import singledispatch
5
5
from textwrap import dedent
6
- from typing import Union
6
+ from typing import TYPE_CHECKING , Callable , Optional , Union , cast
7
7
8
8
import numba
9
9
import numba .np .unsafe .ndarray as numba_ndarray
22
22
from pytensor .compile .ops import DeepCopyOp
23
23
from pytensor .graph .basic import Apply , NoParams
24
24
from pytensor .graph .fg import FunctionGraph
25
+ from pytensor .graph .op import Op
25
26
from pytensor .graph .type import Type
26
27
from pytensor .ifelse import IfElse
27
28
from pytensor .link .utils import (
48
49
from pytensor .tensor .type_other import MakeSlice , NoneConst
49
50
50
51
52
+ if TYPE_CHECKING :
53
+ from pytensor .graph .op import StorageMapType
54
+
55
+
51
56
def numba_njit (* args , ** kwargs ):
52
57
53
58
if len (args ) > 0 and callable (args [0 ]):
@@ -339,9 +344,42 @@ def numba_const_convert(data, dtype=None, **kwargs):
339
344
return data
340
345
341
346
347
+ def numba_funcify (obj , node = None , storage_map = None , ** kwargs ) -> Callable :
348
+ """Convert `obj` to a Numba-JITable object."""
349
+ return _numba_funcify (obj , node = node , storage_map = storage_map , ** kwargs )
350
+
351
+
342
352
@singledispatch
343
- def numba_funcify (op , node = None , storage_map = None , ** kwargs ):
344
- """Create a Numba compatible function from an PyTensor `Op`."""
353
+ def _numba_funcify (
354
+ obj ,
355
+ node : Optional [Apply ] = None ,
356
+ storage_map : Optional ["StorageMapType" ] = None ,
357
+ ** kwargs ,
358
+ ) -> Callable :
359
+ r"""Dispatch on PyTensor object types to perform Numba conversions.
360
+
361
+ Arguments
362
+ ---------
363
+ obj
364
+ The object used to determine the appropriate conversion function based
365
+ on its type. This is generally an `Op` instance, but `FunctionGraph`\s
366
+ are also supported.
367
+ node
368
+ When `obj` is an `Op`, this value should be the corresponding `Apply` node.
369
+ storage_map
370
+ A storage map with, for example, the constant and `SharedVariable` values
371
+ of the graph being converted.
372
+
373
+ Returns
374
+ -------
375
+ A `Callable` that can be JIT-compiled in Numba using `numba.jit`.
376
+
377
+ """
378
+
379
+
380
+ @_numba_funcify .register (Op )
381
+ def numba_funcify_perform (op , node , storage_map = None , ** kwargs ) -> Callable :
382
+ """Create a Numba compatible function from an PyTensor `Op.perform`."""
345
383
346
384
warnings .warn (
347
385
f"Numba will use object mode to run { op } 's perform method" ,
@@ -392,10 +430,10 @@ def perform(*inputs):
392
430
ret = py_perform_return (inputs )
393
431
return ret
394
432
395
- return perform
433
+ return cast ( Callable , perform )
396
434
397
435
398
- @numba_funcify .register (OpFromGraph )
436
+ @_numba_funcify .register (OpFromGraph )
399
437
def numba_funcify_OpFromGraph (op , node = None , ** kwargs ):
400
438
401
439
_ = kwargs .pop ("storage_map" , None )
@@ -417,7 +455,7 @@ def opfromgraph(*inputs):
417
455
return opfromgraph
418
456
419
457
420
- @numba_funcify .register (FunctionGraph )
458
+ @_numba_funcify .register (FunctionGraph )
421
459
def numba_funcify_FunctionGraph (
422
460
fgraph ,
423
461
node = None ,
@@ -525,9 +563,9 @@ def {fn_name}({", ".join(input_names)}):
525
563
return subtensor_def_src
526
564
527
565
528
- @numba_funcify .register (Subtensor )
529
- @numba_funcify .register (AdvancedSubtensor )
530
- @numba_funcify .register (AdvancedSubtensor1 )
566
+ @_numba_funcify .register (Subtensor )
567
+ @_numba_funcify .register (AdvancedSubtensor )
568
+ @_numba_funcify .register (AdvancedSubtensor1 )
531
569
def numba_funcify_Subtensor (op , node , ** kwargs ):
532
570
533
571
subtensor_def_src = create_index_func (
@@ -543,8 +581,8 @@ def numba_funcify_Subtensor(op, node, **kwargs):
543
581
return numba_njit (subtensor_fn )
544
582
545
583
546
- @numba_funcify .register (IncSubtensor )
547
- @numba_funcify .register (AdvancedIncSubtensor )
584
+ @_numba_funcify .register (IncSubtensor )
585
+ @_numba_funcify .register (AdvancedIncSubtensor )
548
586
def numba_funcify_IncSubtensor (op , node , ** kwargs ):
549
587
550
588
incsubtensor_def_src = create_index_func (
@@ -560,7 +598,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
560
598
return numba_njit (incsubtensor_fn )
561
599
562
600
563
- @numba_funcify .register (AdvancedIncSubtensor1 )
601
+ @_numba_funcify .register (AdvancedIncSubtensor1 )
564
602
def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
565
603
inplace = op .inplace
566
604
set_instead_of_inc = op .set_instead_of_inc
@@ -593,7 +631,7 @@ def advancedincsubtensor1(x, vals, idxs):
593
631
return advancedincsubtensor1
594
632
595
633
596
- @numba_funcify .register (DeepCopyOp )
634
+ @_numba_funcify .register (DeepCopyOp )
597
635
def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
598
636
599
637
# Scalars are apparently returned as actual Python scalar types and not
@@ -615,26 +653,26 @@ def deepcopyop(x):
615
653
return deepcopyop
616
654
617
655
618
- @numba_funcify .register (MakeSlice )
619
- def numba_funcify_MakeSlice (op , ** kwargs ):
656
+ @_numba_funcify .register (MakeSlice )
657
+ def numba_funcify_MakeSlice (op , node , ** kwargs ):
620
658
@numba_njit
621
659
def makeslice (* x ):
622
660
return slice (* x )
623
661
624
662
return makeslice
625
663
626
664
627
- @numba_funcify .register (Shape )
628
- def numba_funcify_Shape (op , ** kwargs ):
665
+ @_numba_funcify .register (Shape )
666
+ def numba_funcify_Shape (op , node , ** kwargs ):
629
667
@numba_njit (inline = "always" )
630
668
def shape (x ):
631
669
return np .asarray (np .shape (x ))
632
670
633
671
return shape
634
672
635
673
636
- @numba_funcify .register (Shape_i )
637
- def numba_funcify_Shape_i (op , ** kwargs ):
674
+ @_numba_funcify .register (Shape_i )
675
+ def numba_funcify_Shape_i (op , node , ** kwargs ):
638
676
i = op .i
639
677
640
678
@numba_njit (inline = "always" )
@@ -664,8 +702,8 @@ def codegen(context, builder, signature, args):
664
702
return sig , codegen
665
703
666
704
667
- @numba_funcify .register (Reshape )
668
- def numba_funcify_Reshape (op , ** kwargs ):
705
+ @_numba_funcify .register (Reshape )
706
+ def numba_funcify_Reshape (op , node , ** kwargs ):
669
707
ndim = op .ndim
670
708
671
709
if ndim == 0 :
@@ -687,7 +725,7 @@ def reshape(x, shape):
687
725
return reshape
688
726
689
727
690
- @numba_funcify .register (SpecifyShape )
728
+ @_numba_funcify .register (SpecifyShape )
691
729
def numba_funcify_SpecifyShape (op , node , ** kwargs ):
692
730
shape_inputs = node .inputs [1 :]
693
731
shape_input_names = ["shape_" + str (i ) for i in range (len (shape_inputs ))]
@@ -734,7 +772,7 @@ def inputs_cast(x):
734
772
return inputs_cast
735
773
736
774
737
- @numba_funcify .register (Dot )
775
+ @_numba_funcify .register (Dot )
738
776
def numba_funcify_Dot (op , node , ** kwargs ):
739
777
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
740
778
# float.
@@ -749,7 +787,7 @@ def dot(x, y):
749
787
return dot
750
788
751
789
752
- @numba_funcify .register (Softplus )
790
+ @_numba_funcify .register (Softplus )
753
791
def numba_funcify_Softplus (op , node , ** kwargs ):
754
792
755
793
x_dtype = np .dtype (node .inputs [0 ].dtype )
@@ -768,7 +806,7 @@ def softplus(x):
768
806
return softplus
769
807
770
808
771
- @numba_funcify .register (Cholesky )
809
+ @_numba_funcify .register (Cholesky )
772
810
def numba_funcify_Cholesky (op , node , ** kwargs ):
773
811
lower = op .lower
774
812
@@ -804,7 +842,7 @@ def cholesky(a):
804
842
return cholesky
805
843
806
844
807
- @numba_funcify .register (Solve )
845
+ @_numba_funcify .register (Solve )
808
846
def numba_funcify_Solve (op , node , ** kwargs ):
809
847
810
848
assume_a = op .assume_a
@@ -851,7 +889,7 @@ def solve(a, b):
851
889
return solve
852
890
853
891
854
- @numba_funcify .register (BatchedDot )
892
+ @_numba_funcify .register (BatchedDot )
855
893
def numba_funcify_BatchedDot (op , node , ** kwargs ):
856
894
dtype = node .outputs [0 ].type .numpy_dtype
857
895
@@ -872,7 +910,7 @@ def batched_dot(x, y):
872
910
# optimizations are apparently already performed by Numba
873
911
874
912
875
- @numba_funcify .register (IfElse )
913
+ @_numba_funcify .register (IfElse )
876
914
def numba_funcify_IfElse (op , ** kwargs ):
877
915
n_outs = op .n_outs
878
916
0 commit comments