Skip to content

Commit e470d13

Browse files
committed
Do not rely on tag information for rv and logp conversions
1 parent 44dc340 commit e470d13

25 files changed

+535
-264
lines changed

pymc/aesaraf.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from typing import (
1517
Callable,
1618
Dict,
@@ -31,6 +33,7 @@
3133
import scipy.sparse as sps
3234

3335
from aeppl.logprob import CheckParameterValue
36+
from aeppl.transforms import RVTransform
3437
from aesara import scalar
3538
from aesara.compile.mode import Mode, get_mode
3639
from aesara.gradient import grad
@@ -205,7 +208,7 @@ def expand(var):
205208
yield from walk(graphs, expand, bfs=False)
206209

207210

208-
def replace_rvs_in_graphs(
211+
def _replace_rvs_in_graphs(
209212
graphs: Iterable[TensorVariable],
210213
replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]],
211214
**kwargs,
@@ -282,6 +285,10 @@ def rvs_to_value_vars(
282285
apply_transforms
283286
If ``True``, apply each value variable's transform.
284287
"""
288+
warnings.warn(
289+
"rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead",
290+
FutureWarning,
291+
)
285292

286293
def populate_replacements(
287294
random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
@@ -313,7 +320,7 @@ def populate_replacements(
313320
equiv = clone_get_equiv(inputs, graphs, False, False, {})
314321
graphs = [equiv[n] for n in graphs]
315322

316-
graphs, _ = replace_rvs_in_graphs(
323+
graphs, _ = _replace_rvs_in_graphs(
317324
graphs,
318325
replacement_fn=populate_replacements,
319326
**kwargs,
@@ -322,6 +329,69 @@ def populate_replacements(
322329
return graphs
323330

324331

332+
def replace_rvs_by_values(
333+
graphs: Sequence[TensorVariable],
334+
*,
335+
rvs_to_values: Dict[TensorVariable, TensorVariable],
336+
rvs_to_transforms: Dict[TensorVariable, RVTransform],
337+
**kwargs,
338+
) -> List[TensorVariable]:
339+
"""Clone and replace random variables in graphs with their value variables.
340+
341+
This will *not* recompute test values in the resulting graphs.
342+
343+
Parameters
344+
----------
345+
graphs
346+
The graphs in which to perform the replacements.
347+
rvs_to_values
348+
Mapping between the original graph RVs and respective value variables
349+
rvs_to_transforms
350+
Mapping between the original graph RVs and respective value transforms
351+
"""
352+
353+
# Clone original graphs so that we don't modify variables in place
354+
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
355+
equiv = clone_get_equiv(inputs, graphs, False, False, {})
356+
graphs = [equiv[n] for n in graphs]
357+
358+
# Get needed mappings for equivalent cloned variables
359+
equiv_rvs_to_values = {}
360+
equiv_rvs_to_transforms = {}
361+
for rv, value in rvs_to_values.items():
362+
equiv_rv = equiv.get(rv, rv)
363+
equiv_rvs_to_values[equiv_rv] = equiv.get(value, value)
364+
equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv]
365+
366+
def poulate_replacements(rv, replacements):
367+
# Populate replacements dict with {rv: value} pairs indicating which graph
368+
# RVs should be replaced by what value variables.
369+
370+
# No value variable to replace RV with
371+
value = equiv_rvs_to_values.get(rv, None)
372+
if value is None:
373+
return []
374+
375+
transform = equiv_rvs_to_transforms.get(rv, None)
376+
if transform is not None:
377+
# We want to replace uses of the RV by the back-transformation of its value
378+
value = transform.backward(value, *rv.owner.inputs)
379+
value.name = rv.name
380+
381+
replacements[rv] = value
382+
# Also walk the graph of the value variable to make any additional
383+
# replacements if that is not a simple input variable
384+
return [value]
385+
386+
graphs, _ = _replace_rvs_in_graphs(
387+
graphs,
388+
replacement_fn=poulate_replacements,
389+
**kwargs,
390+
)
391+
392+
return graphs
393+
394+
325395
def inputvars(a):
326396
"""
327397
Get the inputs into Aesara variables

pymc/backends/arviz.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def find_observations(model: "Model") -> Dict[str, Var]:
4747
"""If there are observations available, return them as a dictionary."""
4848
observations = {}
4949
for obs in model.observed_RVs:
50-
aux_obs = getattr(obs.tag, "observations", None)
50+
aux_obs = model.rvs_to_values.get(obs, None)
5151
if aux_obs is not None:
5252
try:
5353
obs_data = extract_obs_data(aux_obs)
@@ -261,7 +261,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):
261261

262262
if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
263263
try:
264-
obs_data = extract_obs_data(var.tag.observations)
264+
obs_data = extract_obs_data(self.model.rvs_to_values[var])
265265
except TypeError:
266266
warnings.warn(f"Could not extract data from symbolic observation {var}")
267267

pymc/distributions/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
logcdf,
1717
logp,
1818
joint_logp,
19-
joint_logpt,
2019
)
2120

