|
23 | 23 |
|
24 | 24 | import numpy as np
|
25 | 25 |
|
| 26 | +from pytensor import shared |
26 | 27 | from pytensor import tensor as pt
|
27 | 28 | from pytensor.compile.builders import OpFromGraph
|
28 | 29 | from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
|
@@ -683,33 +684,54 @@ def dist(
|
683 | 684 | logcdf = default_not_implemented(class_name, "logcdf")
|
684 | 685 |
|
685 | 686 | def dist_moment(rv, size, *dist_params):
|
686 |
| - size = normalize_size_param(size) |
687 |
| - dummy_size_param = size.type() |
688 |
| - dummy_dist_params = [dist_param.type() for dist_param in dist_params] |
689 |
| - dummy_rv = dist(*dummy_dist_params, dummy_size_param) |
690 |
| - # dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,)) |
691 |
| - dummy_params = [dummy_size_param] + dummy_dist_params |
692 |
| - # dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) |
693 |
| - rv_type = type( |
694 |
| - class_name, |
695 |
| - (CustomSymbolicDistRV,), |
696 |
| - # If logp is not provided, we try to infer it from the dist graph |
697 |
| - dict( |
698 |
| - inline_logprob=logp is None, |
699 |
| - ), |
700 |
| - ) |
701 |
| - rv_op = rv_type( |
702 |
| - inputs=dummy_params, |
703 |
| - # inputs=dummy_dist_params, |
704 |
| - outputs=[dummy_rv], |
705 |
| - ndim_supp=ndim_supp, |
706 |
| - ) |
707 |
| - fgraph = rv_op.fgraph.clone() |
| 687 | + # size = normalize_size_param(size) |
| 688 | + # dummy_size_param = size.type() |
| 689 | + # dummy_dist_params = [dist_param.type() for dist_param in dist_params] |
| 690 | + # dummy_rv = dist(*dummy_dist_params, dummy_size_param) |
| 691 | + # # dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,)) |
| 692 | + # dummy_params = [dummy_size_param] + dummy_dist_params |
| 693 | + # # dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) |
| 694 | + # rv_type = type( |
| 695 | + # class_name, |
| 696 | + # (CustomSymbolicDistRV,), |
| 697 | + # # If logp is not provided, we try to infer it from the dist graph |
| 698 | + # dict( |
| 699 | + # inline_logprob=logp is None, |
| 700 | + # ), |
| 701 | + # ) |
| 702 | + # rv_op = rv_type( |
| 703 | + # inputs=dummy_params, |
| 704 | + # # inputs=dummy_dist_params, |
| 705 | + # outputs=[dummy_rv], |
| 706 | + # ndim_supp=ndim_supp, |
| 707 | + # ) |
| 708 | + # fgraph = rv_op.fgraph.clone() |
| 709 | + fgraph = rv.owner.op.fgraph.clone() |
| 710 | + # inputs = filter_RNGs(rv.owner.op.fgraph.inputs) |
| 711 | + # scan_node = rv.owner.op.fgraph.toposort()[-1] |
| 712 | + # output = rv.owner.op.fgraph.outputs[-1] |
| 713 | + # fgraph = FunctionGraph(inputs=inputs, outputs=[output], clone=True) |
708 | 714 | replace_moments = MomentRewrite()
|
709 | 715 | replace_moments.rewrite(fgraph)
|
710 | 716 | for i, par in enumerate([size] + list(dist_params)):
|
711 | 717 | fgraph.replace(fgraph.inputs[i], par)
|
712 |
| - [moment] = fgraph.outputs |
| 718 | + # for i, inp in enumerate(replace_rng_nodes(rv.owner.inputs)): |
| 719 | + # if isinstance(inp.type, RandomGeneratorType): |
| 720 | + # fgraph.replace(fgraph.inputs[i], inp) |
| 721 | + # breakpoint() |
| 722 | + # outputs = replace_rng_nodes(fgraph.outputs) |
| 723 | + for node in fgraph.toposort(): |
| 724 | + if isinstance(node.op, Scan): |
| 725 | + # breakpoint() |
| 726 | + for inp in node.inputs: |
| 727 | + if isinstance(inp.type, RandomGeneratorType): |
| 728 | + fgraph.replace( |
| 729 | + inp, |
| 730 | + shared(np.random.Generator(np.random.PCG64())), |
| 731 | + import_missing=True, |
| 732 | + ) |
| 733 | + moment = fgraph.outputs[-1] |
| 734 | + # breakpoint() |
713 | 735 | return moment
|
714 | 736 |
|
715 | 737 | if moment is None:
|
|
0 commit comments