Skip to content

Commit ad450a6

Browse files
Default moment for CustomDist provided with a dist function (#6873)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent d7415de commit ad450a6

File tree

2 files changed

+184
-15
lines changed

2 files changed

+184
-15
lines changed

pymc/distributions/distribution.py

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@
2525

2626
from pytensor import tensor as pt
2727
from pytensor.compile.builders import OpFromGraph
28-
from pytensor.graph import FunctionGraph, node_rewriter
29-
from pytensor.graph.basic import Node, Variable
30-
from pytensor.graph.replace import clone_replace
31-
from pytensor.graph.rewriting.basic import in2out
28+
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
29+
from pytensor.graph.basic import Node, Variable, io_toposort
30+
from pytensor.graph.features import ReplaceValidate
31+
from pytensor.graph.rewriting.basic import GraphRewriter, in2out
3232
from pytensor.graph.utils import MetaType
33+
from pytensor.scan.op import Scan
3334
from pytensor.tensor.basic import as_tensor_variable
3435
from pytensor.tensor.random.op import RandomVariable
3536
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
37+
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
3638
from pytensor.tensor.random.utils import normalize_size_param
3739
from pytensor.tensor.rewriting.shape import ShapeFeature
3840
from pytensor.tensor.variable import TensorVariable
@@ -83,6 +85,59 @@
8385
PLATFORM = sys.platform
8486

8587

88+
class MomentRewrite(GraphRewriter):
89+
def rewrite_moment_scan_node(self, node):
90+
if not isinstance(node.op, Scan):
91+
return
92+
93+
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
94+
op = node.op
95+
96+
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
97+
98+
replace_with_moment = []
99+
to_replace_set = set()
100+
101+
for nd in local_fgraph_topo:
102+
if nd not in to_replace_set and isinstance(
103+
nd.op, (RandomVariable, SymbolicRandomVariable)
104+
):
105+
replace_with_moment.append(nd.out)
106+
to_replace_set.add(nd)
107+
givens = {}
108+
if len(replace_with_moment) > 0:
109+
for item in replace_with_moment:
110+
givens[item] = moment(item)
111+
else:
112+
return
113+
op_outs = clone_replace(node_outputs, replace=givens)
114+
115+
nwScan = Scan(
116+
node_inputs,
117+
op_outs,
118+
op.info,
119+
mode=op.mode,
120+
profile=op.profile,
121+
truncate_gradient=op.truncate_gradient,
122+
name=op.name,
123+
allow_gc=op.allow_gc,
124+
)
125+
nw_node = nwScan(*(node.inputs), return_list=True)[0].owner
126+
return nw_node
127+
128+
def add_requirements(self, fgraph):
129+
fgraph.attach_feature(ReplaceValidate())
130+
131+
def apply(self, fgraph):
132+
for node in fgraph.toposort():
133+
if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)):
134+
fgraph.replace(node.out, moment(node.out))
135+
elif isinstance(node.op, Scan):
136+
new_node = self.rewrite_moment_scan_node(node)
137+
if new_node is not None:
138+
fgraph.replace_all(tuple(zip(node.outputs, new_node.outputs)))
139+
140+
86141
class _Unpickling:
87142
pass
88143

@@ -601,6 +656,20 @@ def update(self, node: Node):
601656
return updates
602657

603658

659+
@_moment.register(CustomSymbolicDistRV)
660+
def dist_moment(op, rv, *args):
661+
node = rv.owner
662+
rv_out_idx = node.outputs.index(rv)
663+
664+
fgraph = op.fgraph.clone()
665+
replace_moments = MomentRewrite()
666+
replace_moments.rewrite(fgraph)
667+
# Replace dummy inner inputs by outer inputs
668+
fgraph.replace_all(tuple(zip(op.inner_inputs, args)), import_missing=True)
669+
moment = fgraph.outputs[rv_out_idx]
670+
return moment
671+
672+
604673
class _CustomSymbolicDist(Distribution):
605674
rv_type = CustomSymbolicDistRV
606675

