Skip to content

Commit 3004799

Browse files
committed
remove rv creation
1 parent ef7a7a0 commit 3004799

File tree

1 file changed

+45
-23
lines changed

1 file changed

+45
-23
lines changed

pymc/distributions/distribution.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import numpy as np
2525

26+
from pytensor import shared
2627
from pytensor import tensor as pt
2728
from pytensor.compile.builders import OpFromGraph
2829
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
@@ -683,33 +684,54 @@ def dist(
683684
logcdf = default_not_implemented(class_name, "logcdf")
684685

685686
def dist_moment(rv, size, *dist_params):
686-
size = normalize_size_param(size)
687-
dummy_size_param = size.type()
688-
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
689-
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
690-
# dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,))
691-
dummy_params = [dummy_size_param] + dummy_dist_params
692-
# dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
693-
rv_type = type(
694-
class_name,
695-
(CustomSymbolicDistRV,),
696-
# If logp is not provided, we try to infer it from the dist graph
697-
dict(
698-
inline_logprob=logp is None,
699-
),
700-
)
701-
rv_op = rv_type(
702-
inputs=dummy_params,
703-
# inputs=dummy_dist_params,
704-
outputs=[dummy_rv],
705-
ndim_supp=ndim_supp,
706-
)
707-
fgraph = rv_op.fgraph.clone()
687+
# size = normalize_size_param(size)
688+
# dummy_size_param = size.type()
689+
# dummy_dist_params = [dist_param.type() for dist_param in dist_params]
690+
# dummy_rv = dist(*dummy_dist_params, dummy_size_param)
691+
# # dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,))
692+
# dummy_params = [dummy_size_param] + dummy_dist_params
693+
# # dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
694+
# rv_type = type(
695+
# class_name,
696+
# (CustomSymbolicDistRV,),
697+
# # If logp is not provided, we try to infer it from the dist graph
698+
# dict(
699+
# inline_logprob=logp is None,
700+
# ),
701+
# )
702+
# rv_op = rv_type(
703+
# inputs=dummy_params,
704+
# # inputs=dummy_dist_params,
705+
# outputs=[dummy_rv],
706+
# ndim_supp=ndim_supp,
707+
# )
708+
# fgraph = rv_op.fgraph.clone()
709+
fgraph = rv.owner.op.fgraph.clone()
710+
# inputs = filter_RNGs(rv.owner.op.fgraph.inputs)
711+
# scan_node = rv.owner.op.fgraph.toposort()[-1]
712+
# output = rv.owner.op.fgraph.outputs[-1]
713+
# fgraph = FunctionGraph(inputs=inputs, outputs=[output], clone=True)
708714
replace_moments = MomentRewrite()
709715
replace_moments.rewrite(fgraph)
710716
for i, par in enumerate([size] + list(dist_params)):
711717
fgraph.replace(fgraph.inputs[i], par)
712-
[moment] = fgraph.outputs
718+
# for i, inp in enumerate(replace_rng_nodes(rv.owner.inputs)):
719+
# if isinstance(inp.type, RandomGeneratorType):
720+
# fgraph.replace(fgraph.inputs[i], inp)
721+
# breakpoint()
722+
# outputs = replace_rng_nodes(fgraph.outputs)
723+
for node in fgraph.toposort():
724+
if isinstance(node.op, Scan):
725+
# breakpoint()
726+
for inp in node.inputs:
727+
if isinstance(inp.type, RandomGeneratorType):
728+
fgraph.replace(
729+
inp,
730+
shared(np.random.Generator(np.random.PCG64())),
731+
import_missing=True,
732+
)
733+
moment = fgraph.outputs[-1]
734+
# breakpoint()
713735
return moment
714736

715737
if moment is None:

0 commit comments

Comments
 (0)