@@ -684,45 +684,13 @@ def dist(
684
684
logcdf = default_not_implemented (class_name , "logcdf" )
685
685
686
686
def dist_moment (rv , size , * dist_params ):
687
- # size = normalize_size_param(size)
688
- # dummy_size_param = size.type()
689
- # dummy_dist_params = [dist_param.type() for dist_param in dist_params]
690
- # dummy_rv = dist(*dummy_dist_params, dummy_size_param)
691
- # # dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,))
692
- # dummy_params = [dummy_size_param] + dummy_dist_params
693
- # # dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
694
- # rv_type = type(
695
- # class_name,
696
- # (CustomSymbolicDistRV,),
697
- # # If logp is not provided, we try to infer it from the dist graph
698
- # dict(
699
- # inline_logprob=logp is None,
700
- # ),
701
- # )
702
- # rv_op = rv_type(
703
- # inputs=dummy_params,
704
- # # inputs=dummy_dist_params,
705
- # outputs=[dummy_rv],
706
- # ndim_supp=ndim_supp,
707
- # )
708
- # fgraph = rv_op.fgraph.clone()
709
687
fgraph = rv .owner .op .fgraph .clone ()
710
- # inputs = filter_RNGs(rv.owner.op.fgraph.inputs)
711
- # scan_node = rv.owner.op.fgraph.toposort()[-1]
712
- # output = rv.owner.op.fgraph.outputs[-1]
713
- # fgraph = FunctionGraph(inputs=inputs, outputs=[output], clone=True)
714
688
replace_moments = MomentRewrite ()
715
689
replace_moments .rewrite (fgraph )
716
690
for i , par in enumerate ([size ] + list (dist_params )):
717
691
fgraph .replace (fgraph .inputs [i ], par )
718
- # for i, inp in enumerate(replace_rng_nodes(rv.owner.inputs)):
719
- # if isinstance(inp.type, RandomGeneratorType):
720
- # fgraph.replace(fgraph.inputs[i], inp)
721
- # breakpoint()
722
- # outputs = replace_rng_nodes(fgraph.outputs)
723
692
for node in fgraph .toposort ():
724
693
if isinstance (node .op , Scan ):
725
- # breakpoint()
726
694
for inp in node .inputs :
727
695
if isinstance (inp .type , RandomGeneratorType ):
728
696
fgraph .replace (
@@ -731,7 +699,6 @@ def dist_moment(rv, size, *dist_params):
731
699
import_missing = True ,
732
700
)
733
701
moment = fgraph .outputs [- 1 ]
734
- # breakpoint()
735
702
return moment
736
703
737
704
if moment is None :
0 commit comments