2221
from pymc.distributions.bound import Bound
@@ -199,7 +198,6 @@
199198
"Censored",
200199
"CAR",
201200
"PolyaGamma",
202-
"joint_logpt",
203201
"joint_logp",
204202
"logp",
205203
"logcdf",

pymc/distributions/logprob.py

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@
2525
from aeppl.logprob import logcdf as logcdf_aeppl
2626
from aeppl.logprob import logprob as logp_aeppl
2727
from aeppl.tensor import MeasurableJoin
28-
from aeppl.transforms import TransformValuesRewrite
28+
from aeppl.transforms import RVTransform, TransformValuesRewrite
2929
from aesara import tensor as at
3030
from aesara.graph.basic import graph_inputs, io_toposort
3131
from aesara.tensor.random.op import RandomVariable
3232
from aesara.tensor.var import TensorVariable
3333

3434
from pymc.aesaraf import constant_fold, floatX
3535

36+
TOTAL_SIZE = Union[int, Sequence[int], None]
3637

37-
def _get_scaling(
38-
total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int
39-
) -> TensorVariable:
38+
39+
def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable:
4040
"""
4141
Gets scaling constant for logp.
4242
@@ -104,12 +104,26 @@ def _get_scaling(
104104
return at.as_tensor(coef, dtype=aesara.config.floatX)
105105

106106

107-
def joint_logpt(*args, **kwargs):
108-
warnings.warn(
109-
"joint_logpt has been deprecated. Use joint_logp instead.",
110-
FutureWarning,
111-
)
112-
return joint_logp(*args, **kwargs)
107+
def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
108+
# Raise if there are unexpected RandomVariables in the logp graph
109+
# Only SimulatorRVs are allowed
110+
from pymc.distributions.simulator import SimulatorRV
111+
112+
unexpected_rv_nodes = [
113+
node
114+
for node in aesara.graph.ancestors(logp_terms)
115+
if (
116+
node.owner
117+
and isinstance(node.owner.op, RandomVariable)
118+
and not isinstance(node.owner.op, SimulatorRV)
119+
)
120+
]
121+
if unexpected_rv_nodes:
122+
raise ValueError(
123+
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
124+
"This can happen when DensityDist logp or Interval transform functions "
125+
"reference nonlocal variables."
126+
)
113127

114128

115129
def joint_logp(
@@ -151,6 +165,10 @@ def joint_logp(
151165
Sum the log-likelihood or return each term as a separate list item.
152166
153167
"""
168+
warnings.warn(
169+
"joint_logp has been deprecated, use model.logp instead",
170+
FutureWarning,
171+
)
154172
# TODO: In future when we drop support for tag.value_var most of the following
155173
# logic can be removed and logp can just be a wrapper function that calls aeppl's
156174
# joint_logprob directly.
@@ -223,33 +241,15 @@ def joint_logp(
223241
**kwargs,
224242
)
225243

226-
# Raise if there are unexpected RandomVariables in the logp graph
227-
# Only SimulatorRVs are allowed
228-
from pymc.distributions.simulator import SimulatorRV
229-
230-
unexpected_rv_nodes = [
231-
node
232-
for node in aesara.graph.ancestors(list(temp_logp_var_dict.values()))
233-
if (
234-
node.owner
235-
and isinstance(node.owner.op, RandomVariable)
236-
and not isinstance(node.owner.op, SimulatorRV)
237-
)
238-
]
239-
if unexpected_rv_nodes:
240-
raise ValueError(
241-
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
242-
"This can happen when DensityDist logp or Interval transform functions "
243-
"reference nonlocal variables."
244-
)
245-
246244
# aeppl returns the logp for every single value term we provided to it. This includes
247245
# the extra values we plugged in above, so we filter those we actually wanted in the
248246
# same order they were given in.
249247
logp_var_dict = {}
250248
for value_var in rv_values.values():
251249
logp_var_dict[value_var] = temp_logp_var_dict[value_var]
252250

