Skip to content

Commit 5d003dc

Browse files
cuchoiricardoV94twiecki
authored
Remove t suffix from Model methods (#5863)
* Future warning for logpt * Future warning for dlogpt and d2logpt * Updated references to logpt, and updated varlogpt, datalogpt, joint_logpt * Fix issue with d2logpt * Added tests * Fix typo * Updated release notes for 4.0 * Added potentiallogpt test * Updated developer guide * Update pymc/distributions/logprob.py Co-authored-by: Ricardo Vieira <[email protected]> * Removed t from varlogp_nojact * Revert Release Notes * Revert changes to developer guide * Future warning for logpt * Future warning for dlogpt and d2logpt * Updated references to logpt, and updated varlogpt, datalogpt, joint_logpt * Fix issue with d2logpt * Added tests * Fix typo * Updated release notes for 4.0 * Added potentiallogpt test * Update pymc/distributions/logprob.py Co-authored-by: Ricardo Vieira <[email protected]> * Removed t from varlogp_nojact * Revert Release Notes * Updated release notes for 4.0 * Revert Release Notes * Added deprecation of functions/properties ending with t to release notes * Update RELEASE-NOTES.md Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]>
1 parent a5f3e45 commit 5d003dc

24 files changed

+268
-152
lines changed

RELEASE-NOTES.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@
44
+ Fixed an incorrect entry in `pm.Metropolis.stats_dtypes` (see #5582).
55
+ Added a check in `Empirical` approximation which does not yet support `InferenceData` inputs (see #5874, #5884).
66
+ Fixed bug when sampling discrete variables with SMC (see #5887).
7+
+ Removed trailing `t` (for tensor) in functions and properties from the model class and from `jointlogpt` (see #5859).
8+
+ `Model.logpt``Model.logp`
9+
+ `Model.dlogpt``Model.dlogp`
10+
+ `Model.d2logpt``Model.d2logp`
11+
+ `Model.datalogpt``Model.datalogp`
12+
+ `Model.varlogpt``Model.varlogp`
13+
+ `Model.observedlogpt``Model.observedlogp`
14+
+ `Model.potentiallogpt``Model.potentiallogp`
15+
+ `Model.varlogp_nojact``Model.varlogp_nojac`
16+
+ `logprob.joint_logpt``logprob.joint_logp`
717

818
## PyMC 4.0.0 (2022-06-03)
919

docs/source/learn/core_notebooks/pymc_aesara.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,7 +1844,7 @@
18441844
}
18451845
},
18461846
"source": [
1847-
"`pymc` models provide some helpful routines to facilitating the conversion of `RandomVariable`s to probability functions. {meth}`~pymc.Model.logpt`, for instance can be used to extract the joint probability of all variables in the model:"
1847+
"`pymc` models provide some helpful routines to facilitating the conversion of `RandomVariable`s to probability functions. {meth}`~pymc.Model.logp`, for instance can be used to extract the joint probability of all variables in the model:"
18481848
]
18491849
},
18501850
{
@@ -1902,7 +1902,7 @@
19021902
}
19031903
],
19041904
"source": [
1905-
"aesara.dprint(model.logpt(sum=False))"
1905+
"aesara.dprint(model.logp(sum=False))"
19061906
]
19071907
},
19081908
{
@@ -2213,7 +2213,7 @@
22132213
"sigma_log_value = model_2.rvs_to_values[sigma]\n",
22142214
"x_value = model_2.rvs_to_values[x]\n",
22152215
"# element-wise log-probability of the model (we do not take te sum)\n",
2216-
"logp_graph = at.stack(model_2.logpt(sum=False))\n",
2216+
"logp_graph = at.stack(model_2.logp(sum=False))\n",
22172217
"# evaluate by passing concrete values\n",
22182218
"logp_graph.eval({mu_value: 0, sigma_log_value: -10, x_value:0})"
22192219
]
@@ -2314,7 +2314,7 @@
23142314
}
23152315
},
23162316
"source": [
2317-
"The {class}`~pymc.Model` class also has methods to extract the gradient ({meth}`~pymc.Model.dlogpt`) and the hessian ({meth}`~pymc.Model.d2logpt`) of the logp."
2317+
"The {class}`~pymc.Model` class also has methods to extract the gradient ({meth}`~pymc.Model.dlogp`) and the hessian ({meth}`~pymc.Model.d2logp`) of the logp."
23182318
]
23192319
},
23202320
{

pymc/backends/arviz.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _extract_log_likelihood(self, trace):
251251
(
252252
var,
253253
self.model.compile_fn(
254-
self.model.logpt(var, sum=False)[0],
254+
self.model.logp(var, sum=False)[0],
255255
inputs=self.model.value_vars,
256256
on_unused_input="ignore",
257257
),
@@ -263,7 +263,7 @@ def _extract_log_likelihood(self, trace):
263263
(
264264
var,
265265
self.model.compile_fn(
266-
self.model.logpt(var, sum=False)[0],
266+
self.model.logp(var, sum=False)[0],
267267
inputs=self.model.value_vars,
268268
on_unused_input="ignore",
269269
),

pymc/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pymc.distributions.logprob import ( # isort:skip
1616
logcdf,
1717
logp,
18+
joint_logp,
1819
joint_logpt,
1920
)
2021

@@ -191,6 +192,7 @@
191192
"CAR",
192193
"PolyaGamma",
193194
"joint_logpt",
195+
"joint_logp",
194196
"logp",
195197
"logcdf",
196198
]

pymc/distributions/continuous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2558,7 +2558,7 @@ def logcdf(value, nu):
25582558
return logcdf(Gamma.dist(alpha=nu / 2, beta=0.5), value)
25592559

25602560

2561-
# TODO: Remove this once logpt for multiplication is working!
2561+
# TODO: Remove this once logp for multiplication is working!
25622562
class WeibullBetaRV(WeibullRV):
25632563
ndims_params = [0, 0]
25642564

pymc/distributions/logprob.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415

1516
from collections.abc import Mapping
1617
from typing import Dict, List, Optional, Sequence, Union
@@ -119,7 +120,15 @@ def _get_scaling(
119120
)
120121

121122

122-
def joint_logpt(
123+
def joint_logpt(*args, **kwargs):
124+
warnings.warn(
125+
"joint_logpt has been deprecated. Use joint_logp instead.",
126+
FutureWarning,
127+
)
128+
return joint_logp(*args, **kwargs)
129+
130+
131+
def joint_logp(
123132
var: Union[TensorVariable, List[TensorVariable]],
124133
rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None,
125134
*,
@@ -159,14 +168,14 @@ def joint_logpt(
159168
160169
"""
161170
# TODO: In future when we drop support for tag.value_var most of the following
162-
# logic can be removed and logpt can just be a wrapper function that calls aeppl's
171+
# logic can be removed and logp can just be a wrapper function that calls aeppl's
163172
# joint_logprob directly.
164173

165174
# If var is not a list make it one.
166175
if not isinstance(var, (list, tuple)):
167176
var = [var]
168177

169-
# If logpt isn't provided values it is assumed that the tagged value var or
178+
# If logp isn't provided values it is assumed that the tagged value var or
170179
# observation is the value variable for that particular RV.
171180
if rv_values is None:
172181
rv_values = {}
@@ -251,7 +260,7 @@ def joint_logpt(
251260
"reference nonlocal variables."
252261
)
253262

254-
# aeppl returns the logpt for every single value term we provided to it. This includes
263+
# aeppl returns the logp for every single value term we provided to it. This includes
255264
# the extra values we plugged in above, so we filter those we actually wanted in the
256265
# same order they were given in.
257266
logp_var_dict = {}

pymc/model.py

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
)
5858
from pymc.blocking import DictToArrayBijection, RaveledVars
5959
from pymc.data import GenTensorVariable, Minibatch
60-
from pymc.distributions import joint_logpt
60+
from pymc.distributions import joint_logp
6161
from pymc.distributions.logprob import _get_scaling
6262
from pymc.distributions.transforms import _default_transform
6363
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning
@@ -623,9 +623,9 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
623623
raise ValueError(f"Can only compute the gradient of continuous types: {var}")
624624

625625
if tempered:
626-
costs = [self.varlogpt, self.datalogpt]
626+
costs = [self.varlogp, self.datalogp]
627627
else:
628-
costs = [self.logpt()]
628+
costs = [self.logp()]
629629

630630
input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
631631
extra_vars = [self.rvs_to_values.get(var, var) for var in self.free_RVs]
@@ -654,7 +654,7 @@ def compile_logp(
654654
Whether to sum all logp terms or return elemwise logp for each variable.
655655
Defaults to True.
656656
"""
657-
return self.model.compile_fn(self.logpt(vars=vars, jacobian=jacobian, sum=sum))
657+
return self.model.compile_fn(self.logp(vars=vars, jacobian=jacobian, sum=sum))
658658

659659
def compile_dlogp(
660660
self,
@@ -671,7 +671,7 @@ def compile_dlogp(
671671
jacobian:
672672
Whether to include jacobian terms in logprob graph. Defaults to True.
673673
"""
674-
return self.model.compile_fn(self.dlogpt(vars=vars, jacobian=jacobian))
674+
return self.model.compile_fn(self.dlogp(vars=vars, jacobian=jacobian))
675675

676676
def compile_d2logp(
677677
self,
@@ -688,9 +688,16 @@ def compile_d2logp(
688688
jacobian:
689689
Whether to include jacobian terms in logprob graph. Defaults to True.
690690
"""
691-
return self.model.compile_fn(self.d2logpt(vars=vars, jacobian=jacobian))
691+
return self.model.compile_fn(self.d2logp(vars=vars, jacobian=jacobian))
692692

693-
def logpt(
693+
def logpt(self, *args, **kwargs):
694+
warnings.warn(
695+
"Model.logpt has been deprecated. Use Model.logp instead.",
696+
FutureWarning,
697+
)
698+
return self.logp(*args, **kwargs)
699+
700+
def logp(
694701
self,
695702
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
696703
jacobian: bool = True,
@@ -742,7 +749,7 @@ def logpt(
742749

743750
rv_logps: List[TensorVariable] = []
744751
if rv_values:
745-
rv_logps = joint_logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
752+
rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
746753
assert isinstance(rv_logps, list)
747754

748755
# Replace random variables by their value variables in potential terms
@@ -764,7 +771,14 @@ def logpt(
764771
logp_scalar.name = logp_scalar_name
765772
return logp_scalar
766773

767-
def dlogpt(
774+
def dlogpt(self, *args, **kwargs):
775+
warnings.warn(
776+
"Model.dlogpt has been deprecated. Use Model.dlogp instead.",
777+
FutureWarning,
778+
)
779+
return self.dlogp(*args, **kwargs)
780+
781+
def dlogp(
768782
self,
769783
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
770784
jacobian: bool = True,
@@ -799,10 +813,17 @@ def dlogpt(
799813
f"Requested variable {var} not found among the model variables"
800814
)
801815

802-
cost = self.logpt(jacobian=jacobian)
816+
cost = self.logp(jacobian=jacobian)
803817
return gradient(cost, value_vars)
804818

805-
def d2logpt(
819+
def d2logpt(self, *args, **kwargs):
820+
warnings.warn(
821+
"Model.d2logpt has been deprecated. Use Model.d2logp instead.",
822+
FutureWarning,
823+
)
824+
return self.d2logp(*args, **kwargs)
825+
826+
def d2logp(
806827
self,
807828
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
808829
jacobian: bool = True,
@@ -837,34 +858,74 @@ def d2logpt(
837858
f"Requested variable {var} not found among the model variables"
838859
)
839860

840-
cost = self.logpt(jacobian=jacobian)
861+
cost = self.logp(jacobian=jacobian)
841862
return hessian(cost, value_vars)
842863

843864
@property
844-
def datalogpt(self) -> Variable:
865+
def datalogpt(self):
866+
warnings.warn(
867+
"Model.datalogpt has been deprecated. Use Model.datalogp instead.",
868+
FutureWarning,
869+
)
870+
return self.datalogp
871+
872+
@property
873+
def datalogp(self) -> Variable:
845874
"""Aesara scalar of log-probability of the observed variables and
846875
potential terms"""
847-
return self.observedlogpt + self.potentiallogpt
876+
return self.observedlogp + self.potentiallogp
848877

849878
@property
850-
def varlogpt(self) -> Variable:
879+
def varlogpt(self):
880+
warnings.warn(
881+
"Model.varlogpt has been deprecated. Use Model.varlogp instead.",
882+
FutureWarning,
883+
)
884+
return self.varlogp
885+
886+
@property
887+
def varlogp(self) -> Variable:
851888
"""Aesara scalar of log-probability of the unobserved random variables
852889
(excluding deterministic)."""
853-
return self.logpt(vars=self.free_RVs)
890+
return self.logp(vars=self.free_RVs)
854891

855892
@property
856-
def varlogp_nojact(self) -> Variable:
893+
def varlogp_nojact(self):
894+
warnings.warn(
895+
"Model.varlogp_nojact has been deprecated. Use Model.varlogp_nojac instead.",
896+
FutureWarning,
897+
)
898+
return self.varlogp_nojac
899+
900+
@property
901+
def varlogp_nojac(self) -> Variable:
857902
"""Aesara scalar of log-probability of the unobserved random variables
858903
(excluding deterministic) without jacobian term."""
859-
return self.logpt(vars=self.free_RVs, jacobian=False)
904+
return self.logp(vars=self.free_RVs, jacobian=False)
905+
906+
@property
907+
def observedlogpt(self):
908+
warnings.warn(
909+
"Model.observedlogpt has been deprecated. Use Model.observedlogp instead.",
910+
FutureWarning,
911+
)
912+
return self.observedlogp
860913

861914
@property
862-
def observedlogpt(self) -> Variable:
915+
def observedlogp(self) -> Variable:
863916
"""Aesara scalar of log-probability of the observed variables"""
864-
return self.logpt(vars=self.observed_RVs)
917+
return self.logp(vars=self.observed_RVs)
918+
919+
@property
920+
def potentiallogpt(self):
921+
warnings.warn(
922+
"Model.potentiallogpt has been deprecated. Use Model.potentiallogp instead.",
923+
FutureWarning,
924+
)
925+
return self.potentiallogp
865926

866927
@property
867-
def potentiallogpt(self) -> Variable:
928+
def potentiallogp(self) -> Variable:
868929
"""Aesara scalar of log-probability of the Potential terms"""
869930
# Convert random variables in Potential expression into their log-likelihood
870931
# inputs and apply their transforms, if any
@@ -1755,7 +1816,7 @@ def point_logps(self, point=None, round_vals=2):
17551816
point = self.initial_point()
17561817

17571818
factors = self.basic_RVs + self.potentials
1758-
factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
1819+
factor_logps_fn = [at.sum(factor) for factor in self.logp(factors, sum=False)]
17591820
return {
17601821
factor.name: np.round(np.asarray(factor_logp), round_vals)
17611822
for factor, factor_logp in zip(

pymc/sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,15 @@ def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
204204
# Use competence classmethods to select step methods for remaining
205205
# variables
206206
selected_steps = defaultdict(list)
207-
model_logpt = model.logpt()
207+
model_logp = model.logp()
208208

209209
for var in model.value_vars:
210210
if var not in assigned_vars:
211211
# determine if a gradient can be computed
212212
has_gradient = var.dtype not in discrete_types
213213
if has_gradient:
214214
try:
215-
tg.grad(model_logpt, var)
215+
tg.grad(model_logp, var)
216216
except (NotImplementedError, tg.NullTypeGradError):
217217
has_gradient = False
218218

pymc/sampling_jax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def get_jaxified_graph(
100100

101101

102102
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
103-
model_logpt = model.logpt()
103+
model_logp = model.logp()
104104
if not negative_logp:
105-
model_logpt = -model_logpt
106-
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
105+
model_logp = -model_logp
106+
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
107107

108108
def logp_fn_wrap(x):
109109
return logp_fn(*x)[0]
@@ -136,8 +136,8 @@ def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
136136
"""Compute log-likelihood for all observations"""
137137
data = {}
138138
for v in model.observed_RVs:
139-
v_elemwise_logpt = model.logpt(v, sum=False)
140-
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=v_elemwise_logpt)
139+
v_elemwise_logp = model.logp(v, sum=False)
140+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=v_elemwise_logp)
141141
result = jax.jit(jax.vmap(jax.vmap(jax_fn)), backend=backend)(*samples)[0]
142142
data[v.name] = result
143143
return data

pymc/smc/smc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ def _initialize_kernel(self):
219219
shared = make_shared_replacements(initial_point, self.variables, self.model)
220220

221221
self.prior_logp_func = _logp_forw(
222-
initial_point, [self.model.varlogpt], self.variables, shared
222+
initial_point, [self.model.varlogp], self.variables, shared
223223
)
224224
self.likelihood_logp_func = _logp_forw(
225-
initial_point, [self.model.datalogpt], self.variables, shared
225+
initial_point, [self.model.datalogp], self.variables, shared
226226
)
227227

228228
priors = [self.prior_logp_func(sample) for sample in self.tempered_posterior]

0 commit comments

Comments
 (0)