diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 6bf39ed998..321d364957 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -25,14 +25,16 @@ from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph -from pytensor.graph import FunctionGraph, node_rewriter -from pytensor.graph.basic import Node, Variable -from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.basic import in2out +from pytensor.graph import FunctionGraph, clone_replace, node_rewriter +from pytensor.graph.basic import Node, Variable, io_toposort +from pytensor.graph.features import ReplaceValidate +from pytensor.graph.rewriting.basic import GraphRewriter, in2out from pytensor.graph.utils import MetaType +from pytensor.scan.op import Scan from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import local_subtensor_rv_lift +from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.variable import TensorVariable @@ -83,6 +85,59 @@ PLATFORM = sys.platform +class MomentRewrite(GraphRewriter): + def rewrite_moment_scan_node(self, node): + if not isinstance(node.op, Scan): + return + + node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + op = node.op + + local_fgraph_topo = io_toposort(node_inputs, node_outputs) + + replace_with_moment = [] + to_replace_set = set() + + for nd in local_fgraph_topo: + if nd not in to_replace_set and isinstance( + nd.op, (RandomVariable, SymbolicRandomVariable) + ): + replace_with_moment.append(nd.out) + to_replace_set.add(nd) + givens = {} + if len(replace_with_moment) > 0: + for item in replace_with_moment: + givens[item] = moment(item) + else: + return + op_outs = clone_replace(node_outputs, replace=givens) + + nwScan = Scan( + node_inputs, + op_outs, + op.info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + name=op.name, + allow_gc=op.allow_gc, + ) + nw_node = nwScan(*(node.inputs), return_list=True)[0].owner + return nw_node + + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + + def apply(self, fgraph): + for node in fgraph.toposort(): + if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): + fgraph.replace(node.out, moment(node.out)) + elif isinstance(node.op, Scan): + new_node = self.rewrite_moment_scan_node(node) + if new_node is not None: + fgraph.replace_all(tuple(zip(node.outputs, new_node.outputs))) + + class _Unpickling: pass @@ -601,6 +656,20 @@ def update(self, node: Node): return updates +@_moment.register(CustomSymbolicDistRV) +def dist_moment(op, rv, *args): + node = rv.owner + rv_out_idx = node.outputs.index(rv) + + fgraph = op.fgraph.clone() + replace_moments = MomentRewrite() + replace_moments.rewrite(fgraph) + # Replace dummy inner inputs by outer inputs + fgraph.replace_all(tuple(zip(op.inner_inputs, args)), import_missing=True) + moment = fgraph.outputs[rv_out_idx] + return moment + + class _CustomSymbolicDist(Distribution): rv_type = CustomSymbolicDistRV @@ -622,14 +691,6 @@ def dist( if logcdf is None: logcdf = default_not_implemented(class_name, "logcdf") - if moment is None: - moment = functools.partial( - default_moment, - rv_name=class_name, - has_fallback=True, - ndim_supp=ndim_supp, - ) - return super().dist( dist_params, class_name=class_name, @@ -685,9 +746,19 @@ def custom_dist_logp(op, values, size, *params, **kwargs): def custom_dist_logcdf(op, value, size, *params, **kwargs): return logcdf(value, *params[: len(dist_params)]) - @_moment.register(rv_type) - def custom_dist_get_moment(op, rv, size, *params): - return moment(rv, size, *params[: len(params)]) + if moment is not None: + + @_moment.register(rv_type) + def custom_dist_get_moment(op, rv, size, *params): + return moment( + rv, + size, + *[ + p + for p in params + if not isinstance(p.type, (RandomType, RandomGeneratorType)) + ], + ) @_change_dist_size.register(rv_type) def change_custom_symbolic_dist_size(op, rv, new_size, expand): diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 9bdfdb2053..b324460bc2 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -430,6 +430,104 @@ def custom_dist(mu, sigma, size): ip = m.initial_point() np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip)) + @pytest.mark.parametrize( + "dist_params, size, expected, dist_fn", + [ + ( + (5, 1), + None, + np.exp(5), + lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size)), + ), + ( + (2, np.ones(5)), + None, + np.exp([2, 2, 2, 2, 2] + np.ones(5)), + lambda mu, sigma, size: pt.exp( + pm.Normal.dist(mu, sigma, size=size) + pt.ones(size) + ), + ), + ( + (1, 2), + None, + np.sqrt(np.exp(1 + 0.5 * 2**2)), + lambda mu, sigma, size: pt.sqrt(pm.LogNormal.dist(mu, sigma, size=size)), + ), + ( + (4,), + (3,), + np.log([4, 4, 4]), + lambda nu, size: pt.log(pm.ChiSquared.dist(nu, size=size)), + ), + ( + (12, 1), + None, + 12, + lambda mu1, sigma, size: pm.Normal.dist(mu1, sigma, size=size), + ), + ], + ) + def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): + with Model() as model: + CustomDist("x", *dist_params, dist=dist_fn, size=size) + assert_moment_is_expected(model, expected) + + def test_custom_dist_default_moment_scan(self): + def scan_step(left, right): + x = pm.Uniform.dist(left, right) + x_update = collect_default_updates([x]) + return x, x_update + + def dist(size): + xs, updates = scan( + fn=scan_step, + sequences=[ + pt.as_tensor_variable(np.array([-4, -3])), + pt.as_tensor_variable(np.array([-2, -1])), + ], + name="xs", + ) + return xs + + with Model() as model: + CustomDist("x", dist=dist) + assert_moment_is_expected(model, np.array([-3, -2])) + + def test_custom_dist_default_moment_scan_recurring(self): + def scan_step(xtm1): + x = pm.Normal.dist(xtm1 + 1) + x_update = collect_default_updates([x]) + return x, x_update + + def dist(size): + xs, _ = scan( + fn=scan_step, + outputs_info=pt.as_tensor_variable(np.array([0])).astype(float), + n_steps=3, + name="xs", + ) + return xs + + with Model() as model: + CustomDist("x", dist=dist) + assert_moment_is_expected(model, np.array([[1], [2], [3]])) + + @pytest.mark.parametrize( + "left, right, size, expected", + [ + (-1, 1, None, 0 + 5), + (-3, -1, None, -2 + 5), + (-3, 1, (3,), np.array([-1 + 5, -1 + 5, -1 + 5])), + ], + ) + def test_custom_dist_default_moment_nested(self, left, right, size, expected): + def dist_fn(left, right, size): + return pm.Truncated.dist(pm.Normal.dist(0, 1), left, right, size=size) + 5 + + with Model() as model: + CustomDist("x", left, right, size=size, dist=dist_fn) + assert_moment_is_expected(model, expected) + def test_logcdf_inference(self): def custom_dist(mu, sigma, size): return pt.exp(pm.Normal.dist(mu, sigma, size=size))