251+
_check_no_rvs(list(logp_var_dict.values()))
252+
253253
if scaling:
254254
for value_var in logp_var_dict.keys():
255255
if value_var in rv_scalings:
@@ -263,6 +263,52 @@ def joint_logp(
263263
return logp_var
264264

265265

266+
def _joint_logp(
267+
rvs: Sequence[TensorVariable],
268+
*,
269+
rvs_to_values: Dict[TensorVariable, TensorVariable],
270+
rvs_to_transforms: Dict[TensorVariable, RVTransform],
271+
jacobian: bool = True,
272+
rvs_to_total_sizes: Dict[TensorVariable, TOTAL_SIZE],
273+
**kwargs,
274+
) -> List[TensorVariable]:
275+
"""Thin wrapper around aeppl.factorized_joint_logprob, extended with PyMC specific
276+
concerns such as transforms, jacobian, and scaling"""
277+
278+
transform_rewrite = None
279+
values_to_transforms = {
280+
rvs_to_values[rv]: transform
281+
for rv, transform in rvs_to_transforms.items()
282+
if transform is not None
283+
}
284+
if values_to_transforms:
285+
# There seems to be an incorrect type hint in TransformValuesRewrite
286+
transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore
287+
288+
temp_logp_terms = factorized_joint_logprob(
289+
rvs_to_values,
290+
extra_rewrites=transform_rewrite,
291+
use_jacobian=jacobian,
292+
**kwargs,
293+
)
294+
295+
# aeppl returns the logp for every single value term we provided to it. This includes
296+
# the extra values we plugged in above, so we filter those we actually wanted in the
297+
# same order they were given in.
298+
logp_terms = {}
299+
for rv in rvs:
300+
value_var = rvs_to_values[rv]
301+
logp_term = temp_logp_terms[value_var]
302+
total_size = rvs_to_total_sizes.get(rv, None)
303+
if total_size is not None:
304+
scaling = _get_scaling(total_size, value_var.shape, value_var.ndim)
305+
logp_term *= scaling
306+
logp_terms[value_var] = logp_term
307+
308+
_check_no_rvs(list(logp_terms.values()))
309+
return list(logp_terms.values())
310+
311+
266312
def logp(rv: TensorVariable, value) -> TensorVariable:
267313
"""Return the log-probability graph of a Random Variable"""
268314

pymc/initial_point.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import aesara.tensor as at
2121
import numpy as np
2222

23+
from aeppl.transforms import RVTransform
2324
from aesara.graph.basic import Variable
2425
from aesara.graph.fg import FunctionGraph
2526
from aesara.tensor.var import TensorVariable
@@ -43,9 +44,7 @@ def convert_str_to_rv_dict(
4344
if isinstance(key, str):
4445
if is_transformed_name(key):
4546
rv = model[get_untransformed_name(key)]
46-
initvals[rv] = model.rvs_to_values[rv].tag.transform.backward(
47-
initval, *rv.owner.inputs
48-
)
47+
initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs)
4948
else:
5049
initvals[model[key]] = initval
5150
else:
@@ -158,7 +157,7 @@ def make_initial_point_fn(
158157

159158
initial_values = make_initial_point_expression(
160159
free_rvs=model.free_RVs,
161-
rvs_to_values=model.rvs_to_values,
160+
rvs_to_transforms=model.rvs_to_transforms,
162161
initval_strategies=initval_strats,
163162
jitter_rvs=jitter_rvs,
164163
default_strategy=default_strategy,
@@ -172,7 +171,7 @@ def make_initial_point_fn(
172171

173172
varnames = []
174173
for var in model.free_RVs:
175-
transform = getattr(model.rvs_to_values[var].tag, "transform", None)
174+
transform = model.rvs_to_transforms[var]
176175
if transform is not None and return_transformed:
177176
name = get_transformed_name(var.name, transform)
178177
else:
@@ -197,7 +196,7 @@ def inner(seed, *args, **kwargs):
197196
def make_initial_point_expression(
198197
*,
199198
free_rvs: Sequence[TensorVariable],
200-
rvs_to_values: Dict[TensorVariable, TensorVariable],
199+
rvs_to_transforms: Dict[TensorVariable, RVTransform],
201200
initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]],
202201
jitter_rvs: Set[TensorVariable] = None,
203202
default_strategy: str = "moment",
@@ -265,7 +264,7 @@ def make_initial_point_expression(
265264
else:
266265
value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype)
267266

268-
transform = getattr(rvs_to_values[variable].tag, "transform", None)
267+
transform = rvs_to_transforms.get(variable, None)
269268

270269
if transform is not None:
271270
value = transform.forward(value, *variable.owner.inputs)

0 commit comments

Comments
 (0)