@@ -622,14 +691,6 @@ def dist(
622691
if logcdf is None:
623692
logcdf = default_not_implemented(class_name, "logcdf")
624693

625-
if moment is None:
626-
moment = functools.partial(
627-
default_moment,
628-
rv_name=class_name,
629-
has_fallback=True,
630-
ndim_supp=ndim_supp,
631-
)
632-
633694
return super().dist(
634695
dist_params,
635696
class_name=class_name,
@@ -685,9 +746,19 @@ def custom_dist_logp(op, values, size, *params, **kwargs):
685746
def custom_dist_logcdf(op, value, size, *params, **kwargs):
686747
return logcdf(value, *params[: len(dist_params)])
687748

688-
@_moment.register(rv_type)
689-
def custom_dist_get_moment(op, rv, size, *params):
690-
return moment(rv, size, *params[: len(params)])
749+
if moment is not None:
750+
751+
@_moment.register(rv_type)
752+
def custom_dist_get_moment(op, rv, size, *params):
753+
return moment(
754+
rv,
755+
size,
756+
*[
757+
p
758+
for p in params
759+
if not isinstance(p.type, (RandomType, RandomGeneratorType))
760+
],
761+
)
691762

692763
@_change_dist_size.register(rv_type)
693764
def change_custom_symbolic_dist_size(op, rv, new_size, expand):

tests/distributions/test_distribution.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,104 @@ def custom_dist(mu, sigma, size):
430430
ip = m.initial_point()
431431
np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip))
432432

433+
@pytest.mark.parametrize(
434+
"dist_params, size, expected, dist_fn",
435+
[
436+
(
437+
(5, 1),
438+
None,
439+
np.exp(5),
440+
lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size)),
441+
),
442+
(
443+
(2, np.ones(5)),
444+
None,
445+
np.exp([2, 2, 2, 2, 2] + np.ones(5)),
446+
lambda mu, sigma, size: pt.exp(
447+
pm.Normal.dist(mu, sigma, size=size) + pt.ones(size)
448+
),
449+
),
450+
(
451+
(1, 2),
452+
None,
453+
np.sqrt(np.exp(1 + 0.5 * 2**2)),
454+
lambda mu, sigma, size: pt.sqrt(pm.LogNormal.dist(mu, sigma, size=size)),
455+
),
456+
(
457+
(4,),
458+
(3,),
459+
np.log([4, 4, 4]),
460+
lambda nu, size: pt.log(pm.ChiSquared.dist(nu, size=size)),
461+
),
462+
(
463+
(12, 1),
464+
None,
465+
12,
466+
lambda mu1, sigma, size: pm.Normal.dist(mu1, sigma, size=size),
467+
),
468+
],
469+
)
470+
def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn):
471+
with Model() as model:
472+
CustomDist("x", *dist_params, dist=dist_fn, size=size)
473+
assert_moment_is_expected(model, expected)
474+
475+
def test_custom_dist_default_moment_scan(self):
476+
def scan_step(left, right):
477+
x = pm.Uniform.dist(left, right)
478+
x_update = collect_default_updates([x])
479+
return x, x_update
480+
481+
def dist(size):
482+
xs, updates = scan(
483+
fn=scan_step,
484+
sequences=[
485+
pt.as_tensor_variable(np.array([-4, -3])),
486+
pt.as_tensor_variable(np.array([-2, -1])),
487+
],
488+
name="xs",
489+
)
490+
return xs
491+
492+
with Model() as model:
493+
CustomDist("x", dist=dist)
494+
assert_moment_is_expected(model, np.array([-3, -2]))
495+
496+
def test_custom_dist_default_moment_scan_recurring(self):
497+
def scan_step(xtm1):
498+
x = pm.Normal.dist(xtm1 + 1)
499+
x_update = collect_default_updates([x])
500+
return x, x_update
501+
502+
def dist(size):
503+
xs, _ = scan(
504+
fn=scan_step,
505+
outputs_info=pt.as_tensor_variable(np.array([0])).astype(float),
506+
n_steps=3,
507+
name="xs",
508+
)
509+
return xs
510+
511+
with Model() as model:
512+
CustomDist("x", dist=dist)
513+
assert_moment_is_expected(model, np.array([[1], [2], [3]]))
514+
515+
@pytest.mark.parametrize(
516+
"left, right, size, expected",
517+
[
518+
(-1, 1, None, 0 + 5),
519+
(-3, -1, None, -2 + 5),
520+
(-3, 1, (3,), np.array([-1 + 5, -1 + 5, -1 + 5])),
521+
],
522+
)
523+
def test_custom_dist_default_moment_nested(self, left, right, size, expected):
524+
def dist_fn(left, right, size):
525+
return pm.Truncated.dist(pm.Normal.dist(0, 1), left, right, size=size) + 5
526+
527+
with Model() as model:
528+
CustomDist("x", left, right, size=size, dist=dist_fn)
529+
assert_moment_is_expected(model, expected)
530+
433531
def test_logcdf_inference(self):
434532
def custom_dist(mu, sigma, size):
435533
return pt.exp(pm.Normal.dist(mu, sigma, size=size))

0 commit comments

Comments
 (0)