25
25
26
26
from pytensor import tensor as pt
27
27
from pytensor .compile .builders import OpFromGraph
28
- from pytensor .graph import FunctionGraph , node_rewriter
29
- from pytensor .graph .basic import Node , Variable
30
- from pytensor .graph .replace import clone_replace
31
- from pytensor .graph .rewriting .basic import in2out
28
+ from pytensor .graph import FunctionGraph , clone_replace , node_rewriter
29
+ from pytensor .graph .basic import Node , Variable , io_toposort
30
+ from pytensor .graph .features import ReplaceValidate
31
+ from pytensor .graph .rewriting .basic import GraphRewriter , in2out
32
32
from pytensor .graph .utils import MetaType
33
+ from pytensor .scan .op import Scan
33
34
from pytensor .tensor .basic import as_tensor_variable
34
35
from pytensor .tensor .random .op import RandomVariable
35
36
from pytensor .tensor .random .rewriting import local_subtensor_rv_lift
37
+ from pytensor .tensor .random .type import RandomGeneratorType , RandomType
36
38
from pytensor .tensor .random .utils import normalize_size_param
37
39
from pytensor .tensor .rewriting .shape import ShapeFeature
38
40
from pytensor .tensor .variable import TensorVariable
83
85
PLATFORM = sys .platform
84
86
85
87
88
+ class MomentRewrite (GraphRewriter ):
89
+ def rewrite_moment_scan_node (self , node ):
90
+ if not isinstance (node .op , Scan ):
91
+ return
92
+
93
+ node_inputs , node_outputs = node .op .inner_inputs , node .op .inner_outputs
94
+ op = node .op
95
+
96
+ local_fgraph_topo = io_toposort (node_inputs , node_outputs )
97
+
98
+ replace_with_moment = []
99
+ to_replace_set = set ()
100
+
101
+ for nd in local_fgraph_topo :
102
+ if nd not in to_replace_set and isinstance (
103
+ nd .op , (RandomVariable , SymbolicRandomVariable )
104
+ ):
105
+ replace_with_moment .append (nd .out )
106
+ to_replace_set .add (nd )
107
+ givens = {}
108
+ if len (replace_with_moment ) > 0 :
109
+ for item in replace_with_moment :
110
+ givens [item ] = moment (item )
111
+ else :
112
+ return
113
+ op_outs = clone_replace (node_outputs , replace = givens )
114
+
115
+ nwScan = Scan (
116
+ node_inputs ,
117
+ op_outs ,
118
+ op .info ,
119
+ mode = op .mode ,
120
+ profile = op .profile ,
121
+ truncate_gradient = op .truncate_gradient ,
122
+ name = op .name ,
123
+ allow_gc = op .allow_gc ,
124
+ )
125
+ nw_node = nwScan (* (node .inputs ), return_list = True )[0 ].owner
126
+ return nw_node
127
+
128
+ def add_requirements (self , fgraph ):
129
+ fgraph .attach_feature (ReplaceValidate ())
130
+
131
+ def apply (self , fgraph ):
132
+ for node in fgraph .toposort ():
133
+ if isinstance (node .op , (RandomVariable , SymbolicRandomVariable )):
134
+ fgraph .replace (node .out , moment (node .out ))
135
+ elif isinstance (node .op , Scan ):
136
+ new_node = self .rewrite_moment_scan_node (node )
137
+ if new_node is not None :
138
+ fgraph .replace_all (tuple (zip (node .outputs , new_node .outputs )))
139
+
140
+
86
141
class _Unpickling :
87
142
pass
88
143
@@ -601,6 +656,20 @@ def update(self, node: Node):
601
656
return updates
602
657
603
658
659
+ @_moment .register (CustomSymbolicDistRV )
660
+ def dist_moment (op , rv , * args ):
661
+ node = rv .owner
662
+ rv_out_idx = node .outputs .index (rv )
663
+
664
+ fgraph = op .fgraph .clone ()
665
+ replace_moments = MomentRewrite ()
666
+ replace_moments .rewrite (fgraph )
667
+ # Replace dummy inner inputs by outer inputs
668
+ fgraph .replace_all (tuple (zip (op .inner_inputs , args )), import_missing = True )
669
+ moment = fgraph .outputs [rv_out_idx ]
670
+ return moment
671
+
672
+
604
673
class _CustomSymbolicDist (Distribution ):
605
674
rv_type = CustomSymbolicDistRV
606
675
@@ -622,14 +691,6 @@ def dist(
622
691
if logcdf is None :
623
692
logcdf = default_not_implemented (class_name , "logcdf" )
624
693
625
- if moment is None :
626
- moment = functools .partial (
627
- default_moment ,
628
- rv_name = class_name ,
629
- has_fallback = True ,
630
- ndim_supp = ndim_supp ,
631
- )
632
-
633
694
return super ().dist (
634
695
dist_params ,
635
696
class_name = class_name ,
@@ -685,9 +746,19 @@ def custom_dist_logp(op, values, size, *params, **kwargs):
685
746
def custom_dist_logcdf (op , value , size , * params , ** kwargs ):
686
747
return logcdf (value , * params [: len (dist_params )])
687
748
688
- @_moment .register (rv_type )
689
- def custom_dist_get_moment (op , rv , size , * params ):
690
- return moment (rv , size , * params [: len (params )])
749
+ if moment is not None :
750
+
751
+ @_moment .register (rv_type )
752
+ def custom_dist_get_moment (op , rv , size , * params ):
753
+ return moment (
754
+ rv ,
755
+ size ,
756
+ * [
757
+ p
758
+ for p in params
759
+ if not isinstance (p .type , (RandomType , RandomGeneratorType ))
760
+ ],
761
+ )
691
762
692
763
@_change_dist_size .register (rv_type )
693
764
def change_custom_symbolic_dist_size (op , rv , new_size , expand ):
0 commit comments