From 030824ec552f5f1ede7b92a6c1678186af0ca24c Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 22 Aug 2023 18:09:01 +0300 Subject: [PATCH 01/41] add test for custom dist default moment --- tests/distributions/test_distribution.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 9bdfdb2053..2d58600a4b 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -430,6 +430,32 @@ def custom_dist(mu, sigma, size): ip = m.initial_point() np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip)) + @pytest.mark.parametrize( + "mu, sigma, size, expected", + [ + (0, 1, None, np.exp(0 + 0.5 * 1**2)), + (0, np.ones(5), None, np.exp(0 + 0.5 * np.ones(5) ** 2)), + (np.arange(5), np.ones(5), None, np.exp(np.arange(5) + 0.5 * 1**2)), + ], + ) + def test_custom_dist_default_moment(self, mu, sigma, size, expected): + def custom_dist(mu, sigma, size): + return pm.math.exp(pm.Normal.dist(mu, sigma, size=size)) + + with Model() as model: + mu = Normal("mu") + sigma = HalfNormal("sigma") + CustomDist( + "x", + mu, + sigma, + dist=custom_dist, + size=(10,), + transform=log, + initval=np.ones(10), + ) + assert_moment_is_expected(model, expected) + def test_logcdf_inference(self): def custom_dist(mu, sigma, size): return pt.exp(pm.Normal.dist(mu, sigma, size=size)) From cdbe6f866f361a346e793975247b5206ceeb2240 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Fri, 1 Sep 2023 19:01:02 +0300 Subject: [PATCH 02/41] add graph rewriting --- pymc/distributions/distribution.py | 32 ++++++++++++++++++------ tests/distributions/test_distribution.py | 14 +++-------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 6bf39ed998..e18d42e677 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -28,7 +28,7 @@ from pytensor.graph import FunctionGraph, node_rewriter from pytensor.graph.basic import Node, Variable from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.basic import in2out +from pytensor.graph.rewriting.basic import NodeRewriter, WalkingGraphRewriter, in2out from pytensor.graph.utils import MetaType from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.random.op import RandomVariable @@ -83,6 +83,14 @@ PLATFORM = sys.platform +class MomentRewrite(NodeRewriter): + def transform(self, fgraph, node): + if isinstance(node.op, Distribution) and hasattr(node, "owner"): + node = moment(node) + return node + return False + + class _Unpickling: pass @@ -622,13 +630,21 @@ def dist( if logcdf is None: logcdf = default_not_implemented(class_name, "logcdf") + def dist_moment(rv, size, *dist_params): + fgraph = FunctionGraph(outputs=[dist(*dist_params, size=size)], clone=True) + replace_moments = WalkingGraphRewriter(MomentRewrite()) + replace_moments.rewrite(fgraph) + [moment] = fgraph.outputs + return moment + if moment is None: - moment = functools.partial( - default_moment, - rv_name=class_name, - has_fallback=True, - ndim_supp=ndim_supp, - ) + # moment = functools.partial( + # default_moment, + # rv_name=class_name, + # has_fallback=True, + # ndim_supp=ndim_supp, + # ) + moment = dist_moment return super().dist( dist_params, @@ -687,7 +703,7 @@ def custom_dist_logcdf(op, value, size, *params, **kwargs): @_moment.register(rv_type) def custom_dist_get_moment(op, rv, size, *params): - return moment(rv, size, *params[: len(params)]) + return moment(rv, size, *params[: len(params) - 1]) @_change_dist_size.register(rv_type) def change_custom_symbolic_dist_size(op, rv, new_size, expand): diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 2d58600a4b..2ef8e2e65e 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -434,8 +434,8 @@ def custom_dist(mu, sigma, size): "mu, sigma, size, expected", [ (0, 1, None, np.exp(0 + 0.5 * 1**2)), - (0, np.ones(5), None, np.exp(0 + 0.5 * np.ones(5) ** 2)), - (np.arange(5), np.ones(5), None, np.exp(np.arange(5) + 0.5 * 1**2)), + # (0, np.ones(5), None, np.exp(0 + 0.5 * np.ones(5) ** 2)), + # (np.arange(5), np.ones(5), None, np.exp(np.arange(5) + 0.5 * 1**2)), ], ) def test_custom_dist_default_moment(self, mu, sigma, size, expected): @@ -445,15 +445,7 @@ def custom_dist(mu, sigma, size): with Model() as model: mu = Normal("mu") sigma = HalfNormal("sigma") - CustomDist( - "x", - mu, - sigma, - dist=custom_dist, - size=(10,), - transform=log, - initval=np.ones(10), - ) + CustomDist("x", mu, sigma, dist=custom_dist, size=size) assert_moment_is_expected(model, expected) def test_logcdf_inference(self): From af82efc1a5ef0ce7231604ebde74bc2605002856 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 4 Sep 2023 18:18:42 +0300 Subject: [PATCH 03/41] add graph rewrite --- pymc/distributions/distribution.py | 35 ++++++++++++++++++------ tests/distributions/test_distribution.py | 6 ++-- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index e18d42e677..e467f8464c 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -27,8 +27,9 @@ from pytensor.compile.builders import OpFromGraph from pytensor.graph import FunctionGraph, node_rewriter from pytensor.graph.basic import Node, Variable +from pytensor.graph.features import ReplaceValidate from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.basic import NodeRewriter, WalkingGraphRewriter, in2out +from pytensor.graph.rewriting.basic import GraphRewriter, in2out from pytensor.graph.utils import MetaType from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.random.op import RandomVariable @@ -83,12 +84,30 @@ PLATFORM = sys.platform -class MomentRewrite(NodeRewriter): - def transform(self, fgraph, node): - if isinstance(node.op, Distribution) and hasattr(node, "owner"): - node = moment(node) - return node - return False +class MomentRewrite(GraphRewriter): + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + + def apply(self, fgraph): + for node in fgraph.toposort(): + for i, inp in enumerate(node.inputs): + if ( + hasattr(inp, "owner") + and hasattr(inp.owner, "op") + and isinstance(inp.owner.op, Distribution) + ): + fgraph.replace(node.inputs[i], moment(node.inputs[i])) + + # def transform(self, fgraph, node): + # flag = False + # for i, inp in enumerate(node.inputs): + # if hasattr(inp, "owner") and hasattr(inp.owner, "op") and isinstance(inp.owner.op, Distribution): + # #node.inputs[i] = moment(inp) + # node.inputs[i] = inp + # flag = True + # if flag: + # return [node] + # return False class _Unpickling: @@ -632,7 +651,7 @@ def dist( def dist_moment(rv, size, *dist_params): fgraph = FunctionGraph(outputs=[dist(*dist_params, size=size)], clone=True) - replace_moments = WalkingGraphRewriter(MomentRewrite()) + replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) [moment] = fgraph.outputs return moment diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 2ef8e2e65e..570184d867 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -433,7 +433,7 @@ def custom_dist(mu, sigma, size): @pytest.mark.parametrize( "mu, sigma, size, expected", [ - (0, 1, None, np.exp(0 + 0.5 * 1**2)), + (5, 1, None, np.exp(5)), # (0, np.ones(5), None, np.exp(0 + 0.5 * np.ones(5) ** 2)), # (np.arange(5), np.ones(5), None, np.exp(np.arange(5) + 0.5 * 1**2)), ], @@ -443,8 +443,8 @@ def custom_dist(mu, sigma, size): return pm.math.exp(pm.Normal.dist(mu, sigma, size=size)) with Model() as model: - mu = Normal("mu") - sigma = HalfNormal("sigma") + # mu = Normal("mu") + # sigma = HalfNormal("sigma") CustomDist("x", mu, sigma, dist=custom_dist, size=size) assert_moment_is_expected(model, expected) From c12da7ea4d1b2cce38f418d861cc4995a455c647 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 4 Sep 2023 18:24:29 +0300 Subject: [PATCH 04/41] change test case --- tests/distributions/test_distribution.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 570184d867..db51ad402d 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -434,8 +434,8 @@ def custom_dist(mu, sigma, size): "mu, sigma, size, expected", [ (5, 1, None, np.exp(5)), - # (0, np.ones(5), None, np.exp(0 + 0.5 * np.ones(5) ** 2)), - # (np.arange(5), np.ones(5), None, np.exp(np.arange(5) + 0.5 * 1**2)), + (2, np.ones(5), None, np.exp([2, 2, 2, 2, 2])), + (np.arange(5), np.ones(5), None, np.exp(np.arange(5))), ], ) def test_custom_dist_default_moment(self, mu, sigma, size, expected): @@ -443,8 +443,6 @@ def custom_dist(mu, sigma, size): return pm.math.exp(pm.Normal.dist(mu, sigma, size=size)) with Model() as model: - # mu = Normal("mu") - # sigma = HalfNormal("sigma") CustomDist("x", mu, sigma, dist=custom_dist, size=size) assert_moment_is_expected(model, expected) From b570fb62b2d5b5a82d3856e61fd9e393349e52ac Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 4 Sep 2023 18:34:33 +0300 Subject: [PATCH 05/41] remove commented code --- pymc/distributions/distribution.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index e467f8464c..d21fab53dc 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -98,17 +98,6 @@ def apply(self, fgraph): ): fgraph.replace(node.inputs[i], moment(node.inputs[i])) - # def transform(self, fgraph, node): - # flag = False - # for i, inp in enumerate(node.inputs): - # if hasattr(inp, "owner") and hasattr(inp.owner, "op") and isinstance(inp.owner.op, Distribution): - # #node.inputs[i] = moment(inp) - # node.inputs[i] = inp - # flag = True - # if flag: - # return [node] - # return False - class _Unpickling: pass @@ -657,12 +646,6 @@ def dist_moment(rv, size, *dist_params): return moment if moment is None: - # moment = functools.partial( - # default_moment, - # rv_name=class_name, - # has_fallback=True, - # ndim_supp=ndim_supp, - # ) moment = dist_moment return super().dist( From ea228489974e19853b5f96c2c32e09b80f6a39ee Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 4 Sep 2023 20:15:46 +0300 Subject: [PATCH 06/41] add more test cases --- pymc/distributions/distribution.py | 4 ++- tests/distributions/test_distribution.py | 38 ++++++++++++++++++------ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index d21fab53dc..eabbd4a270 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -35,6 +35,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.random.utils import normalize_size_param +from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.variable import TensorVariable from typing_extensions import TypeAlias @@ -705,7 +706,8 @@ def custom_dist_logcdf(op, value, size, *params, **kwargs): @_moment.register(rv_type) def custom_dist_get_moment(op, rv, size, *params): - return moment(rv, size, *params[: len(params) - 1]) + params = [i for i in params if not isinstance(i, RandomGeneratorSharedVariable)] + return moment(rv, size, *params[: len(params)]) @_change_dist_size.register(rv_type) def change_custom_symbolic_dist_size(op, rv, new_size, expand): diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index db51ad402d..c8f141a26b 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -431,19 +431,39 @@ def custom_dist(mu, sigma, size): np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip)) @pytest.mark.parametrize( - "mu, sigma, size, expected", + "dist_params, size, expected, dist_fn", [ - (5, 1, None, np.exp(5)), - (2, np.ones(5), None, np.exp([2, 2, 2, 2, 2])), - (np.arange(5), np.ones(5), None, np.exp(np.arange(5))), + ( + (5, 1), + None, + np.exp(5), + lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size)), + ), + ( + (2, np.ones(5)), + None, + np.exp([2, 2, 2, 2, 2] + np.ones(5)), + lambda mu, sigma, size: pt.exp( + pm.Normal.dist(mu, sigma, size=size) + pt.ones(size) + ), + ), + ( + (1, 2), + None, + np.sqrt(np.exp(1 + 0.5 * 2**2)), + lambda mu, sigma, size: pt.sqrt(pm.LogNormal.dist(mu, sigma, size=size)), + ), + ( + (4,), + (3,), + np.log([4, 4, 4]), + lambda nu, size: pt.log(pm.ChiSquared.dist(nu, size=size)), + ), ], ) - def test_custom_dist_default_moment(self, mu, sigma, size, expected): - def custom_dist(mu, sigma, size): - return pm.math.exp(pm.Normal.dist(mu, sigma, size=size)) - + def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): with Model() as model: - CustomDist("x", mu, sigma, dist=custom_dist, size=size) + CustomDist("x", *dist_params, dist=dist_fn, size=size) assert_moment_is_expected(model, expected) def test_logcdf_inference(self): From 06c264602898e954d94b1427af163c72c12fe145 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Thu, 7 Sep 2023 17:56:10 +0300 Subject: [PATCH 07/41] replace Distribution by RandomVariable in node input check --- pymc/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index eabbd4a270..01368e111f 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -95,7 +95,7 @@ def apply(self, fgraph): if ( hasattr(inp, "owner") and hasattr(inp.owner, "op") - and isinstance(inp.owner.op, Distribution) + and isinstance(inp.owner.op, RandomVariable) ): fgraph.replace(node.inputs[i], moment(node.inputs[i])) From 511f0f618611794a998ef791f1a8b75ccea4e6cb Mon Sep 17 00:00:00 2001 From: Anatoly <44327258+aerubanov@users.noreply.github.com> Date: Sat, 9 Sep 2023 13:53:29 +0300 Subject: [PATCH 08/41] Update pymc/distributions/distribution.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 01368e111f..2d58c0cbb8 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -95,7 +95,7 @@ def apply(self, fgraph): if ( hasattr(inp, "owner") and hasattr(inp.owner, "op") - and isinstance(inp.owner.op, RandomVariable) + and isinstance(inp.owner.op, (RandomVariable, SymbolicRandomVariable)) ): fgraph.replace(node.inputs[i], moment(node.inputs[i])) From 9a4a801890ec637ff295d1d18930e67762e6203b Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 11 Sep 2023 17:38:06 +0300 Subject: [PATCH 09/41] add test case for dist with inner graph --- tests/distributions/test_distribution.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index c8f141a26b..876d621753 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -466,6 +466,26 @@ def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): CustomDist("x", *dist_params, dist=dist_fn, size=size) assert_moment_is_expected(model, expected) + def test_custom_dist_custom_moment_inner_graph(self): + def scan_step(mu): + x = pm.Normal.dist(mu, 1) + x_update = collect_default_updates([x]) + return x, x_update + + def dist(mu, size): + # size = size.reshape(mu.shape) + ys, _ = pytensor.scan( + fn=scan_step, + sequences=[mu], + outputs_info=[None], + name="ys", + ) + return pt.sum(ys) + + with Model() as model: + CustomDist("x", pt.ones(2), dist=dist) + assert_moment_is_expected(model, 2) + def test_logcdf_inference(self): def custom_dist(mu, sigma, size): return pt.exp(pm.Normal.dist(mu, sigma, size=size)) From 10ba63d006e5d38d31319367fdb0277e50701745 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 12 Sep 2023 13:56:53 +0300 Subject: [PATCH 10/41] add extra test case and change apply method --- pymc/distributions/distribution.py | 16 +++++++++++++++- tests/distributions/test_distribution.py | 13 ++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 2d58c0cbb8..aac07152b7 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -85,11 +85,12 @@ PLATFORM = sys.platform -class MomentRewrite(GraphRewriter): +class InnerMomentRewrite(GraphRewriter): def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) def apply(self, fgraph): + # breakpoint() for node in fgraph.toposort(): for i, inp in enumerate(node.inputs): if ( @@ -100,6 +101,19 @@ def apply(self, fgraph): fgraph.replace(node.inputs[i], moment(node.inputs[i])) +class MomentRewrite(GraphRewriter): + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + + def apply(self, fgraph): + for node in fgraph.toposort(): + if hasattr(node.op, "fgraph"): + moment_replace = InnerMomentRewrite() + moment_replace.rewrite(node.op.fgraph) + elif isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): + fgraph.replace(node.out, moment(node.out)) + + class _Unpickling: pass diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 876d621753..9aaad72303 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -459,6 +459,12 @@ def custom_dist(mu, sigma, size): np.log([4, 4, 4]), lambda nu, size: pt.log(pm.ChiSquared.dist(nu, size=size)), ), + ( + (12, 1), + None, + 12, + lambda mu1, sigma, size: pm.Normal.dist(mu1, sigma, size=size), + ), ], ) def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): @@ -466,9 +472,9 @@ def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): CustomDist("x", *dist_params, dist=dist_fn, size=size) assert_moment_is_expected(model, expected) - def test_custom_dist_custom_moment_inner_graph(self): + def test_custom_dist_default_moment_inner_graph(self): def scan_step(mu): - x = pm.Normal.dist(mu, 1) + x = pt.exp(pm.Normal.dist(mu, 1)) x_update = collect_default_updates([x]) return x, x_update @@ -484,7 +490,8 @@ def dist(mu, size): with Model() as model: CustomDist("x", pt.ones(2), dist=dist) - assert_moment_is_expected(model, 2) + # assert_moment_is_expected(model, 5.43656365691809) + assert_moment_is_expected(model, np.sum(np.exp(np.ones(2)))) def test_logcdf_inference(self): def custom_dist(mu, sigma, size): From f078665735ca84a6bb9bb0c073c67be4a8ed5b1d Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 12 Sep 2023 16:09:58 +0300 Subject: [PATCH 11/41] change inner graph moment replacement --- pymc/distributions/distribution.py | 19 +----------- tests/distributions/test_distribution.py | 39 ++++++++++++++++-------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index aac07152b7..ffcbe3e746 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -85,22 +85,6 @@ PLATFORM = sys.platform -class InnerMomentRewrite(GraphRewriter): - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) - - def apply(self, fgraph): - # breakpoint() - for node in fgraph.toposort(): - for i, inp in enumerate(node.inputs): - if ( - hasattr(inp, "owner") - and hasattr(inp.owner, "op") - and isinstance(inp.owner.op, (RandomVariable, SymbolicRandomVariable)) - ): - fgraph.replace(node.inputs[i], moment(node.inputs[i])) - - class MomentRewrite(GraphRewriter): def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) @@ -108,8 +92,7 @@ def add_requirements(self, fgraph): def apply(self, fgraph): for node in fgraph.toposort(): if hasattr(node.op, "fgraph"): - moment_replace = InnerMomentRewrite() - moment_replace.rewrite(node.op.fgraph) + self.rewrite(node.op.fgraph) elif isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): fgraph.replace(node.out, moment(node.out)) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 9aaad72303..ae3f591b5c 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -473,25 +473,40 @@ def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): assert_moment_is_expected(model, expected) def test_custom_dist_default_moment_inner_graph(self): - def scan_step(mu): - x = pt.exp(pm.Normal.dist(mu, 1)) + # def scan_step(mu): + # x = pt.exp(pm.Normal.dist(mu, 1)) + # x_update = collect_default_updates([x]) + # return x, x_update + + # def dist(mu, size): + # # size = size.reshape(mu.shape) + # ys, _ = pytensor.scan( + # fn=scan_step, + # sequences=[mu], + # outputs_info=[None], + # name="ys", + # ) + # return pt.sum(ys) + + def scan_step(xtm1): + x = xtm1 * 2 x_update = collect_default_updates([x]) return x, x_update - def dist(mu, size): - # size = size.reshape(mu.shape) - ys, _ = pytensor.scan( + def dist(size): + x0 = pm.Normal.dist(1, 1) + + xs, updates = scan( fn=scan_step, - sequences=[mu], - outputs_info=[None], - name="ys", + outputs_info=[x0], + n_steps=2, + name="xs", ) - return pt.sum(ys) + return xs[-1] with Model() as model: - CustomDist("x", pt.ones(2), dist=dist) - # assert_moment_is_expected(model, 5.43656365691809) - assert_moment_is_expected(model, np.sum(np.exp(np.ones(2)))) + CustomDist("x", dist=dist) + assert_moment_is_expected(model, 4, check_finite_logp=False) def test_logcdf_inference(self): def custom_dist(mu, sigma, size): From cee024d855e2f7ee6666d36dffc04e77342afc86 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 12 Sep 2023 16:10:26 +0300 Subject: [PATCH 12/41] remove comented code --- tests/distributions/test_distribution.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index ae3f591b5c..646ae79093 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -473,21 +473,6 @@ def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): assert_moment_is_expected(model, expected) def test_custom_dist_default_moment_inner_graph(self): - # def scan_step(mu): - # x = pt.exp(pm.Normal.dist(mu, 1)) - # x_update = collect_default_updates([x]) - # return x, x_update - - # def dist(mu, size): - # # size = size.reshape(mu.shape) - # ys, _ = pytensor.scan( - # fn=scan_step, - # sequences=[mu], - # outputs_info=[None], - # name="ys", - # ) - # return pt.sum(ys) - def scan_step(xtm1): x = xtm1 * 2 x_update = collect_default_updates([x]) From 6e84d6212b4bc2c6139015bf344af4980f6fde46 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Fri, 22 Sep 2023 18:10:35 +0300 Subject: [PATCH 13/41] add initial implementation of scan op replacement --- pymc/distributions/distribution.py | 57 +++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index ffcbe3e746..cd2c8d8181 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -25,12 +25,12 @@ from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph -from pytensor.graph import FunctionGraph, node_rewriter -from pytensor.graph.basic import Node, Variable +from pytensor.graph import FunctionGraph, clone_replace, node_rewriter +from pytensor.graph.basic import Node, Variable, io_toposort from pytensor.graph.features import ReplaceValidate -from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out from pytensor.graph.utils import MetaType +from pytensor.scan.op import Scan from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import local_subtensor_rv_lift @@ -85,14 +85,60 @@ PLATFORM = sys.platform +def rewrite_moment_scan_node(node): + if not isinstance(node.op, Scan): + return + + node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + op = node.op + + local_fgraph_topo = io_toposort(node_inputs, node_outputs) + local_fgraph_outs_set = set(node_outputs) + local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} + + replace_with_moment = [] + to_replace_set = set() + + for nd in local_fgraph_topo: + if nd not in to_replace_set and isinstance( + node.op, (RandomVariable, SymbolicRandomVariable) + ): + for out in enumerate(nd.outputs): + # y_place_holder = safe_new(y, "_replace") + replace_with_moment.append(out) + to_replace_set.add(nd) + givens = {} + if len(replace_with_moment) > 0: + for item in replace_with_moment: + givens[item] = moment(item) + op_outs = clone_replace(node_outputs, replace=givens) + + nwScan = Scan( + node_inputs, + op_outs, + op.info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + name=op.name, + allow_gc=op.allow_gc, + ) + nw_node = nwScan(*(node.inputs), return_list=True)[0].owner + return nw_node + + class MomentRewrite(GraphRewriter): def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) def apply(self, fgraph): for node in fgraph.toposort(): - if hasattr(node.op, "fgraph"): - self.rewrite(node.op.fgraph) + if isinstance(node.op, Scan): + # inner_graph = node.op.fgraph.clone() + # self.rewrite(inner_graph) + # node.op.fgraph = inner_graph + new_node = rewrite_moment_scan_node(node) + fgraph.replace(node, new_node) elif isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): fgraph.replace(node.out, moment(node.out)) @@ -637,6 +683,7 @@ def dist( logcdf = default_not_implemented(class_name, "logcdf") def dist_moment(rv, size, *dist_params): + # TODO: add check for other op with inner graph (not Scan) fgraph = FunctionGraph(outputs=[dist(*dist_params, size=size)], clone=True) replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) From aa8fce9c9cc88c277819cd0d3a1c58687cc83d2b Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Fri, 22 Sep 2023 20:34:29 +0300 Subject: [PATCH 14/41] fix errors --- pymc/distributions/distribution.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index cd2c8d8181..ff9b030864 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -100,13 +100,9 @@ def rewrite_moment_scan_node(node): to_replace_set = set() for nd in local_fgraph_topo: - if nd not in to_replace_set and isinstance( - node.op, (RandomVariable, SymbolicRandomVariable) - ): - for out in enumerate(nd.outputs): - # y_place_holder = safe_new(y, "_replace") - replace_with_moment.append(out) - to_replace_set.add(nd) + if nd not in to_replace_set and isinstance(nd.op, (RandomVariable, SymbolicRandomVariable)): + replace_with_moment.append(nd.out) + to_replace_set.add(nd) givens = {} if len(replace_with_moment) > 0: for item in replace_with_moment: @@ -138,7 +134,8 @@ def apply(self, fgraph): # self.rewrite(inner_graph) # node.op.fgraph = inner_graph new_node = rewrite_moment_scan_node(node) - fgraph.replace(node, new_node) + for out1, out2 in zip(node.outputs, new_node.outputs): + fgraph.replace(out1, out2) elif isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): fgraph.replace(node.out, moment(node.out)) From 716f1f83e0b500dbd8bb36aea27957f8e4be9abb Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Fri, 22 Sep 2023 20:35:12 +0300 Subject: [PATCH 15/41] change test --- tests/distributions/test_distribution.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 646ae79093..8cea687f55 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -473,25 +473,22 @@ def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): assert_moment_is_expected(model, expected) def test_custom_dist_default_moment_inner_graph(self): - def scan_step(xtm1): - x = xtm1 * 2 + def scan_step(mu): + x = pm.Normal.dist(mu, 1) x_update = collect_default_updates([x]) return x, x_update def dist(size): - x0 = pm.Normal.dist(1, 1) - xs, updates = scan( fn=scan_step, - outputs_info=[x0], - n_steps=2, + sequences=[pt.ones(2)], name="xs", ) - return xs[-1] + return pt.sum(xs) with Model() as model: CustomDist("x", dist=dist) - assert_moment_is_expected(model, 4, check_finite_logp=False) + assert_moment_is_expected(model, 2, check_finite_logp=False) def test_logcdf_inference(self): def custom_dist(mu, sigma, size): From 62b3c5961165da9087803cbbcd4c144bb173afd4 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 25 Sep 2023 13:19:09 +0300 Subject: [PATCH 16/41] change normal ditribution by uniform in step func --- tests/distributions/test_distribution.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 8cea687f55..a148c4eeba 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -473,22 +473,25 @@ def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): assert_moment_is_expected(model, expected) def test_custom_dist_default_moment_inner_graph(self): - def scan_step(mu): - x = pm.Normal.dist(mu, 1) + def scan_step(left, right): + x = pm.Uniform.dist(left, right) x_update = collect_default_updates([x]) return x, x_update def dist(size): xs, updates = scan( fn=scan_step, - sequences=[pt.ones(2)], + sequences=[ + pt.as_tensor_variable(np.array([-4, -3])), + pt.as_tensor_variable(np.array([-2, -1])), + ], name="xs", ) - return pt.sum(xs) + return xs with Model() as model: CustomDist("x", dist=dist) - assert_moment_is_expected(model, 2, check_finite_logp=False) + assert_moment_is_expected(model, np.array([-3, -2])) def test_logcdf_inference(self): def custom_dist(mu, sigma, size): From 875e2ed8fc2b853d8bbc1ef3a5effe791a78f452 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 3 Oct 2023 17:38:25 +0300 Subject: [PATCH 17/41] add test case for nested dist --- tests/distributions/test_distribution.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index a148c4eeba..8c16559581 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -493,6 +493,22 @@ def dist(size): CustomDist("x", dist=dist) assert_moment_is_expected(model, np.array([-3, -2])) + @pytest.mark.parametrize( + "left, right, size, expected", + [ + (-1, 1, None, 0 + 5), + (-3, -1, None, -2 + 5), + (-3, 1, (3,), np.array([-1 + 5, -1 + 5, -1 + 5])), + ], + ) + def test_custom_dist_default_moment_nested(self, left, right, size, expected): + def dist_fn(left, right, size): + return pm.Truncated.dist(pm.Normal.dist(0, 1), left, right, size=size) + 5 + + with Model() as model: + CustomDist("x", left, right, size=size, dist=dist_fn) + assert_moment_is_expected(model, expected) + def test_logcdf_inference(self): def custom_dist(mu, sigma, size): return pt.exp(pm.Normal.dist(mu, sigma, size=size)) From e1b0f669c52c771d682c8b18b77a7071a8d38437 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Thu, 5 Oct 2023 19:36:24 +0300 Subject: [PATCH 18/41] remove unused and commented code --- pymc/distributions/distribution.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index ff9b030864..5621506054 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -93,8 +93,6 @@ def rewrite_moment_scan_node(node): op = node.op local_fgraph_topo = io_toposort(node_inputs, node_outputs) - local_fgraph_outs_set = set(node_outputs) - local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} replace_with_moment = [] to_replace_set = set() @@ -130,9 +128,6 @@ def add_requirements(self, fgraph): def apply(self, fgraph): for node in fgraph.toposort(): if isinstance(node.op, Scan): - # inner_graph = node.op.fgraph.clone() - # self.rewrite(inner_graph) - # node.op.fgraph = inner_graph new_node = rewrite_moment_scan_node(node) for out1, out2 in zip(node.outputs, new_node.outputs): fgraph.replace(out1, out2) From 3c30352a403b732b207f80c46aba9299fdf20a36 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Thu, 5 Oct 2023 19:37:09 +0300 Subject: [PATCH 19/41] change conditions order --- pymc/distributions/distribution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 5621506054..d556b20838 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -127,12 +127,12 @@ def add_requirements(self, fgraph): def apply(self, fgraph): for node in fgraph.toposort(): - if isinstance(node.op, Scan): + if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): + fgraph.replace(node.out, moment(node.out)) + elif isinstance(node.op, Scan): new_node = rewrite_moment_scan_node(node) for out1, out2 in zip(node.outputs, new_node.outputs): fgraph.replace(out1, out2) - elif isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): - fgraph.replace(node.out, moment(node.out)) class _Unpickling: From c0171958c8ed5951306b1e2a2b055e346eff5d44 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Thu, 5 Oct 2023 20:46:21 +0300 Subject: [PATCH 20/41] remove comment --- pymc/distributions/distribution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index d556b20838..e767e20aa3 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -675,7 +675,6 @@ def dist( logcdf = default_not_implemented(class_name, "logcdf") def dist_moment(rv, size, *dist_params): - # TODO: add check for other op with inner graph (not Scan) fgraph = FunctionGraph(outputs=[dist(*dist_params, size=size)], clone=True) replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) From 6cb03f1008782a7c5354058923c97d229fbd128c Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 16 Oct 2023 17:15:48 +0300 Subject: [PATCH 21/41] add helper function --- pymc/distributions/distribution.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index e767e20aa3..2049c32444 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -85,6 +85,10 @@ PLATFORM = sys.platform +def filter_shared_RNGs(params): + return [p for p in params if not isinstance(p, RandomGeneratorSharedVariable)] + + def rewrite_moment_scan_node(node): if not isinstance(node.op, Scan): return @@ -741,8 +745,7 @@ def custom_dist_logcdf(op, value, size, *params, **kwargs): @_moment.register(rv_type) def custom_dist_get_moment(op, rv, size, *params): - params = [i for i in params if not isinstance(i, RandomGeneratorSharedVariable)] - return moment(rv, size, *params[: len(params)]) + return moment(rv, size, *filter_shared_RNGs(params)) @_change_dist_size.register(rv_type) def change_custom_symbolic_dist_size(op, rv, new_size, expand): From 1b60994b1c94c7946f66a731d1675e9ccabdebf5 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 24 Oct 2023 20:35:39 +0300 Subject: [PATCH 22/41] add helper function for graph construction --- pymc/distributions/distribution.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 2049c32444..ebae5ad6b7 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -657,6 +657,13 @@ def update(self, node: Node): return updates +def get_rv_fgraph(dist_fn, dist_params, size): + rv = dist_fn(*dist_params, size=size) + outputs = [rv] + fgraph = FunctionGraph(outputs=outputs, clone=True) + return fgraph + + class _CustomSymbolicDist(Distribution): rv_type = CustomSymbolicDistRV @@ -679,7 +686,7 @@ def dist( logcdf = default_not_implemented(class_name, "logcdf") def dist_moment(rv, size, *dist_params): - fgraph = FunctionGraph(outputs=[dist(*dist_params, size=size)], clone=True) + fgraph = get_rv_fgraph(dist, dist_params, size) replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) [moment] = fgraph.outputs From 06569876623da4c63c27065a2db368b6c7ae949b Mon Sep 17 00:00:00 2001 From: Anatoly <44327258+aerubanov@users.noreply.github.com> Date: Mon, 30 Oct 2023 19:29:31 +0300 Subject: [PATCH 23/41] Update pymc/distributions/distribution.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/distributions/distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index ebae5ad6b7..2f5f0abe8b 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -85,8 +85,8 @@ PLATFORM = sys.platform -def filter_shared_RNGs(params): - return [p for p in params if not isinstance(p, RandomGeneratorSharedVariable)] +def filter_RNGs(params): + return [p for p in params if not isinstance(p.type, RandomType)] def rewrite_moment_scan_node(node): From 155e44ba2c933e8532ec8265aad4f18bac484f3f Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 31 Oct 2023 17:25:01 +0300 Subject: [PATCH 24/41] fix function name and imports --- pymc/distributions/distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 2f5f0abe8b..6694016485 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -34,8 +34,8 @@ from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import local_subtensor_rv_lift +from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.utils import normalize_size_param -from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.variable import TensorVariable from typing_extensions import TypeAlias @@ -752,7 +752,7 @@ def custom_dist_logcdf(op, value, size, *params, **kwargs): @_moment.register(rv_type) def custom_dist_get_moment(op, rv, size, *params): - return moment(rv, size, *filter_shared_RNGs(params)) + return moment(rv, size, *filter_RNGs(params)) @_change_dist_size.register(rv_type) def change_custom_symbolic_dist_size(op, rv, new_size, expand): From 097a05703325893197dab2e3a6b9bcedcdbb177c Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 31 Oct 2023 17:41:06 +0300 Subject: [PATCH 25/41] transform scan rewrite function into method of MomentRewrite --- pymc/distributions/distribution.py | 67 +++++++++++++++--------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 6694016485..3a649516d2 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -89,43 +89,44 @@ def filter_RNGs(params): return [p for p in params if not isinstance(p.type, RandomType)] -def rewrite_moment_scan_node(node): - if not isinstance(node.op, Scan): - return +class MomentRewrite(GraphRewriter): + def rewrite_moment_scan_node(node): + if not isinstance(node.op, Scan): + return - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs - op = node.op + node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + op = node.op - local_fgraph_topo = io_toposort(node_inputs, node_outputs) - - replace_with_moment = [] - to_replace_set = set() - - for nd in local_fgraph_topo: - if nd not in to_replace_set and isinstance(nd.op, (RandomVariable, SymbolicRandomVariable)): - replace_with_moment.append(nd.out) - to_replace_set.add(nd) - givens = {} - if len(replace_with_moment) > 0: - for item in replace_with_moment: - givens[item] = moment(item) - op_outs = clone_replace(node_outputs, replace=givens) - - nwScan = Scan( - node_inputs, - op_outs, - op.info, - mode=op.mode, - profile=op.profile, - truncate_gradient=op.truncate_gradient, - name=op.name, - allow_gc=op.allow_gc, - ) - nw_node = nwScan(*(node.inputs), return_list=True)[0].owner - return nw_node + local_fgraph_topo = io_toposort(node_inputs, node_outputs) + replace_with_moment = [] + to_replace_set = set() + + for nd in local_fgraph_topo: + if nd not in to_replace_set and isinstance( + nd.op, (RandomVariable, SymbolicRandomVariable) + ): + replace_with_moment.append(nd.out) + to_replace_set.add(nd) + givens = {} + if len(replace_with_moment) > 0: + for item in replace_with_moment: + givens[item] = moment(item) + op_outs = clone_replace(node_outputs, replace=givens) + + nwScan = Scan( + node_inputs, + op_outs, + op.info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + name=op.name, + allow_gc=op.allow_gc, + ) + nw_node = nwScan(*(node.inputs), return_list=True)[0].owner + return nw_node -class MomentRewrite(GraphRewriter): def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) From 2647993b3e014b66a3feafae228a9e7dc96f0980 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 31 Oct 2023 17:49:14 +0300 Subject: [PATCH 26/41] add check for no replacements needed --- pymc/distributions/distribution.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 3a649516d2..bbd2d8726a 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -112,6 +112,8 @@ def rewrite_moment_scan_node(node): if len(replace_with_moment) > 0: for item in replace_with_moment: givens[item] = moment(item) + else: + return op_outs = clone_replace(node_outputs, replace=givens) nwScan = Scan( @@ -135,9 +137,10 @@ def apply(self, fgraph): if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): fgraph.replace(node.out, moment(node.out)) elif isinstance(node.op, Scan): - new_node = rewrite_moment_scan_node(node) - for out1, out2 in zip(node.outputs, new_node.outputs): - fgraph.replace(out1, out2) + new_node = self.rewrite_moment_scan_node(node) + if new_node is not None: + for out1, out2 in zip(node.outputs, new_node.outputs): + fgraph.replace(out1, out2) class _Unpickling: From a0ae812271f2bedc4db1388609338043e9627475 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Tue, 31 Oct 2023 17:51:02 +0300 Subject: [PATCH 27/41] fix method arguments --- pymc/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index bbd2d8726a..33141358d4 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -90,7 +90,7 @@ def filter_RNGs(params): class MomentRewrite(GraphRewriter): - def rewrite_moment_scan_node(node): + def rewrite_moment_scan_node(self, node): if not isinstance(node.op, Scan): return From ef7a7a02ff7734795500e12579120fc1defd8438 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Wed, 1 Nov 2023 21:03:20 +0300 Subject: [PATCH 28/41] change fgraph construction --- pymc/distributions/distribution.py | 36 +++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 33141358d4..619b8c7015 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -34,7 +34,7 @@ from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import local_subtensor_rv_lift -from pytensor.tensor.random.type import RandomType +from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.variable import TensorVariable @@ -86,7 +86,7 @@ def filter_RNGs(params): - return [p for p in params if not isinstance(p.type, RandomType)] + return [p for p in params if not isinstance(p.type, (RandomType, RandomGeneratorType))] class MomentRewrite(GraphRewriter): @@ -661,13 +661,6 @@ def update(self, node: Node): return updates -def get_rv_fgraph(dist_fn, dist_params, size): - rv = dist_fn(*dist_params, size=size) - outputs = [rv] - fgraph = FunctionGraph(outputs=outputs, clone=True) - return fgraph - - class _CustomSymbolicDist(Distribution): rv_type = CustomSymbolicDistRV @@ -690,9 +683,32 @@ def dist( logcdf = default_not_implemented(class_name, "logcdf") def dist_moment(rv, size, *dist_params): - fgraph = get_rv_fgraph(dist, dist_params, size) + size = normalize_size_param(size) + dummy_size_param = size.type() + dummy_dist_params = [dist_param.type() for dist_param in dist_params] + dummy_rv = dist(*dummy_dist_params, dummy_size_param) + # dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,)) + dummy_params = [dummy_size_param] + dummy_dist_params + # dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) + rv_type = type( + class_name, + (CustomSymbolicDistRV,), + # If logp is not provided, we try to infer it from the dist graph + dict( + inline_logprob=logp is None, + ), + ) + rv_op = rv_type( + inputs=dummy_params, + # inputs=dummy_dist_params, + outputs=[dummy_rv], + ndim_supp=ndim_supp, + ) + fgraph = rv_op.fgraph.clone() replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) + for i, par in enumerate([size] + list(dist_params)): + fgraph.replace(fgraph.inputs[i], par) [moment] = fgraph.outputs return moment From 3004799df45f1da4bed69f731757fba5275560ce Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 6 Nov 2023 19:32:03 +0300 Subject: [PATCH 29/41] remove rv creation --- pymc/distributions/distribution.py | 68 ++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 619b8c7015..455c3b514a 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -23,6 +23,7 @@ import numpy as np +from pytensor import shared from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph from pytensor.graph import FunctionGraph, clone_replace, node_rewriter @@ -683,33 +684,54 @@ def dist( logcdf = default_not_implemented(class_name, "logcdf") def dist_moment(rv, size, *dist_params): - size = normalize_size_param(size) - dummy_size_param = size.type() - dummy_dist_params = [dist_param.type() for dist_param in dist_params] - dummy_rv = dist(*dummy_dist_params, dummy_size_param) - # dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,)) - dummy_params = [dummy_size_param] + dummy_dist_params - # dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) - rv_type = type( - class_name, - (CustomSymbolicDistRV,), - # If logp is not provided, we try to infer it from the dist graph - dict( - inline_logprob=logp is None, - ), - ) - rv_op = rv_type( - inputs=dummy_params, - # inputs=dummy_dist_params, - outputs=[dummy_rv], - ndim_supp=ndim_supp, - ) - fgraph = rv_op.fgraph.clone() + # size = normalize_size_param(size) + # dummy_size_param = size.type() + # dummy_dist_params = [dist_param.type() for dist_param in dist_params] + # dummy_rv = dist(*dummy_dist_params, dummy_size_param) + # # dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,)) + # dummy_params = [dummy_size_param] + dummy_dist_params + # # dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) + # rv_type = type( + # class_name, + # (CustomSymbolicDistRV,), + # # If logp is not provided, we try to infer it from the dist graph + # dict( + # inline_logprob=logp is None, + # ), + # ) + # rv_op = rv_type( + # inputs=dummy_params, + # # inputs=dummy_dist_params, + # outputs=[dummy_rv], + # ndim_supp=ndim_supp, + # ) + # fgraph = rv_op.fgraph.clone() + fgraph = rv.owner.op.fgraph.clone() + # inputs = filter_RNGs(rv.owner.op.fgraph.inputs) + # scan_node = rv.owner.op.fgraph.toposort()[-1] + # output = rv.owner.op.fgraph.outputs[-1] + # fgraph = FunctionGraph(inputs=inputs, outputs=[output], clone=True) replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) for i, par in enumerate([size] + list(dist_params)): fgraph.replace(fgraph.inputs[i], par) - [moment] = fgraph.outputs + # for i, inp in enumerate(replace_rng_nodes(rv.owner.inputs)): + # if isinstance(inp.type, RandomGeneratorType): + # fgraph.replace(fgraph.inputs[i], inp) + # breakpoint() + # outputs = replace_rng_nodes(fgraph.outputs) + for node in fgraph.toposort(): + if isinstance(node.op, Scan): + # breakpoint() + for inp in node.inputs: + if isinstance(inp.type, RandomGeneratorType): + fgraph.replace( + inp, + shared(np.random.Generator(np.random.PCG64())), + import_missing=True, + ) + moment = fgraph.outputs[-1] + # breakpoint() return moment if moment is None: From 6b7c11bafed13e72b3b644e85e800d33f670ea4b Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 6 Nov 2023 19:33:31 +0300 Subject: [PATCH 30/41] remove commented code --- pymc/distributions/distribution.py | 33 ------------------------------ 1 file changed, 33 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 455c3b514a..bb831664a6 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -684,45 +684,13 @@ def dist( logcdf = default_not_implemented(class_name, "logcdf") def dist_moment(rv, size, *dist_params): - # size = normalize_size_param(size) - # dummy_size_param = size.type() - # dummy_dist_params = [dist_param.type() for dist_param in dist_params] - # dummy_rv = dist(*dummy_dist_params, dummy_size_param) - # # dummy_updates_dict = collect_default_updates(inputs=dummy_dist_params, outputs=(dummy_rv,)) - # dummy_params = [dummy_size_param] + dummy_dist_params - # # dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) - # rv_type = type( - # class_name, - # (CustomSymbolicDistRV,), - # # If logp is not provided, we try to infer it from the dist graph - # dict( - # inline_logprob=logp is None, - # ), - # ) - # rv_op = rv_type( - # inputs=dummy_params, - # # inputs=dummy_dist_params, - # outputs=[dummy_rv], - # ndim_supp=ndim_supp, - # ) - # fgraph = rv_op.fgraph.clone() fgraph = rv.owner.op.fgraph.clone() - # inputs = filter_RNGs(rv.owner.op.fgraph.inputs) - # scan_node = rv.owner.op.fgraph.toposort()[-1] - # output = rv.owner.op.fgraph.outputs[-1] - # fgraph = FunctionGraph(inputs=inputs, outputs=[output], clone=True) replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) for i, par in enumerate([size] + list(dist_params)): fgraph.replace(fgraph.inputs[i], par) - # for i, inp in enumerate(replace_rng_nodes(rv.owner.inputs)): - # if isinstance(inp.type, RandomGeneratorType): - # fgraph.replace(fgraph.inputs[i], inp) - # breakpoint() - # outputs = replace_rng_nodes(fgraph.outputs) for node in fgraph.toposort(): if isinstance(node.op, Scan): - # breakpoint() for inp in node.inputs: if isinstance(inp.type, RandomGeneratorType): fgraph.replace( @@ -731,7 +699,6 @@ def dist_moment(rv, size, *dist_params): import_missing=True, ) moment = fgraph.outputs[-1] - # breakpoint() return moment if moment is None: From 9db17f681b95211fe1f69fa1af481e782cc69db7 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 6 Nov 2023 19:35:52 +0300 Subject: [PATCH 31/41] add comments --- pymc/distributions/distribution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index bb831664a6..4296c6306a 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -687,8 +687,10 @@ def dist_moment(rv, size, *dist_params): fgraph = rv.owner.op.fgraph.clone() replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) + # we need to replace dummy variable by actual dist params for i, par in enumerate([size] + list(dist_params)): fgraph.replace(fgraph.inputs[i], par) + # we need to replace dymmy random generators in Scan node inputs for node in fgraph.toposort(): if isinstance(node.op, Scan): for inp in node.inputs: From b3548cb862bf3e6ead8f4f19d4419b4c051765d3 Mon Sep 17 00:00:00 2001 From: Anatoly <44327258+aerubanov@users.noreply.github.com> Date: Fri, 10 Nov 2023 12:11:47 +0300 Subject: [PATCH 32/41] Update tests/distributions/test_distribution.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- tests/distributions/test_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 8c16559581..69ec62d00c 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -472,7 +472,7 @@ def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn): CustomDist("x", *dist_params, dist=dist_fn, size=size) assert_moment_is_expected(model, expected) - def test_custom_dist_default_moment_inner_graph(self): + def test_custom_dist_default_moment_scan(self): def scan_step(left, right): x = pm.Uniform.dist(left, right) x_update = collect_default_updates([x]) From 20e066ce5549facb459c1938f39e08b00a8156cb Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Fri, 10 Nov 2023 16:57:32 +0300 Subject: [PATCH 33/41] move dist_moment function outside of dist method --- pymc/distributions/distribution.py | 35 ++++++++++++------------------ 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 4296c6306a..f374192344 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -23,7 +23,6 @@ import numpy as np -from pytensor import shared from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph from pytensor.graph import FunctionGraph, clone_replace, node_rewriter @@ -662,6 +661,20 @@ def update(self, node: Node): return updates +def dist_moment(rv, *args): + node = rv.owner + op = node.op + rv_out_idx = node.outputs.index(rv) + + fgraph = op.fgraph.clone() + replace_moments = MomentRewrite() + replace_moments.rewrite(fgraph) + # Replace dummy inner inputs by outer inputs + fgraph.replace_all(tuple(zip(op.inner_inputs, node.inputs)), import_missing=True) + moment = fgraph.outputs[rv_out_idx] + return moment + + class _CustomSymbolicDist(Distribution): rv_type = CustomSymbolicDistRV @@ -683,26 +696,6 @@ def dist( if logcdf is None: logcdf = default_not_implemented(class_name, "logcdf") - def dist_moment(rv, size, *dist_params): - fgraph = rv.owner.op.fgraph.clone() - replace_moments = MomentRewrite() - replace_moments.rewrite(fgraph) - # we need to replace dummy variable by actual dist params - for i, par in enumerate([size] + list(dist_params)): - fgraph.replace(fgraph.inputs[i], par) - # we need to replace dymmy random generators in Scan node inputs - for node in fgraph.toposort(): - if isinstance(node.op, Scan): - for inp in node.inputs: - if isinstance(inp.type, RandomGeneratorType): - fgraph.replace( - inp, - shared(np.random.Generator(np.random.PCG64())), - import_missing=True, - ) - moment = fgraph.outputs[-1] - return moment - if moment is None: moment = dist_moment From db49e97ce65b2f07f37643533fe65883b6430cc0 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Fri, 10 Nov 2023 17:54:40 +0300 Subject: [PATCH 34/41] add new test case --- tests/distributions/test_distribution.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 69ec62d00c..b8ce198f8b 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -493,6 +493,25 @@ def dist(size): CustomDist("x", dist=dist) assert_moment_is_expected(model, np.array([-3, -2])) + def test_custom_dist_default_moment_scan_recurring(self): + def scan_step(xtm1): + x = pm.Normal.dist(xtm1) + x_update = collect_default_updates([x]) + return x, x_update + + def dist(size): + xs, _ = scan( + fn=scan_step, + outputs_info=pt.as_tensor_variable(np.array([0])).astype(float), + n_steps=3, + name="xs", + ) + return xs + + with Model() as model: + CustomDist("x", dist=dist) + assert_moment_is_expected(model, np.array([[0], [0], [0]])) + @pytest.mark.parametrize( "left, right, size, expected", [ From 80c6b02a6698e17ab6f3ee3e2162d8a9d243657a Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 13 Nov 2023 14:15:29 +0300 Subject: [PATCH 35/41] add changes from review --- pymc/distributions/distribution.py | 15 ++++++++------- tests/distributions/test_distribution.py | 4 ++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index f374192344..12398274c3 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -139,8 +139,7 @@ def apply(self, fgraph): elif isinstance(node.op, Scan): new_node = self.rewrite_moment_scan_node(node) if new_node is not None: - for out1, out2 in zip(node.outputs, new_node.outputs): - fgraph.replace(out1, out2) + fgraph.replace_all(tuple(zip(node.outputs, new_node.outputs))) class _Unpickling: @@ -696,8 +695,8 @@ def dist( if logcdf is None: logcdf = default_not_implemented(class_name, "logcdf") - if moment is None: - moment = dist_moment + # if moment is None: + moment = dist_moment return super().dist( dist_params, @@ -754,9 +753,11 @@ def custom_dist_logp(op, values, size, *params, **kwargs): def custom_dist_logcdf(op, value, size, *params, **kwargs): return logcdf(value, *params[: len(dist_params)]) - @_moment.register(rv_type) - def custom_dist_get_moment(op, rv, size, *params): - return moment(rv, size, *filter_RNGs(params)) + if moment is not None: + + @_moment.register(rv_type) + def custom_dist_get_moment(op, rv, size, *params): + return moment(rv, size, *filter_RNGs(params)) @_change_dist_size.register(rv_type) def change_custom_symbolic_dist_size(op, rv, new_size, expand): diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index b8ce198f8b..b324460bc2 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -495,7 +495,7 @@ def dist(size): def test_custom_dist_default_moment_scan_recurring(self): def scan_step(xtm1): - x = pm.Normal.dist(xtm1) + x = pm.Normal.dist(xtm1 + 1) x_update = collect_default_updates([x]) return x, x_update @@ -510,7 +510,7 @@ def dist(size): with Model() as model: CustomDist("x", dist=dist) - assert_moment_is_expected(model, np.array([[0], [0], [0]])) + assert_moment_is_expected(model, np.array([[1], [2], [3]])) @pytest.mark.parametrize( "left, right, size, expected", From ba00b38037f6dc57335cf4f371ffe4a57024d448 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 13 Nov 2023 14:31:06 +0300 Subject: [PATCH 36/41] remove filter_RNGs function --- pymc/distributions/distribution.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 12398274c3..e238a8bfaa 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -85,10 +85,6 @@ PLATFORM = sys.platform -def filter_RNGs(params): - return [p for p in params if not isinstance(p.type, (RandomType, RandomGeneratorType))] - - class MomentRewrite(GraphRewriter): def rewrite_moment_scan_node(self, node): if not isinstance(node.op, Scan): @@ -695,8 +691,8 @@ def dist( if logcdf is None: logcdf = default_not_implemented(class_name, "logcdf") - # if moment is None: - moment = dist_moment + if moment is None: + moment = dist_moment return super().dist( dist_params, @@ -757,7 +753,15 @@ def custom_dist_logcdf(op, value, size, *params, **kwargs): @_moment.register(rv_type) def custom_dist_get_moment(op, rv, size, *params): - return moment(rv, size, *filter_RNGs(params)) + return moment( + rv, + size, + *[ + p + for p in params + if not isinstance(p.type, (RandomType, RandomGeneratorType)) + ], + ) @_change_dist_size.register(rv_type) def change_custom_symbolic_dist_size(op, rv, new_size, expand): From a3c9f143109cdeaf22a5dbe60d54daebf8b7020b Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 13 Nov 2023 14:35:40 +0300 Subject: [PATCH 37/41] remove moment from dist method --- pymc/distributions/distribution.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index e238a8bfaa..e7262711c8 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -691,9 +691,6 @@ def dist( if logcdf is None: logcdf = default_not_implemented(class_name, "logcdf") - if moment is None: - moment = dist_moment - return super().dist( dist_params, class_name=class_name, From ed5a3c7851ef2223da28f47c1b173abf1bbc4a39 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 13 Nov 2023 14:43:50 +0300 Subject: [PATCH 38/41] register moment fn --- pymc/distributions/distribution.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index e7262711c8..871596a598 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -656,16 +656,18 @@ def update(self, node: Node): return updates +@_moment.register(CustomSymbolicDistRV) def dist_moment(rv, *args): - node = rv.owner + x = args[0] + node = x.owner op = node.op - rv_out_idx = node.outputs.index(rv) + rv_out_idx = node.outputs.index(x) fgraph = op.fgraph.clone() replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) # Replace dummy inner inputs by outer inputs - fgraph.replace_all(tuple(zip(op.inner_inputs, node.inputs)), import_missing=True) + fgraph.replace_all(tuple(zip(op.inner_inputs, args[1:])), import_missing=True) moment = fgraph.outputs[rv_out_idx] return moment From 9ec33c6707241de78773c7b7ef04a523c5207c3d Mon Sep 17 00:00:00 2001 From: Anatoly <44327258+aerubanov@users.noreply.github.com> Date: Mon, 13 Nov 2023 14:46:55 +0300 Subject: [PATCH 39/41] Update pymc/distributions/distribution.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 871596a598..17c041c512 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -657,7 +657,7 @@ def update(self, node: Node): @_moment.register(CustomSymbolicDistRV) -def dist_moment(rv, *args): +def dist_moment(op, rv, *args): x = args[0] node = x.owner op = node.op From d5899d472c610252795d01fcc31d453008ac2664 Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 13 Nov 2023 14:47:51 +0300 Subject: [PATCH 40/41] fix moment arguments --- pymc/distributions/distribution.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 17c041c512..969cc555d9 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -658,16 +658,15 @@ def update(self, node: Node): @_moment.register(CustomSymbolicDistRV) def dist_moment(op, rv, *args): - x = args[0] - node = x.owner + node = rv.owner op = node.op - rv_out_idx = node.outputs.index(x) + rv_out_idx = node.outputs.index(rv) fgraph = op.fgraph.clone() replace_moments = MomentRewrite() replace_moments.rewrite(fgraph) # Replace dummy inner inputs by outer inputs - fgraph.replace_all(tuple(zip(op.inner_inputs, args[1:])), import_missing=True) + fgraph.replace_all(tuple(zip(op.inner_inputs, args)), import_missing=True) moment = fgraph.outputs[rv_out_idx] return moment From 9b3c43d02561c282c0f9f8be8962b272c9daf1ff Mon Sep 17 00:00:00 2001 From: Anatoly Rubanov Date: Mon, 13 Nov 2023 16:30:56 +0300 Subject: [PATCH 41/41] remove separate var for op --- pymc/distributions/distribution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 969cc555d9..321d364957 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -659,7 +659,6 @@ def update(self, node: Node): @_moment.register(CustomSymbolicDistRV) def dist_moment(op, rv, *args): node = rv.owner - op = node.op rv_out_idx = node.outputs.index(rv) fgraph = op.fgraph.clone()