From f534a7f9c80bf23c43c7a8333eddd66bcbd21283 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 24 Jan 2022 12:17:44 +0100
Subject: [PATCH 01/10] Rename `BaseTestDistribution` to
 `BaseTestDistributionRandom`

---
 pymc/tests/test_distributions_random.py | 170 ++++++++++++------------
 1 file changed, 85 insertions(+), 85 deletions(-)

diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index 62f477df89..02a775baeb 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -256,7 +256,7 @@ class TestGaussianRandomWalk(BaseTestCases.BaseTestCase):
     default_shape = (1,)
 
 
-class BaseTestDistribution(SeededTest):
+class BaseTestDistributionRandom(SeededTest):
     """
     This class provides a base for tests that new RandomVariables are correctly
     implemented, and that the mapping of parameters between the PyMC
@@ -411,7 +411,7 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
     )
 
 
-class TestFlat(BaseTestDistribution):
+class TestFlat(BaseTestDistributionRandom):
     pymc_dist = pm.Flat
     pymc_dist_params = {}
     expected_rv_op_params = {}
@@ -426,7 +426,7 @@ def check_not_implemented(self):
             self.pymc_rv.eval()
 
 
-class TestHalfFlat(BaseTestDistribution):
+class TestHalfFlat(BaseTestDistributionRandom):
     pymc_dist = pm.HalfFlat
     pymc_dist_params = {}
     expected_rv_op_params = {}
@@ -441,7 +441,7 @@ def check_not_implemented(self):
             self.pymc_rv.eval()
 
 
-class TestDiscreteWeibull(BaseTestDistribution):
+class TestDiscreteWeibull(BaseTestDistributionRandom):
     def discrete_weibul_rng_fn(self, size, q, beta, uniform_rng_fct):
         return np.ceil(np.power(np.log(1 - uniform_rng_fct(size=size)) / np.log(q), 1.0 / beta)) - 1
 
@@ -463,7 +463,7 @@ def seeded_discrete_weibul_rng_fn(self):
     ]
 
 
-class TestPareto(BaseTestDistribution):
+class TestPareto(BaseTestDistributionRandom):
     pymc_dist = pm.Pareto
     pymc_dist_params = {"alpha": 3.0, "m": 2.0}
     expected_rv_op_params = {"alpha": 3.0, "m": 2.0}
@@ -476,7 +476,7 @@ class TestPareto(BaseTestDistribution):
     ]
 
 
-class TestLaplace(BaseTestDistribution):
+class TestLaplace(BaseTestDistributionRandom):
     pymc_dist = pm.Laplace
     pymc_dist_params = {"mu": 0.0, "b": 1.0}
     expected_rv_op_params = {"mu": 0.0, "b": 1.0}
@@ -489,7 +489,7 @@ class TestLaplace(BaseTestDistribution):
     ]
 
 
-class TestAsymmetricLaplace(BaseTestDistribution):
+class TestAsymmetricLaplace(BaseTestDistributionRandom):
     def asymmetriclaplace_rng_fn(self, b, kappa, mu, size, uniform_rng_fct):
         u = uniform_rng_fct(size=size)
         switch = kappa ** 2 / (1 + kappa ** 2)
@@ -517,7 +517,7 @@ def seeded_asymmetriclaplace_rng_fn(self):
     ]
 
 
-class TestExGaussian(BaseTestDistribution):
+class TestExGaussian(BaseTestDistributionRandom):
     def exgaussian_rng_fn(self, mu, sigma, nu, size, normal_rng_fct, exponential_rng_fct):
         return normal_rng_fct(mu, sigma, size=size) + exponential_rng_fct(scale=nu, size=size)
 
@@ -547,7 +547,7 @@ def seeded_exgaussian_rng_fn(self):
     ]
 
 
-class TestGumbel(BaseTestDistribution):
+class TestGumbel(BaseTestDistributionRandom):
     pymc_dist = pm.Gumbel
     pymc_dist_params = {"mu": 1.5, "beta": 3.0}
     expected_rv_op_params = {"mu": 1.5, "beta": 3.0}
@@ -559,7 +559,7 @@ class TestGumbel(BaseTestDistribution):
     ]
 
 
-class TestStudentT(BaseTestDistribution):
+class TestStudentT(BaseTestDistributionRandom):
     pymc_dist = pm.StudentT
     pymc_dist_params = {"nu": 5.0, "mu": -1.0, "sigma": 2.0}
     expected_rv_op_params = {"nu": 5.0, "mu": -1.0, "sigma": 2.0}
@@ -572,7 +572,7 @@ class TestStudentT(BaseTestDistribution):
     ]
 
 
-class TestMoyal(BaseTestDistribution):
+class TestMoyal(BaseTestDistributionRandom):
     pymc_dist = pm.Moyal
     pymc_dist_params = {"mu": 0.0, "sigma": 1.0}
     expected_rv_op_params = {"mu": 0.0, "sigma": 1.0}
@@ -585,7 +585,7 @@ class TestMoyal(BaseTestDistribution):
     ]
 
 
-class TestKumaraswamy(BaseTestDistribution):
+class TestKumaraswamy(BaseTestDistributionRandom):
     def kumaraswamy_rng_fn(self, a, b, size, uniform_rng_fct):
         return (1 - (1 - uniform_rng_fct(size=size)) ** (1 / b)) ** (1 / a)
 
@@ -607,7 +607,7 @@ def seeded_kumaraswamy_rng_fn(self):
     ]
 
 
-class TestTruncatedNormal(BaseTestDistribution):
+class TestTruncatedNormal(BaseTestDistributionRandom):
     pymc_dist = pm.TruncatedNormal
     lower, upper, mu, sigma = -2.0, 2.0, 0, 1.0
     pymc_dist_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper}
@@ -626,7 +626,7 @@ class TestTruncatedNormal(BaseTestDistribution):
     ]
 
 
-class TestTruncatedNormalTau(BaseTestDistribution):
+class TestTruncatedNormalTau(BaseTestDistributionRandom):
     pymc_dist = pm.TruncatedNormal
     lower, upper, mu, tau = -2.0, 2.0, 0, 1.0
     tau, sigma = get_tau_sigma(tau=tau, sigma=None)
@@ -637,7 +637,7 @@ class TestTruncatedNormalTau(BaseTestDistribution):
     ]
 
 
-class TestTruncatedNormalLowerTau(BaseTestDistribution):
+class TestTruncatedNormalLowerTau(BaseTestDistributionRandom):
     pymc_dist = pm.TruncatedNormal
     lower, upper, mu, tau = -2.0, np.inf, 0, 1.0
     tau, sigma = get_tau_sigma(tau=tau, sigma=None)
@@ -648,7 +648,7 @@ class TestTruncatedNormalLowerTau(BaseTestDistribution):
     ]
 
 
-class TestTruncatedNormalUpperTau(BaseTestDistribution):
+class TestTruncatedNormalUpperTau(BaseTestDistributionRandom):
     pymc_dist = pm.TruncatedNormal
     lower, upper, mu, tau = -np.inf, 2.0, 0, 1.0
     tau, sigma = get_tau_sigma(tau=tau, sigma=None)
@@ -659,7 +659,7 @@ class TestTruncatedNormalUpperTau(BaseTestDistribution):
     ]
 
 
-class TestTruncatedNormalUpperArray(BaseTestDistribution):
+class TestTruncatedNormalUpperArray(BaseTestDistributionRandom):
     pymc_dist = pm.TruncatedNormal
     lower, upper, mu, tau = (
         np.array([-np.inf, -np.inf]),
@@ -681,7 +681,7 @@ class TestTruncatedNormalUpperArray(BaseTestDistribution):
     ]
 
 
-class TestWald(BaseTestDistribution):
+class TestWald(BaseTestDistributionRandom):
     pymc_dist = pm.Wald
     mu, lam, alpha = 1.0, 1.0, 0.0
     mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None)
@@ -701,7 +701,7 @@ def check_pymc_draws_match_reference(self):
         )
 
 
-class TestWaldMuPhi(BaseTestDistribution):
+class TestWaldMuPhi(BaseTestDistributionRandom):
     pymc_dist = pm.Wald
     mu, phi, alpha = 1.0, 3.0, 0.0
     mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=None, phi=phi)
@@ -712,7 +712,7 @@ class TestWaldMuPhi(BaseTestDistribution):
     ]
 
 
-class TestSkewNormal(BaseTestDistribution):
+class TestSkewNormal(BaseTestDistributionRandom):
     pymc_dist = pm.SkewNormal
     pymc_dist_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
     expected_rv_op_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
@@ -725,7 +725,7 @@ class TestSkewNormal(BaseTestDistribution):
     ]
 
 
-class TestSkewNormalTau(BaseTestDistribution):
+class TestSkewNormalTau(BaseTestDistributionRandom):
     pymc_dist = pm.SkewNormal
     tau, sigma = get_tau_sigma(tau=2.0)
     pymc_dist_params = {"mu": 0.0, "tau": tau, "alpha": 5.0}
@@ -733,7 +733,7 @@ class TestSkewNormalTau(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestRice(BaseTestDistribution):
+class TestRice(BaseTestDistributionRandom):
     pymc_dist = pm.Rice
     b, sigma = 1, 2
     pymc_dist_params = {"b": b, "sigma": sigma}
@@ -747,7 +747,7 @@ class TestRice(BaseTestDistribution):
     ]
 
 
-class TestRiceNu(BaseTestDistribution):
+class TestRiceNu(BaseTestDistributionRandom):
     pymc_dist = pm.Rice
     nu = sigma = 2
     pymc_dist_params = {"nu": nu, "sigma": sigma}
@@ -755,7 +755,7 @@ class TestRiceNu(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestStudentTLam(BaseTestDistribution):
+class TestStudentTLam(BaseTestDistributionRandom):
     pymc_dist = pm.StudentT
     lam, sigma = get_tau_sigma(tau=2.0)
     pymc_dist_params = {"nu": 5.0, "mu": -1.0, "lam": lam}
@@ -765,7 +765,7 @@ class TestStudentTLam(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestNormal(BaseTestDistribution):
+class TestNormal(BaseTestDistributionRandom):
     pymc_dist = pm.Normal
     pymc_dist_params = {"mu": 5.0, "sigma": 10.0}
     expected_rv_op_params = {"mu": 5.0, "sigma": 10.0}
@@ -779,7 +779,7 @@ class TestNormal(BaseTestDistribution):
     ]
 
 
-class TestLogitNormal(BaseTestDistribution):
+class TestLogitNormal(BaseTestDistributionRandom):
     def logit_normal_rng_fn(self, rng, size, loc, scale):
         return expit(st.norm.rvs(loc=loc, scale=scale, size=size, random_state=rng))
 
@@ -797,7 +797,7 @@ def logit_normal_rng_fn(self, rng, size, loc, scale):
     ]
 
 
-class TestLogitNormalTau(BaseTestDistribution):
+class TestLogitNormalTau(BaseTestDistributionRandom):
     pymc_dist = pm.LogitNormal
     tau, sigma = get_tau_sigma(tau=25.0)
     pymc_dist_params = {"mu": 1.0, "tau": tau}
@@ -805,7 +805,7 @@ class TestLogitNormalTau(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestNormalTau(BaseTestDistribution):
+class TestNormalTau(BaseTestDistributionRandom):
     pymc_dist = pm.Normal
     tau, sigma = get_tau_sigma(tau=25.0)
     pymc_dist_params = {"mu": 1.0, "tau": tau}
@@ -813,21 +813,21 @@ class TestNormalTau(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestNormalSd(BaseTestDistribution):
+class TestNormalSd(BaseTestDistributionRandom):
     pymc_dist = pm.Normal
     pymc_dist_params = {"mu": 1.0, "sd": 5.0}
     expected_rv_op_params = {"mu": 1.0, "sigma": 5.0}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestUniform(BaseTestDistribution):
+class TestUniform(BaseTestDistributionRandom):
     pymc_dist = pm.Uniform
     pymc_dist_params = {"lower": 0.5, "upper": 1.5}
     expected_rv_op_params = {"lower": 0.5, "upper": 1.5}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestHalfNormal(BaseTestDistribution):
+class TestHalfNormal(BaseTestDistributionRandom):
     pymc_dist = pm.HalfNormal
     pymc_dist_params = {"sigma": 10.0}
     expected_rv_op_params = {"mean": 0, "sigma": 10.0}
@@ -839,7 +839,7 @@ class TestHalfNormal(BaseTestDistribution):
     ]
 
 
-class TestHalfNormalTau(BaseTestDistribution):
+class TestHalfNormalTau(BaseTestDistributionRandom):
     pymc_dist = pm.Normal
     tau, sigma = get_tau_sigma(tau=25.0)
     pymc_dist_params = {"tau": tau}
@@ -847,14 +847,14 @@ class TestHalfNormalTau(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestHalfNormalSd(BaseTestDistribution):
+class TestHalfNormalSd(BaseTestDistributionRandom):
     pymc_dist = pm.Normal
     pymc_dist_params = {"sd": 5.0}
     expected_rv_op_params = {"mu": 0.0, "sigma": 5.0}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestBeta(BaseTestDistribution):
+class TestBeta(BaseTestDistributionRandom):
     pymc_dist = pm.Beta
     pymc_dist_params = {"alpha": 2.0, "beta": 5.0}
     expected_rv_op_params = {"alpha": 2.0, "beta": 5.0}
@@ -870,7 +870,7 @@ class TestBeta(BaseTestDistribution):
     ]
 
 
-class TestBetaMuSigma(BaseTestDistribution):
+class TestBetaMuSigma(BaseTestDistributionRandom):
     pymc_dist = pm.Beta
     pymc_dist_params = {"mu": 0.5, "sigma": 0.25}
     expected_alpha, expected_beta = pm.Beta.get_alpha_beta(
@@ -880,7 +880,7 @@ class TestBetaMuSigma(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestExponential(BaseTestDistribution):
+class TestExponential(BaseTestDistributionRandom):
     pymc_dist = pm.Exponential
     pymc_dist_params = {"lam": 10.0}
     expected_rv_op_params = {"mu": 1.0 / pymc_dist_params["lam"]}
@@ -892,7 +892,7 @@ class TestExponential(BaseTestDistribution):
     ]
 
 
-class TestCauchy(BaseTestDistribution):
+class TestCauchy(BaseTestDistributionRandom):
     pymc_dist = pm.Cauchy
     pymc_dist_params = {"alpha": 2.0, "beta": 5.0}
     expected_rv_op_params = {"alpha": 2.0, "beta": 5.0}
@@ -904,7 +904,7 @@ class TestCauchy(BaseTestDistribution):
     ]
 
 
-class TestHalfCauchy(BaseTestDistribution):
+class TestHalfCauchy(BaseTestDistributionRandom):
     pymc_dist = pm.HalfCauchy
     pymc_dist_params = {"beta": 5.0}
     expected_rv_op_params = {"alpha": 0.0, "beta": 5.0}
@@ -916,7 +916,7 @@ class TestHalfCauchy(BaseTestDistribution):
     ]
 
 
-class TestGamma(BaseTestDistribution):
+class TestGamma(BaseTestDistributionRandom):
     pymc_dist = pm.Gamma
     pymc_dist_params = {"alpha": 2.0, "beta": 5.0}
     expected_rv_op_params = {"alpha": 2.0, "beta": 1 / 5.0}
@@ -928,7 +928,7 @@ class TestGamma(BaseTestDistribution):
     ]
 
 
-class TestGammaMuSigma(BaseTestDistribution):
+class TestGammaMuSigma(BaseTestDistributionRandom):
     pymc_dist = pm.Gamma
     pymc_dist_params = {"mu": 0.5, "sigma": 0.25}
     expected_alpha, expected_beta = pm.Gamma.get_alpha_beta(
@@ -938,7 +938,7 @@ class TestGammaMuSigma(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestInverseGamma(BaseTestDistribution):
+class TestInverseGamma(BaseTestDistributionRandom):
     pymc_dist = pm.InverseGamma
     pymc_dist_params = {"alpha": 2.0, "beta": 5.0}
     expected_rv_op_params = {"alpha": 2.0, "beta": 5.0}
@@ -950,7 +950,7 @@ class TestInverseGamma(BaseTestDistribution):
     ]
 
 
-class TestInverseGammaMuSigma(BaseTestDistribution):
+class TestInverseGammaMuSigma(BaseTestDistributionRandom):
     pymc_dist = pm.InverseGamma
     pymc_dist_params = {"mu": 0.5, "sigma": 0.25}
     expected_alpha, expected_beta = pm.InverseGamma._get_alpha_beta(
@@ -963,7 +963,7 @@ class TestInverseGammaMuSigma(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestChiSquared(BaseTestDistribution):
+class TestChiSquared(BaseTestDistributionRandom):
     pymc_dist = pm.ChiSquared
     pymc_dist_params = {"nu": 2.0}
     expected_rv_op_params = {"nu": 2.0}
@@ -976,21 +976,21 @@ class TestChiSquared(BaseTestDistribution):
     ]
 
 
-class TestBinomial(BaseTestDistribution):
+class TestBinomial(BaseTestDistributionRandom):
     pymc_dist = pm.Binomial
     pymc_dist_params = {"n": 100, "p": 0.33}
     expected_rv_op_params = {"n": 100, "p": 0.33}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestNegativeBinomial(BaseTestDistribution):
+class TestNegativeBinomial(BaseTestDistributionRandom):
     pymc_dist = pm.NegativeBinomial
     pymc_dist_params = {"n": 100, "p": 0.33}
     expected_rv_op_params = {"n": 100, "p": 0.33}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestNegativeBinomialMuSigma(BaseTestDistribution):
+class TestNegativeBinomialMuSigma(BaseTestDistributionRandom):
     pymc_dist = pm.NegativeBinomial
     pymc_dist_params = {"mu": 5.0, "alpha": 8.0}
     expected_n, expected_p = pm.NegativeBinomial.get_n_p(
@@ -1003,7 +1003,7 @@ class TestNegativeBinomialMuSigma(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestBernoulli(BaseTestDistribution):
+class TestBernoulli(BaseTestDistributionRandom):
     pymc_dist = pm.Bernoulli
     pymc_dist_params = {"p": 0.33}
     expected_rv_op_params = {"p": 0.33}
@@ -1015,21 +1015,21 @@ class TestBernoulli(BaseTestDistribution):
     ]
 
 
-class TestBernoulliLogitP(BaseTestDistribution):
+class TestBernoulliLogitP(BaseTestDistributionRandom):
     pymc_dist = pm.Bernoulli
     pymc_dist_params = {"logit_p": 1.0}
     expected_rv_op_params = {"p": expit(1.0)}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestPoisson(BaseTestDistribution):
+class TestPoisson(BaseTestDistributionRandom):
     pymc_dist = pm.Poisson
     pymc_dist_params = {"mu": 4.0}
     expected_rv_op_params = {"mu": 4.0}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestMvNormalCov(BaseTestDistribution):
+class TestMvNormalCov(BaseTestDistributionRandom):
     pymc_dist = pm.MvNormal
     pymc_dist_params = {
         "mu": np.array([1.0, 2.0]),
@@ -1077,7 +1077,7 @@ def check_mu_broadcast_helper(self):
         # assert mu.eval().shape == (10, 2, 3)
 
 
-class TestMvNormalChol(BaseTestDistribution):
+class TestMvNormalChol(BaseTestDistributionRandom):
     pymc_dist = pm.MvNormal
     pymc_dist_params = {
         "mu": np.array([1.0, 2.0]),
@@ -1090,7 +1090,7 @@ class TestMvNormalChol(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestMvNormalTau(BaseTestDistribution):
+class TestMvNormalTau(BaseTestDistributionRandom):
     pymc_dist = pm.MvNormal
     pymc_dist_params = {
         "mu": np.array([1.0, 2.0]),
@@ -1103,7 +1103,7 @@ class TestMvNormalTau(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestMvStudentTCov(BaseTestDistribution):
+class TestMvStudentTCov(BaseTestDistributionRandom):
     def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
         chi2_samples = rng.chisquare(nu, size=size)
         mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size)
@@ -1173,7 +1173,7 @@ def check_mu_broadcast_helper(self):
         # assert mu.eval().shape == (10, 2, 3)
 
 
-class TestMvStudentTChol(BaseTestDistribution):
+class TestMvStudentTChol(BaseTestDistributionRandom):
     pymc_dist = pm.MvStudentT
     pymc_dist_params = {
         "nu": 5,
@@ -1188,7 +1188,7 @@ class TestMvStudentTChol(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestMvStudentTTau(BaseTestDistribution):
+class TestMvStudentTTau(BaseTestDistributionRandom):
     pymc_dist = pm.MvStudentT
     pymc_dist_params = {
         "nu": 5,
@@ -1203,7 +1203,7 @@ class TestMvStudentTTau(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestDirichlet(BaseTestDistribution):
+class TestDirichlet(BaseTestDistributionRandom):
     pymc_dist = pm.Dirichlet
     pymc_dist_params = {"a": np.array([1.0, 2.0])}
     expected_rv_op_params = {"a": np.array([1.0, 2.0])}
@@ -1218,7 +1218,7 @@ class TestDirichlet(BaseTestDistribution):
     ]
 
 
-class TestStickBreakingWeights(BaseTestDistribution):
+class TestStickBreakingWeights(BaseTestDistributionRandom):
     pymc_dist = pm.StickBreakingWeights
     pymc_dist_params = {"alpha": 2.0, "K": 19}
     expected_rv_op_params = {"alpha": 2.0, "K": 19}
@@ -1253,7 +1253,7 @@ def check_basic_properties(self):
         assert np.all(draws <= 1)
 
 
-class TestMultinomial(BaseTestDistribution):
+class TestMultinomial(BaseTestDistributionRandom):
     pymc_dist = pm.Multinomial
     pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
     expected_rv_op_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
@@ -1268,7 +1268,7 @@ class TestMultinomial(BaseTestDistribution):
     ]
 
 
-class TestDirichletMultinomial(BaseTestDistribution):
+class TestDirichletMultinomial(BaseTestDistributionRandom):
     pymc_dist = pm.DirichletMultinomial
 
     pymc_dist_params = {"n": 85, "a": np.array([1.0, 2.0, 1.5, 1.5])}
@@ -1298,7 +1298,7 @@ def check_random_draws(self):
         assert np.all((draws.sum(-2)[:, :, 3] > 3) & (draws.sum(-2)[:, :, 3] <= 5))
 
 
-class TestDirichletMultinomial_1d_n_2d_a(BaseTestDistribution):
+class TestDirichletMultinomial_1D_n_2D_a(BaseTestDistributionRandom):
     pymc_dist = pm.DirichletMultinomial
     pymc_dist_params = {
         "n": np.array([23, 29]),
@@ -1309,7 +1309,7 @@ class TestDirichletMultinomial_1d_n_2d_a(BaseTestDistribution):
     checks_to_run = ["check_rv_size"]
 
 
-class TestCategorical(BaseTestDistribution):
+class TestCategorical(BaseTestDistributionRandom):
     pymc_dist = pm.Categorical
     pymc_dist_params = {"p": np.array([0.28, 0.62, 0.10])}
     expected_rv_op_params = {"p": np.array([0.28, 0.62, 0.10])}
@@ -1319,14 +1319,14 @@ class TestCategorical(BaseTestDistribution):
     ]
 
 
-class TestGeometric(BaseTestDistribution):
+class TestGeometric(BaseTestDistributionRandom):
     pymc_dist = pm.Geometric
     pymc_dist_params = {"p": 0.9}
     expected_rv_op_params = {"p": 0.9}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestHyperGeometric(BaseTestDistribution):
+class TestHyperGeometric(BaseTestDistributionRandom):
     pymc_dist = pm.HyperGeometric
     pymc_dist_params = {"N": 20, "k": 12, "n": 5}
     expected_rv_op_params = {
@@ -1342,21 +1342,21 @@ class TestHyperGeometric(BaseTestDistribution):
     ]
 
 
-class TestLogistic(BaseTestDistribution):
+class TestLogistic(BaseTestDistributionRandom):
     pymc_dist = pm.Logistic
     pymc_dist_params = {"mu": 1.0, "s": 2.0}
     expected_rv_op_params = {"mu": 1.0, "s": 2.0}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestLogNormal(BaseTestDistribution):
+class TestLogNormal(BaseTestDistributionRandom):
     pymc_dist = pm.LogNormal
     pymc_dist_params = {"mu": 1.0, "sigma": 5.0}
     expected_rv_op_params = {"mu": 1.0, "sigma": 5.0}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestLognormalTau(BaseTestDistribution):
+class TestLognormalTau(BaseTestDistributionRandom):
     pymc_dist = pm.Lognormal
     tau, sigma = get_tau_sigma(tau=25.0)
     pymc_dist_params = {"mu": 1.0, "tau": 25.0}
@@ -1364,14 +1364,14 @@ class TestLognormalTau(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestLognormalSd(BaseTestDistribution):
+class TestLognormalSd(BaseTestDistributionRandom):
     pymc_dist = pm.Lognormal
     pymc_dist_params = {"mu": 1.0, "sd": 5.0}
     expected_rv_op_params = {"mu": 1.0, "sigma": 5.0}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestTriangular(BaseTestDistribution):
+class TestTriangular(BaseTestDistributionRandom):
     pymc_dist = pm.Triangular
     pymc_dist_params = {"lower": 0, "upper": 1, "c": 0.5}
     expected_rv_op_params = {"lower": 0, "c": 0.5, "upper": 1}
@@ -1383,14 +1383,14 @@ class TestTriangular(BaseTestDistribution):
     ]
 
 
-class TestVonMises(BaseTestDistribution):
+class TestVonMises(BaseTestDistributionRandom):
     pymc_dist = pm.VonMises
     pymc_dist_params = {"mu": -2.1, "kappa": 5}
     expected_rv_op_params = {"mu": -2.1, "kappa": 5}
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestWeibull(BaseTestDistribution):
+class TestWeibull(BaseTestDistributionRandom):
     def weibull_rng_fn(self, size, alpha, beta, std_weibull_rng_fct):
         return beta * std_weibull_rng_fct(alpha, size=size)
 
@@ -1412,7 +1412,7 @@ def seeded_weibul_rng_fn(self):
     ]
 
 
-class TestBetaBinomial(BaseTestDistribution):
+class TestBetaBinomial(BaseTestDistributionRandom):
     pymc_dist = pm.BetaBinomial
     pymc_dist_params = {"alpha": 2.0, "beta": 1.0, "n": 5}
     expected_rv_op_params = {"n": 5, "alpha": 2.0, "beta": 1.0}
@@ -1429,7 +1429,7 @@ class TestBetaBinomial(BaseTestDistribution):
     condition=_polyagamma_not_installed,
     reason="`polyagamma package is not available/installed.",
 )
-class TestPolyaGamma(BaseTestDistribution):
+class TestPolyaGamma(BaseTestDistributionRandom):
     def polyagamma_rng_fn(self, size, h, z, rng):
         return random_polyagamma(h, z, size=size, random_state=rng._bit_generator)
 
@@ -1447,7 +1447,7 @@ def polyagamma_rng_fn(self, size, h, z, rng):
     ]
 
 
-class TestDiscreteUniform(BaseTestDistribution):
+class TestDiscreteUniform(BaseTestDistributionRandom):
     def discrete_uniform_rng_fn(self, size, lower, upper, rng):
         return st.randint.rvs(lower, upper + 1, size=size, random_state=rng)
 
@@ -1465,7 +1465,7 @@ def discrete_uniform_rng_fn(self, size, lower, upper, rng):
     ]
 
 
-class TestConstant(BaseTestDistribution):
+class TestConstant(BaseTestDistributionRandom):
     def constant_rng_fn(self, size, c):
         if size is None:
             return c
@@ -1483,7 +1483,7 @@ def constant_rng_fn(self, size, c):
     ]
 
 
-class TestZeroInflatedPoisson(BaseTestDistribution):
+class TestZeroInflatedPoisson(BaseTestDistributionRandom):
     def zero_inflated_poisson_rng_fn(self, size, psi, theta, poisson_rng_fct, random_rng_fct):
         return poisson_rng_fct(theta, size=size) * (random_rng_fct(size=size) < psi)
 
@@ -1514,7 +1514,7 @@ def seeded_zero_inflated_poisson_rng_fn(self):
     ]
 
 
-class TestZeroInflatedBinomial(BaseTestDistribution):
+class TestZeroInflatedBinomial(BaseTestDistributionRandom):
     def zero_inflated_binomial_rng_fn(self, size, psi, n, p, binomial_rng_fct, random_rng_fct):
         return binomial_rng_fct(n, p, size=size) * (random_rng_fct(size=size) < psi)
 
@@ -1545,7 +1545,7 @@ def seeded_zero_inflated_binomial_rng_fn(self):
     ]
 
 
-class TestZeroInflatedNegativeBinomialMuSigma(BaseTestDistribution):
+class TestZeroInflatedNegativeBinomialMuSigma(BaseTestDistributionRandom):
     def zero_inflated_negbinomial_rng_fn(
         self, size, psi, n, p, negbinomial_rng_fct, random_rng_fct
     ):
@@ -1580,7 +1580,7 @@ def seeded_zero_inflated_negbinomial_rng_fn(self):
     ]
 
 
-class TestZeroInflatedNegativeBinomial(BaseTestDistribution):
+class TestZeroInflatedNegativeBinomial(BaseTestDistributionRandom):
     pymc_dist = pm.ZeroInflatedNegativeBinomial
     pymc_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
     expected_rv_op_params = {"psi": 0.9, "n": 12, "p": 0.7}
@@ -1588,7 +1588,7 @@ class TestZeroInflatedNegativeBinomial(BaseTestDistribution):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
-class TestOrderedLogistic(BaseTestDistribution):
+class TestOrderedLogistic(BaseTestDistributionRandom):
     pymc_dist = _OrderedLogistic
     pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}
     expected_rv_op_params = {"p": np.array([0.11920292, 0.38079708, 0.38079708, 0.11920292])}
@@ -1598,7 +1598,7 @@ class TestOrderedLogistic(BaseTestDistribution):
     ]
 
 
-class TestOrderedProbit(BaseTestDistribution):
+class TestOrderedProbit(BaseTestDistributionRandom):
     pymc_dist = _OrderedProbit
     pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}
     expected_rv_op_params = {"p": np.array([0.02275013, 0.47724987, 0.47724987, 0.02275013])}
@@ -1608,7 +1608,7 @@ class TestOrderedProbit(BaseTestDistribution):
     ]
 
 
-class TestOrderedMultinomial(BaseTestDistribution):
+class TestOrderedMultinomial(BaseTestDistributionRandom):
     pymc_dist = _OrderedMultinomial
     pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2]), "n": 1000}
     sizes_to_check = [None, (1), (4,), (3, 2)]
@@ -1623,7 +1623,7 @@ class TestOrderedMultinomial(BaseTestDistribution):
     ]
 
 
-class TestWishart(BaseTestDistribution):
+class TestWishart(BaseTestDistributionRandom):
     def wishart_rng_fn(self, size, nu, V, rng):
         return st.wishart.rvs(np.int(nu), V, size=size, random_state=rng)
 
@@ -1649,7 +1649,7 @@ def wishart_rng_fn(self, size, nu, V, rng):
     ]
 
 
-class TestMatrixNormal(BaseTestDistribution):
+class TestMatrixNormal(BaseTestDistributionRandom):
 
     pymc_dist = pm.MatrixNormal
 
@@ -1734,7 +1734,7 @@ def check_errors(self):
                 )
 
 
-class TestInterpolated(BaseTestDistribution):
+class TestInterpolated(BaseTestDistributionRandom):
     def interpolated_rng_fn(self, size, mu, sigma, rng):
         return st.norm.rvs(loc=mu, scale=sigma, size=size)
 
@@ -1781,7 +1781,7 @@ def dist(cls, **kwargs):
                 )
 
 
-class TestKroneckerNormal(BaseTestDistribution):
+class TestKroneckerNormal(BaseTestDistributionRandom):
     def kronecker_rng_fn(self, size, mu, covs=None, sigma=None, rng=None):
         cov = pm.math.kronecker(covs[0], covs[1]).eval()
         cov += sigma ** 2 * np.identity(cov.shape[0])

From 346daa10bd4f3012d57b063c16126151fb438f50 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 24 Jan 2022 12:19:38 +0100
Subject: [PATCH 02/10] Test expected (inferred) and actual shape of draws in
 `TestBaseDistributionRandom`

* Fixes bug in returned samples from `Wishart` when `size=1`
---
 pymc/distributions/multivariate.py      | 10 ++++++---
 pymc/tests/test_distributions_random.py | 30 ++++++++++++++++++++-----
 2 files changed, 31 insertions(+), 9 deletions(-)

diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index 1c7f4d42cf..7ef3dd1c19 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -891,9 +891,13 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
         return dist_params[1].shape
 
     @classmethod
-    def rng_fn(cls, rng, nu, V, size=None):
-        size = size if size else 1  # Default size for Scipy's wishart.rvs is 1
-        return stats.wishart.rvs(np.int(nu), V, size=size, random_state=rng)
+    def rng_fn(cls, rng, nu, V, size):
+        scipy_size = size if size else 1  # Default size for Scipy's wishart.rvs is 1
+        result = stats.wishart.rvs(np.int(nu), V, size=scipy_size, random_state=rng)
+        if size == (1,):
+            return result[np.newaxis, ...]
+        else:
+            return result
 
 
 wishart = WishartRV()
diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index 02a775baeb..7232657e59 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -369,8 +369,9 @@ def check_rv_size(self):
         sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
         for size, expected in zip(sizes_to_check, sizes_expected):
             pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
-            actual = tuple(pymc_rv.shape.eval())
-            assert actual == expected, f"size={size}, expected={expected}, actual={actual}"
+            expected_symbolic = tuple(pymc_rv.shape.eval())
+            actual = pymc_rv.eval().shape
+            assert actual == expected_symbolic == expected
 
         # test multi-parameters sampling for univariate distributions (with univariate inputs)
         if (
@@ -390,8 +391,9 @@ def check_rv_size(self):
             ]
             for size, expected in zip(sizes_to_check, sizes_expected):
                 pymc_rv = self.pymc_dist.dist(**params, size=size)
-                actual = tuple(pymc_rv.shape.eval())
-                assert actual == expected
+                expected_symbolic = tuple(pymc_rv.shape.eval())
+                actual = pymc_rv.eval().shape
+                assert actual == expected_symbolic == expected
 
     def validate_tests_list(self):
         assert len(self.checks_to_run) == len(
@@ -417,10 +419,18 @@ class TestFlat(BaseTestDistributionRandom):
     expected_rv_op_params = {}
     checks_to_run = [
         "check_pymc_params_match_rv_op",
-        "check_rv_size",
+        "check_rv_inferred_size",
         "check_not_implemented",
     ]
 
+    def check_rv_inferred_size(self):
+        sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
+        sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
+        for size, expected in zip(sizes_to_check, sizes_expected):
+            pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
+            expected_symbolic = tuple(pymc_rv.shape.eval())
+            assert expected_symbolic == expected
+
     def check_not_implemented(self):
         with pytest.raises(NotImplementedError):
             self.pymc_rv.eval()
@@ -432,10 +442,18 @@ class TestHalfFlat(BaseTestDistributionRandom):
     expected_rv_op_params = {}
     checks_to_run = [
         "check_pymc_params_match_rv_op",
-        "check_rv_size",
+        "check_rv_inferred_size",
         "check_not_implemented",
     ]
 
+    def check_rv_inferred_size(self):
+        sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
+        sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
+        for size, expected in zip(sizes_to_check, sizes_expected):
+            pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
+            expected_symbolic = tuple(pymc_rv.shape.eval())
+            assert expected_symbolic == expected
+
     def check_not_implemented(self):
         with pytest.raises(NotImplementedError):
             self.pymc_rv.eval()

From 2daa76676c267c68f915d78962516026f30eed90 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Tue, 25 Jan 2022 19:29:27 +0100
Subject: [PATCH 03/10] Update aesara dependency to 2.3.8

---
 conda-envs/environment-dev-py37.yml          | 2 +-
 conda-envs/environment-dev-py38.yml          | 2 +-
 conda-envs/environment-dev-py39.yml          | 2 +-
 conda-envs/environment-test-py37.yml         | 2 +-
 conda-envs/environment-test-py38.yml         | 2 +-
 conda-envs/environment-test-py39.yml         | 2 +-
 conda-envs/windows-environment-dev-py38.yml  | 2 +-
 conda-envs/windows-environment-test-py38.yml | 2 +-
 requirements-dev.txt                         | 2 +-
 requirements.txt                             | 2 +-
 10 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml
index 5458f8e866..f461596420 100644
--- a/conda-envs/environment-dev-py37.yml
+++ b/conda-envs/environment-dev-py37.yml
@@ -5,7 +5,7 @@ channels:
 - defaults
 dependencies:
 - aeppl=0.0.18
-- aesara=2.3.6
+- aesara=2.3.8
 - arviz>=0.11.4
 - cachetools>=4.2.1
 - cloudpickle
diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml
index 4db5da94df..d5617ffa61 100644
--- a/conda-envs/environment-dev-py38.yml
+++ b/conda-envs/environment-dev-py38.yml
@@ -5,7 +5,7 @@ channels:
 - defaults
 dependencies:
 - aeppl=0.0.18
-- aesara=2.3.6
+- aesara=2.3.8
 - arviz>=0.11.4
 - cachetools>=4.2.1
 - cloudpickle
diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml
index fd9ebf0747..e6e476ec4c 100644
--- a/conda-envs/environment-dev-py39.yml
+++ b/conda-envs/environment-dev-py39.yml
@@ -5,7 +5,7 @@ channels:
 - defaults
 dependencies:
 - aeppl=0.0.18
-- aesara=2.3.6
+- aesara=2.3.8
 - arviz>=0.11.4
 - cachetools>=4.2.1
 - cloudpickle
diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml
index a0a484e7eb..af99718472 100644
--- a/conda-envs/environment-test-py37.yml
+++ b/conda-envs/environment-test-py37.yml
@@ -5,7 +5,7 @@ channels:
 - defaults
 dependencies:
 - aeppl=0.0.18
-- aesara=2.3.6
+- aesara=2.3.8
 - arviz>=0.11.4
 - cachetools>=4.2.1
 - cloudpickle
diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml
index d10434332b..c63bf4d214 100644
--- a/conda-envs/environment-test-py38.yml
+++ b/conda-envs/environment-test-py38.yml
@@ -5,7 +5,7 @@ channels:
 - defaults
 dependencies:
 - aeppl=0.0.18
-- aesara=2.3.6
+- aesara=2.3.8
 - arviz>=0.11.4
 - cachetools>=4.2.1
 - cloudpickle
diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml
index 3300d80ea5..5cd354fe4d 100644
--- a/conda-envs/environment-test-py39.yml
+++ b/conda-envs/environment-test-py39.yml
@@ -5,7 +5,7 @@ channels:
 - defaults
 dependencies:
 - aeppl=0.0.18
-- aesara=2.3.6
+- aesara=2.3.8
 - arviz>=0.11.4
 - cachetools
 - cloudpickle
diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml
index 9872f2baa0..788ae6d28a 100644
--- a/conda-envs/windows-environment-dev-py38.yml
+++ b/conda-envs/windows-environment-dev-py38.yml
@@ -5,7 +5,7 @@ channels:
 dependencies:
  # base dependencies (see install guide for Windows)
 - aeppl=0.0.18
-- aesara=2.3.6
+- aesara=2.3.8
 - arviz>=0.11.4
 - cachetools>=4.2.1
 - cloudpickle
diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml
index 41e377d51e..2d8fc3fe21 100644
--- a/conda-envs/windows-environment-test-py38.yml
+++ b/conda-envs/windows-environment-test-py38.yml
@@ -5,7 +5,7 @@ channels:
 dependencies:
  # base dependencies (see install guide for Windows)
 - aeppl=0.0.18
-- aesara=2.3.6
+- aesara=2.3.8
 - arviz>=0.11.2
 - cachetools
 - cloudpickle
diff --git a/requirements-dev.txt b/requirements-dev.txt
index d037798484..4a317a1f5a 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -2,7 +2,7 @@
 # See that file for comments about the need/usage of each dependency.
 
 aeppl==0.0.18
-aesara==2.3.6
+aesara==2.3.8
 arviz>=0.11.4
 cachetools>=4.2.1
 cloudpickle
diff --git a/requirements.txt b/requirements.txt
index 6027a8f6f3..46ec438b54 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
 aeppl==0.0.18
-aesara==2.3.6
+aesara==2.3.8
 arviz>=0.11.4
 cachetools>=4.2.1
 cloudpickle

From 5fd94ca2647aff00343aefa8a001fec9c69c1084 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 24 Jan 2022 12:24:33 +0100
Subject: [PATCH 04/10] Refactor LKJCorr distribution to V4

---
 pymc/distributions/multivariate.py      | 205 ++++++++++++++----------
 pymc/tests/test_distributions.py        |   3 +-
 pymc/tests/test_distributions_random.py |  61 ++++---
 3 files changed, 156 insertions(+), 113 deletions(-)

diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index 7ef3dd1c19..8e7432f637 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -25,7 +25,7 @@
 import scipy
 
 from aesara.assert_op import Assert
-from aesara.graph.basic import Apply
+from aesara.graph.basic import Apply, Constant
 from aesara.graph.op import Op
 from aesara.sparse.basic import sp_sum
 from aesara.tensor import gammaln, sigmoid
@@ -43,7 +43,12 @@
 
 from pymc.aesaraf import floatX, intX
 from pymc.distributions import transforms
-from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
+from pymc.distributions.continuous import (
+    BoundedContinuous,
+    ChiSquared,
+    Normal,
+    assert_negative_support,
+)
 from pymc.distributions.dist_math import (
     betaln,
     check_parameters,
@@ -57,7 +62,9 @@
     rv_size_is_none,
     to_tuple,
 )
+from pymc.distributions.transforms import interval
 from pymc.math import kron_diag, kron_dot
+from pymc.util import UNSET
 
 __all__ = [
     "MvNormal",
@@ -1079,6 +1086,11 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
 
 
 def _lkj_normalizing_constant(eta, n):
+    # TODO: This is mixing python branching with the potentially symbolic n and eta variables
+    if not isinstance(eta, (int, float)):
+        raise NotImplementedError("eta must be an int or float")
+    if not isinstance(n, int):
+        raise NotImplementedError("n must be an integer")
     if eta == 1:
         result = gammaln(2.0 * at.arange(1, int((n - 1) / 2) + 1)).sum()
         if n % 2 == 1:
@@ -1431,7 +1443,74 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
         return chol, corr, stds
 
 
-class LKJCorr(Continuous):
+class LKJCorrRV(RandomVariable):
+    name = "lkjcorr"
+    ndim_supp = 1
+    ndims_params = [0, 0]
+    dtype = "floatX"
+    _print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}")
+
+    def make_node(self, rng, size, dtype, n, eta):
+        n = at.as_tensor_variable(n)
+        if not n.ndim == 0:
+            raise ValueError("n must be a scalar (ndim=0).")
+
+        eta = at.as_tensor_variable(eta)
+        if not eta.ndim == 0:
+            raise ValueError("eta must be a scalar (ndim=0).")
+
+        return super().make_node(rng, size, dtype, n, eta)
+
+    def _shape_from_params(self, dist_params, **kwargs):
+        n = dist_params[0]
+        dist_shape = ((n * (n - 1)) // 2,)
+        return dist_shape
+
+    @classmethod
+    def rng_fn(cls, rng, n, eta, size):
+
+        # We flatten the size to make operations easier, and then rebuild it
+        if size is None:
+            flat_size = 1
+        else:
+            flat_size = np.prod(size)
+
+        C = cls._random_corr_matrix(rng, n, eta, flat_size)
+
+        triu_idx = np.triu_indices(n, k=1)
+        samples = C[..., triu_idx[0], triu_idx[1]]
+
+        if size is None:
+            samples = samples[0]
+        else:
+            dist_shape = (n * (n - 1)) // 2
+            samples = np.reshape(samples, (*size, dist_shape))
+        return samples
+
+    @classmethod
+    def _random_corr_matrix(cls, rng, n, eta, flat_size):
+        # original implementation in R see:
+        # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
+        beta = eta - 1.0 + n / 2.0
+        r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=flat_size, random_state=rng) - 1.0
+        P = np.full((flat_size, n, n), np.eye(n))
+        P[..., 0, 1] = r12
+        P[..., 1, 1] = np.sqrt(1.0 - r12 ** 2)
+        for mp1 in range(2, n):
+            beta -= 0.5
+            y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=flat_size, random_state=rng)
+            z = stats.norm.rvs(loc=0, scale=1, size=(flat_size, mp1), random_state=rng)
+            z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis]
+            P[..., 0:mp1, mp1] = np.sqrt(y[..., np.newaxis]) * z
+            P[..., mp1, mp1] = np.sqrt(1.0 - y)
+        C = np.einsum("...ji,...jk->...ik", P, P)
+        return C
+
+
+lkjcorr = LKJCorrRV()
+
+
+class LKJCorr(BoundedContinuous):
     r"""
     The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
 
@@ -1473,112 +1552,60 @@ class LKJCorr(Continuous):
         100(9), pp.1989-2001.
     """
 
-    def __init__(self, eta=None, n=None, p=None, transform="interval", *args, **kwargs):
-        if (p is not None) and (n is not None) and (eta is None):
-            warnings.warn(
-                "Parameters to LKJCorr have changed: shape parameter n -> eta "
-                "dimension parameter p -> n. Please update your code. "
-                "Automatically re-assigning parameters for backwards compatibility.",
-                FutureWarning,
-            )
-            self.n = p
-            self.eta = n
-            eta = self.eta
-            n = self.n
-        elif (n is not None) and (eta is not None) and (p is None):
-            self.n = n
-            self.eta = eta
-        else:
-            raise ValueError(
-                "Invalid parameter: please use eta as the shape parameter and "
-                "n as the dimension parameter."
-            )
-
-        shape = n * (n - 1) // 2
-        self.mean = floatX(np.zeros(shape))
-
-        if transform == "interval":
-            transform = transforms.interval(-1, 1)
-
-        super().__init__(shape=shape, transform=transform, *args, **kwargs)
-        warnings.warn(
-            "Parameters in LKJCorr have been rename: shape parameter n -> eta "
-            "dimension parameter p -> n. Please double check your initialization.",
-            FutureWarning,
-        )
-        self.tri_index = np.zeros([n, n], dtype="int32")
-        self.tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
-        self.tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
-
-    def _random(self, n, eta, size=None):
-        size = size if isinstance(size, tuple) else (size,)
-        # original implementation in R see:
-        # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
-        beta = eta - 1.0 + n / 2.0
-        r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size) - 1.0
-        P = np.eye(n)[:, :, np.newaxis] * np.ones(size)
-        P[0, 1] = r12
-        P[1, 1] = np.sqrt(1.0 - r12 ** 2)
-        for mp1 in range(2, n):
-            beta -= 0.5
-            y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size)
-            z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size)
-            z = z / np.sqrt(np.einsum("ij,ij->j", z, z))
-            P[0:mp1, mp1] = np.sqrt(y) * z
-            P[mp1, mp1] = np.sqrt(1.0 - y)
-        C = np.einsum("ji...,jk...->...ik", P, P)
-        triu_idx = np.triu_indices(n, k=1)
-        return C[..., triu_idx[0], triu_idx[1]]
+    rv_op = lkjcorr
 
-    def random(self, point=None, size=None):
-        """
-        Draw random values from LKJ distribution.
+    def __new__(cls, *args, **kwargs):
+        transform = kwargs.get("transform", UNSET)
+        if transform is UNSET:
+            kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0)))
+        return super().__new__(cls, *args, **kwargs)
 
-        Parameters
-        ----------
-        point: dict, optional
-            Dict of variable values on which random values are to be
-            conditioned (uses default point if not specified).
-        size: int, optional
-            Desired size of random sample (returns one sample if not
-            specified).
-
-        Returns
-        -------
-        array
-        """
-        # n, eta = draw_values([self.n, self.eta], point=point, size=size)
-        # size = 1 if size is None else size
-        # samples = generate_samples(self._random, n, eta, broadcast_shape=(size,))
-        # return samples
+    @classmethod
+    def dist(cls, n, eta, **kwargs):
+        n = at.as_tensor_variable(intX(n))
+        eta = at.as_tensor_variable(floatX(eta))
+        return super().dist([n, eta], **kwargs)
 
-    def logp(self, x):
+    def logp(value, n, eta):
         """
         Calculate log-probability of LKJ distribution at specified
         value.
 
         Parameters
         ----------
-        x: numeric
+        value: numeric
             Value for which log-probability is calculated.
 
         Returns
         -------
         TensorVariable
         """
-        n = self.n
-        eta = self.eta
 
-        X = x[self.tri_index]
-        X = at.fill_diagonal(X, 1)
+        # TODO: Aesara does not have a `triu_indices`, so we can only work with constant
+        #  n (or else find a different expression)
+        if not isinstance(n, Constant):
+            raise NotImplementedError("logp only implemented for constant `n`")
+
+        n = int(n.data)
+        shape = n * (n - 1) // 2
+        tri_index = np.zeros((n, n), dtype="int32")
+        tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
+        tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
 
+        value = at.take(value, tri_index)
+        value = at.fill_diagonal(value, 1)
+
+        # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
+        if not isinstance(eta, Constant):
+            raise NotImplementedError("logp only implemented for constant `eta`")
+        eta = float(eta.data)
         result = _lkj_normalizing_constant(eta, n)
-        result += (eta - 1.0) * at.log(det(X))
+        result += (eta - 1.0) * at.log(det(value))
         return check_parameters(
             result,
-            X >= -1,
-            X <= 1,
-            matrix_pos_def(X),
+            value >= -1,
+            value <= 1,
+            matrix_pos_def(value),
             eta > 0,
         )
 
diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py
index 3903c207de..62b151dd69 100644
--- a/pymc/tests/test_distributions.py
+++ b/pymc/tests/test_distributions.py
@@ -2094,8 +2094,7 @@ def test_wishart(self, n):
         )
 
     @pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES)
-    @pytest.mark.xfail(reason="Distribution not refactored yet")
-    def test_lkj(self, x, eta, n, lp):
+    def test_lkjcorr(self, x, eta, n, lp):
         with Model() as model:
             LKJCorr("lkj", eta=eta, n=n, transform=None)
 
diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index 7232657e59..6ac217ac7c 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -1828,29 +1828,46 @@ def kronecker_rng_fn(self, size, mu, covs=None, sigma=None, rng=None):
     ]
 
 
-class TestScalarParameterSamples(SeededTest):
-    @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
-    def test_lkj(self):
-        for n in [2, 10, 50]:
-            # pylint: disable=cell-var-from-loop
-            shape = n * (n - 1) // 2
-
-            def ref_rand(size, eta):
-                beta = eta - 1 + n / 2
-                return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2
-
-            class TestedLKJCorr(pm.LKJCorr):
-                def __init__(self, **kwargs):
-                    kwargs.pop("shape", None)
-                    super().__init__(n=n, **kwargs)
-
-            pymc_random(
-                TestedLKJCorr,
-                {"eta": Domain([1.0, 10.0, 100.0])},
-                size=10000 // n,
-                ref_rand=ref_rand,
-            )
+class TestLKJCorr(BaseTestDistributionRandom):
+    pymc_dist = pm.LKJCorr
+    pymc_dist_params = {"n": 3, "eta": 1.0}
+    expected_rv_op_params = {"n": 3, "eta": 1.0}
+
+    sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
+    sizes_expected = [
+        (3,),
+        (3,),
+        (1, 3),
+        (1, 3),
+        (5, 3),
+        (4, 5, 3),
+        (2, 4, 2, 3),
+    ]
+
+    tests_to_run = [
+        "check_pymc_params_match_rv_op",
+        "check_rv_size",
+        "check_draws_match_expected",
+    ]
 
+    def check_draws_match_expected(self):
+        def ref_rand(size, n, eta):
+            shape = int(n * (n - 1) // 2)
+            beta = eta - 1 + n / 2
+            return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2
+
+        pymc_random(
+            pm.LKJCorr,
+            {
+                "n": Domain([2, 10, 50], edges=(None, None)),
+                "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)),
+            },
+            ref_rand=ref_rand,
+            size=1000,
+        )
+
+
+class TestScalarParameterSamples(SeededTest):
     @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
     def test_normalmixture(self):
         def ref_rand(size, w, mu, sigma):

From eed60c378cb07f017ead5bacbee1090c07a0907a Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Mon, 24 Jan 2022 15:55:35 +0100
Subject: [PATCH 05/10] Add LKJCorr moment

---
 pymc/distributions/multivariate.py       |  3 +++
 pymc/tests/test_distributions_moments.py | 19 +++++++++++++++++--
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index 8e7432f637..722d94b5c4 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -1566,6 +1566,9 @@ def dist(cls, n, eta, **kwargs):
         eta = at.as_tensor_variable(floatX(eta))
         return super().dist([n, eta], **kwargs)
 
+    def get_moment(rv, *args):
+        return at.zeros_like(rv)
+
     def logp(value, n, eta):
         """
         Calculate log-probability of LKJ distribution at specified
diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py
index 9f31ef768e..0dc6789b70 100644
--- a/pymc/tests/test_distributions_moments.py
+++ b/pymc/tests/test_distributions_moments.py
@@ -39,12 +39,14 @@
     KroneckerNormal,
     Kumaraswamy,
     Laplace,
+    LKJCorr,
     Logistic,
     LogitNormal,
     LogNormal,
     MatrixNormal,
     Moyal,
     Multinomial,
+    MvNormal,
     MvStudentT,
     NegativeBinomial,
     Normal,
@@ -68,7 +70,6 @@
 )
 from pymc.distributions.distribution import _get_moment, get_moment
 from pymc.distributions.logprob import joint_logpt
-from pymc.distributions.multivariate import MvNormal
 from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
 from pymc.initial_point import make_initial_point_fn
 from pymc.model import Model
@@ -97,7 +98,6 @@ def test_all_distributions_have_moments():
 
     # Distributions that have not been refactored for V4 yet
     not_implemented = {
-        dist_module.multivariate.LKJCorr,
         dist_module.mixture.Mixture,
         dist_module.mixture.MixtureSameFamily,
         dist_module.mixture.NormalMixture,
@@ -1424,3 +1424,18 @@ def test_kronecker_normal_moments(mu, covs, size, expected):
     with Model() as model:
         KroneckerNormal("x", mu=mu, covs=covs, size=size)
     assert_moment_is_expected(model, expected)
+
+
+@pytest.mark.parametrize(
+    "n, eta, size, expected",
+    [
+        (3, 1, None, np.zeros(3)),
+        (5, 1, None, np.zeros(10)),
+        (3, 1, 1, np.zeros((1, 3))),
+        (5, 1, (2, 3), np.zeros((2, 3, 10))),
+    ],
+)
+def test_lkjcorr_moment(n, eta, size, expected):
+    with Model() as model:
+        LKJCorr("x", n=n, eta=eta, size=size)
+    assert_moment_is_expected(model, expected)

From 1a35a3dfc0f125dfa1515e3c8d3575c9f17a3074 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Tue, 25 Jan 2022 13:40:01 +0100
Subject: [PATCH 06/10] Refactor LKJCholeskyCov for V4

Changes:
* compute_corr now defaults to True
* LKJCholeskyCov now also provides a `.dist` interface
---
 RELEASE-NOTES.md                        |   1 +
 pymc/distributions/multivariate.py      | 295 +++++++++++-------------
 pymc/distributions/transforms.py        |  26 ++-
 pymc/tests/sampler_fixtures.py          |   4 +-
 pymc/tests/test_distributions.py        |  37 +++
 pymc/tests/test_distributions_random.py |  51 +++-
 pymc/tests/test_idata_conversion.py     |   3 +-
 pymc/tests/test_mixture.py              |   2 +-
 pymc/tests/test_posteriors.py           |   1 -
 9 files changed, 248 insertions(+), 172 deletions(-)

diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md
index cf4e7976df..92b95d3814 100644
--- a/RELEASE-NOTES.md
+++ b/RELEASE-NOTES.md
@@ -87,6 +87,7 @@ All of the above apply to:
 - ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc/pull/4471) and `3.11.2` release notes).
 - `pm.sample_posterior_predictive(vars=...)` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc/pull/4343)).
 - `ElemwiseCategorical` step method was removed (see [#4701](https://github.com/pymc-devs/pymc/pull/4701))
+- `LKJCholeskyCov` `compute_corr` keyword argument is now set to `True` by default (see[#5382](https://github.com/pymc-devs/pymc/pull/5382))
 
 ### Ongoing deprecations
 - Old API still works in `v4` and has a deprecation warning.
diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index 722d94b5c4..7c1e53a55f 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -24,9 +24,9 @@
 import numpy as np
 import scipy
 
-from aesara.assert_op import Assert
-from aesara.graph.basic import Apply, Constant
+from aesara.graph.basic import Apply, Constant, Variable
 from aesara.graph.op import Op
+from aesara.raise_op import Assert
 from aesara.sparse.basic import sp_sum
 from aesara.tensor import gammaln, sigmoid
 from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
@@ -64,7 +64,7 @@
 )
 from pymc.distributions.transforms import interval
 from pymc.math import kron_diag, kron_dot
-from pymc.util import UNSET
+from pymc.util import UNSET, check_dist_not_registered
 
 __all__ = [
     "MvNormal",
@@ -1113,67 +1113,122 @@ def _lkj_normalizing_constant(eta, n):
     return result
 
 
+class _LKJCholeskyCovRV(RandomVariable):
+    name = "_lkjcholeskycov"
+    ndim_supp = 1
+    ndims_params = [0, 0, 1]
+    dtype = "floatX"
+    _print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}")
+
+    def make_node(self, rng, size, dtype, n, eta, D):
+        n = at.as_tensor_variable(n)
+        if not n.ndim == 0:
+            raise ValueError("n must be a scalar (ndim=0).")
+
+        eta = at.as_tensor_variable(eta)
+        if not eta.ndim == 0:
+            raise ValueError("eta must be a scalar (ndim=0).")
+
+        D = at.as_tensor_variable(D)
+
+        return super().make_node(rng, size, dtype, n, eta, D)
+
+    def _infer_shape(self, size, dist_params, param_shapes=None):
+        n = dist_params[0]
+        dist_shape = tuple(size) + ((n * (n + 1)) // 2,)
+        return dist_shape
+
+    def rng_fn(self, rng, n, eta, D, size):
+        # We flatten the size to make operations easier, and then rebuild it
+        if size is None:
+            flat_size = 1
+        else:
+            flat_size = np.prod(size)
+
+        C = LKJCorrRV._random_corr_matrix(rng, n, eta, flat_size)
+
+        D = D.reshape(flat_size, n)
+        C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]
+
+        tril_idx = np.tril_indices(n, k=0)
+        samples = np.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]]
+
+        if size is None:
+            samples = samples[0]
+        else:
+            dist_shape = (n * (n + 1)) // 2
+            samples = np.reshape(samples, (*size, dist_shape))
+
+        return samples
+
+
+_ljk_cholesky_cov = _LKJCholeskyCovRV()
+
+
 class _LKJCholeskyCov(Continuous):
     r"""Underlying class for covariance matrix with LKJ distributed correlations.
     See docs for LKJCholeskyCov function for more details on how to use it in models.
     """
+    rv_op = _ljk_cholesky_cov
 
-    def __init__(self, eta, n, sd_dist, *args, **kwargs):
-        self.n = at.as_tensor_variable(n)
-        self.eta = at.as_tensor_variable(eta)
-
-        if "transform" in kwargs and kwargs["transform"] is not None:
-            raise ValueError("Invalid parameter: transform.")
-        if "shape" in kwargs:
-            raise ValueError("Invalid parameter: shape.")
-
-        shape = n * (n + 1) // 2
+    def __new__(cls, name, eta, n, sd_dist, **kwargs):
+        transform = kwargs.get("transform", UNSET)
+        if transform is UNSET:
+            kwargs["transform"] = transforms.CholeskyCovPacked(n)
 
-        if sd_dist.shape.ndim not in [0, 1]:
-            raise ValueError("Invalid shape for sd_dist.")
+        check_dist_not_registered(sd_dist)
 
-        def transform_params(rv_var):
-            _, _, _, n, eta = rv_var.owner.inputs
-            return np.arange(1, n + 1).cumsum() - 1
+        return super().__new__(cls, name, eta, n, sd_dist, **kwargs)
 
-        transform = transforms.CholeskyCovPacked(transform_params)
+    @classmethod
+    def dist(cls, eta, n, sd_dist, size=None, **kwargs):
+        eta = at.as_tensor_variable(floatX(eta))
+        n = at.as_tensor_variable(intX(n))
 
-        kwargs["shape"] = shape
-        kwargs["transform"] = transform
-        super().__init__(*args, **kwargs)
+        if not (
+            isinstance(sd_dist, Variable)
+            and sd_dist.owner is not None
+            and isinstance(sd_dist.owner.op, RandomVariable)
+        ):
+            raise TypeError("sd_dist must be a Distribution variable")
 
-        self.sd_dist = sd_dist
-        self.diag_idxs = transform.diag_idxs
+        # sd_dist is part of the generative graph, but should be completely ignored
+        # by the logp graph, since the LKJ logp explicitly includes these terms.
+        # Setting sd_dist.tag.ignore_logprob to True, will prevent Aeppl warning about
+        # an unnacounted RandomVariable in the graph
+        # TODO: Things could be simplified a bit if we managed to extract the
+        #  sd_dist prior components from the logp expression.
+        sd_dist.tag.ignore_logprob = True
 
-        self.mode = floatX(np.zeros(shape))
-        self.mode[self.diag_idxs] = 1
+        return super().dist([n, eta, sd_dist], size=size, **kwargs)
 
-    def logp(self, x):
+    def logp(value, n, eta, sd_dist):
         """
         Calculate log-probability of Covariance matrix with LKJ
         distributed correlations at specified value.
 
         Parameters
         ----------
-        x: numeric
+        value: numeric
             Value for which log-probability is calculated.
 
         Returns
         -------
         TensorVariable
         """
-        n = self.n
-        eta = self.eta
 
-        diag_idxs = self.diag_idxs
-        cumsum = at.cumsum(x ** 2)
-        variance = at.zeros(n)
-        variance = at.inc_subtensor(variance[0], x[0] ** 2)
+        if value.ndim > 1:
+            raise ValueError("LKJCholeskyCov logp is only implemented for vector values (ndim=1)")
+
+        diag_idxs = at.cumsum(at.arange(1, n + 1)) - 1
+        cumsum = at.cumsum(value ** 2)
+        variance = at.zeros(at.atleast_1d(n))
+        variance = at.inc_subtensor(variance[0], value[0] ** 2)
         variance = at.inc_subtensor(variance[1:], cumsum[diag_idxs[1:]] - cumsum[diag_idxs[:-1]])
         sd_vals = at.sqrt(variance)
 
-        logp_sd = self.sd_dist.logp(sd_vals).sum()
-        corr_diag = x[diag_idxs] / sd_vals
+        logp_sd = pm.logp(sd_dist, sd_vals).sum()
+        corr_diag = value[diag_idxs] / sd_vals
 
         logp_lkj = (2 * eta - 3 + n - at.arange(n)) * at.log(corr_diag)
         logp_lkj = at.sum(logp_lkj)
@@ -1184,114 +1239,22 @@ def logp(self, x):
         det_invjac = at.log(corr_diag) - idx * at.log(sd_vals)
         det_invjac = det_invjac.sum()
 
-        norm = _lkj_normalizing_constant(eta, n)
+        # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
+        if not isinstance(n, Constant):
+            raise NotImplementedError("logp only implemented for constant `n`")
+        n = int(n.data)
 
-        return norm + logp_lkj + logp_sd + det_invjac
+        if not isinstance(eta, Constant):
+            raise NotImplementedError("logp only implemented for constant `eta`")
+        eta = float(eta.data)
 
-    def _random(self, n, eta, size=1):
-        eta_sample_shape = (size,) + eta.shape
-        P = np.eye(n) * np.ones(eta_sample_shape + (n, n))
-        # original implementation in R see:
-        # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
-        beta = eta - 1.0 + n / 2.0
-        r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=eta_sample_shape) - 1.0
-        P[..., 0, 1] = r12
-        P[..., 1, 1] = np.sqrt(1.0 - r12 ** 2)
-        for mp1 in range(2, n):
-            beta -= 0.5
-            y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=eta_sample_shape)
-            z = stats.norm.rvs(loc=0, scale=1, size=eta_sample_shape + (mp1,))
-            z = z / np.sqrt(np.einsum("ij,ij->j", z, z))
-            P[..., 0:mp1, mp1] = np.sqrt(y[..., np.newaxis]) * z
-            P[..., mp1, mp1] = np.sqrt(1.0 - y)
-        C = np.einsum("...ji,...jk->...ik", P, P)
-        D = np.atleast_1d(self.sd_dist.random(size=P.shape[:-2]))
-        if D.shape in [tuple(), (1,)]:
-            D = self.sd_dist.random(size=P.shape[:-1])
-        elif D.ndim < C.ndim - 1:
-            D = [D] + [self.sd_dist.random(size=P.shape[:-2]) for _ in range(n - 1)]
-            D = np.moveaxis(np.array(D), 0, C.ndim - 2)
-        elif D.ndim == C.ndim - 1:
-            if D.shape[-1] == 1:
-                D = [D] + [self.sd_dist.random(size=P.shape[:-2]) for _ in range(n - 1)]
-                D = np.concatenate(D, axis=-1)
-            elif D.shape[-1] != n:
-                raise ValueError(
-                    "The size of the samples drawn from the "
-                    "supplied sd_dist.random have the wrong "
-                    "size. Expected {} but got {} instead.".format(n, D.shape[-1])
-                )
-        else:
-            raise ValueError(
-                "Supplied sd_dist.random generates samples with "
-                "too many dimensions. It must yield samples "
-                "with 0 or 1 dimensions. Got {} instead".format(D.ndim - C.ndim - 2)
-            )
-        C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]
-        tril_idx = np.tril_indices(n, k=0)
-        return np.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]]
+        norm = _lkj_normalizing_constant(eta, n)
 
-    def random(self, point=None, size=None):
-        """
-        Draw random values from Covariance matrix with LKJ
-        distributed correlations.
+        return norm + logp_lkj + logp_sd + det_invjac
 
-        Parameters
-        ----------
-        point: dict, optional
-            Dict of variable values on which random values are to be
-            conditioned (uses default point if not specified).
-        size: int, optional
-            Desired size of random sample (returns one sample if not
-            specified).
 
-        Returns
-        -------
-        array
-        """
-        # # Get parameters and broadcast them
-        # n, eta = draw_values([self.n, self.eta], point=point, size=size)
-        # broadcast_shape = np.broadcast(n, eta).shape
-        # # We can only handle cov matrices with a constant n per random call
-        # n = np.unique(n)
-        # if len(n) > 1:
-        #     raise RuntimeError("Varying n is not supported for LKJCholeskyCov")
-        # n = int(n[0])
-        # dist_shape = ((n * (n + 1)) // 2,)
-        # # We make sure that eta and the drawn n get their shapes broadcasted
-        # eta = np.broadcast_to(eta, broadcast_shape)
-        # # We change the size of the draw depending on the broadcast shape
-        # sample_shape = broadcast_shape + dist_shape
-        # if size is not None:
-        #     if not isinstance(size, tuple):
-        #         try:
-        #             size = tuple(size)
-        #         except TypeError:
-        #             size = (size,)
-        #     if size == sample_shape:
-        #         size = None
-        #     elif size == broadcast_shape:
-        #         size = None
-        #     elif size[-len(sample_shape) :] == sample_shape:
-        #         size = size[: len(size) - len(sample_shape)]
-        #     elif size[-len(broadcast_shape) :] == broadcast_shape:
-        #         size = size[: len(size) - len(broadcast_shape)]
-        # # We will always provide _random with an integer size and then reshape
-        # # the output to get the correct size
-        # if size is not None:
-        #     _size = np.prod(size)
-        # else:
-        #     _size = 1
-        # samples = self._random(n, eta, size=_size)
-        # if size is None:
-        #     samples = samples[0]
-        # else:
-        #     samples = np.reshape(samples, size + sample_shape)
-        # return samples
-
-
-def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=True, *args, **kwargs):
-    r"""Wrapper function for covariance matrix with LKJ distributed correlations.
+class LKJCholeskyCov:
+    r"""Wrapper class for covariance matrix with LKJ distributed correlations.
 
     This defines a distribution over Cholesky decomposed covariance
     matrices, such that the underlying correlation matrices follow an
@@ -1309,11 +1272,11 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
     n: int
         Dimension of the covariance matrix (n > 1).
     sd_dist: pm.Distribution
-        A distribution for the standard deviations.
-    compute_corr: bool, default=False
+        A distribution for the standard deviations, should have `size=n`.
+    compute_corr: bool, default=True
         If `True`, returns three values: the Cholesky decomposition, the correlations
         and the standard deviations of the covariance matrix. Otherwise, only returns
-        the packed Cholesky decomposition. Defaults to `False` to ensure backwards
+        the packed Cholesky decomposition. Defaults to `True`.
         compatibility.
     store_in_trace: bool, default=True
         Whether to store the correlations and standard deviations of the covariance
@@ -1323,14 +1286,14 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
 
     Returns
     -------
-    packed_chol: TensorVariable
-        If `compute_corr=False` (default). The packed Cholesky covariance decomposition.
     chol:  TensorVariable
         If `compute_corr=True`. The unpacked Cholesky covariance decomposition.
     corr: TensorVariable
         If `compute_corr=True`. The correlations of the covariance matrix.
     stds: TensorVariable
         If `compute_corr=True`. The standard deviations of the covariance matrix.
+    packed_chol: TensorVariable
+        If `compute_corr=False` The packed Cholesky covariance decomposition.
 
     Notes
     -----
@@ -1355,12 +1318,15 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
         with pm.Model() as model:
             # Note that we access the distribution for the standard
             # deviations, and do not create a new random variable.
-            sd_dist = pm.Exponential.dist(1.0)
-            chol, corr, sigmas = pm.LKJCholeskyCov('chol_cov', eta=4, n=10,
-            sd_dist=sd_dist, compute_corr=True)
+            sd_dist = pm.Exponential.dist(1.0, size=10)
+            chol, corr, sigmas = pm.LKJCholeskyCov(
+                'chol_cov', eta=4, n=10, sd_dist=sd_dist
+            )
 
-            # if you only want the packed Cholesky (default behavior):
-            # packed_chol = pm.LKJCholeskyCov('chol_cov', eta=4, n=10, sd_dist=sd_dist)
+            # if you only want the packed Cholesky:
+            # packed_chol = pm.LKJCholeskyCov(
+                'chol_cov', eta=4, n=10, sd_dist=sd_dist, compute_corr=False
+            )
             # chol = pm.expand_packed_triangular(10, packed_chol, lower=True)
 
             # Define a new MvNormal with the given covariance
@@ -1423,12 +1389,29 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
        determinant, URL (version: 2012-04-14):
        http://math.stackexchange.com/q/130026
     """
-    # compute Cholesky decomposition
-    packed_chol = _LKJCholeskyCov(name, eta=eta, n=n, sd_dist=sd_dist)
-    if not compute_corr:
-        return packed_chol
 
-    else:
+    def __new__(cls, name, eta, n, sd_dist, *, compute_corr=True, store_in_trace=True, **kwargs):
+        packed_chol = _LKJCholeskyCov(name, eta=eta, n=n, sd_dist=sd_dist, **kwargs)
+        if not compute_corr:
+            return packed_chol
+        else:
+            chol, corr, stds = cls.helper_deterministics(n, packed_chol)
+            if store_in_trace:
+                corr = pm.Deterministic(f"{name}_corr", corr)
+                stds = pm.Deterministic(f"{name}_stds", stds)
+            return chol, corr, stds
+
+    @classmethod
+    def dist(cls, eta, n, sd_dist, *, compute_corr=True, **kwargs):
+        # compute Cholesky decomposition
+        packed_chol = _LKJCholeskyCov.dist(eta=eta, n=n, sd_dist=sd_dist, **kwargs)
+        if not compute_corr:
+            return packed_chol
+        else:
+            return cls.helper_deterministics(n, packed_chol)
+
+    @classmethod
+    def helper_deterministics(cls, n, packed_chol):
         chol = pm.expand_packed_triangular(n, packed_chol, lower=True)
         # compute covariance matrix
         cov = at.dot(chol, chol.T)
@@ -1436,10 +1419,6 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
         stds = at.sqrt(at.diag(cov))
         inv_stds = 1 / stds
         corr = inv_stds[None, :] * cov * inv_stds[:, None]
-        if store_in_trace:
-            stds = pm.Deterministic(f"{name}_stds", stds)
-            corr = pm.Deterministic(f"{name}_corr", corr)
-
         return chol, corr, stds
 
 
diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py
index 9b750ac7c9..ff287635e1 100644
--- a/pymc/distributions/transforms.py
+++ b/pymc/distributions/transforms.py
@@ -22,7 +22,6 @@
     RVTransform,
     Simplex,
 )
-from aesara.tensor.subtensor import advanced_set_subtensor1
 
 __all__ = [
     "RVTransform",
@@ -97,22 +96,31 @@ def log_jac_det(self, value, *inputs):
 
 
 class CholeskyCovPacked(RVTransform):
+    """
+    Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
+    log scale
+    """
+
     name = "cholesky-cov-packed"
 
-    def __init__(self, param_extract_fn):
-        self.param_extract_fn = param_extract_fn
+    def __init__(self, n):
+        """
+
+        Parameters
+        ----------
+        n: int
+            Number of diagonal entries in the LKJCholeskyCov distribution
+        """
+        self.diag_idxs = at.arange(1, n + 1).cumsum() - 1
 
     def backward(self, value, *inputs):
-        diag_idxs = self.param_extract_fn(inputs)
-        return advanced_set_subtensor1(value, at.exp(value[diag_idxs]), diag_idxs)
+        return at.set_subtensor(value[..., self.diag_idxs], at.exp(value[..., self.diag_idxs]))
 
     def forward(self, value, *inputs):
-        diag_idxs = self.param_extract_fn(inputs)
-        return advanced_set_subtensor1(value, at.log(value[diag_idxs]), diag_idxs)
+        return at.set_subtensor(value[..., self.diag_idxs], at.log(value[..., self.diag_idxs]))
 
     def log_jac_det(self, value, *inputs):
-        diag_idxs = self.param_extract_fn(inputs)
-        return at.sum(value[diag_idxs])
+        return at.sum(value[..., self.diag_idxs], axis=-1)
 
 
 class Chain(RVTransform):
diff --git a/pymc/tests/sampler_fixtures.py b/pymc/tests/sampler_fixtures.py
index ce5f4f0490..6bd085a59f 100644
--- a/pymc/tests/sampler_fixtures.py
+++ b/pymc/tests/sampler_fixtures.py
@@ -122,7 +122,9 @@ def make_model(cls):
         with pm.Model() as model:
             sd_mu = np.array([1, 2, 3, 4, 5])
             sd_dist = pm.LogNormal.dist(mu=sd_mu, sigma=sd_mu / 10.0, size=5)
-            chol_packed = pm.LKJCholeskyCov("chol_packed", eta=3, n=5, sd_dist=sd_dist)
+            chol_packed = pm.LKJCholeskyCov(
+                "chol_packed", eta=3, n=5, sd_dist=sd_dist, compute_corr=False
+            )
             chol = pm.expand_packed_triangular(5, chol_packed, lower=True)
             cov = at.dot(chol, chol.T)
             stds = at.sqrt(at.diag(cov))
diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py
index 62b151dd69..c84abf85c3 100644
--- a/pymc/tests/test_distributions.py
+++ b/pymc/tests/test_distributions.py
@@ -3352,3 +3352,40 @@ def test_censored_invalid_dist(self):
                 match="The dist dist was already registered in the current model",
             ):
                 x = pm.Censored("x", registered_dist, lower=None, upper=None)
+
+
+class TestLKJCholeskCov:
+    def test_dist(self):
+        sd_dist = pm.Exponential.dist(1, size=(10, 3))
+        x = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist, size=10, compute_corr=False)
+        assert x.eval().shape == (10, 6)
+
+        sd_dist = pm.Exponential.dist(1, size=3)
+        chol, corr, stds = pm.LKJCholeskyCov.dist(n=3, eta=1, sd_dist=sd_dist)
+        assert chol.eval().shape == (3, 3)
+        assert corr.eval().shape == (3, 3)
+        assert stds.eval().shape == (3,)
+
+    def test_sd_dist_distribution(self):
+        with pm.Model() as m:
+            sd_dist = at.constant([1, 2, 3])
+            with pytest.raises(TypeError, match="sd_dist must be a Distribution variable"):
+                x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
+
+    def test_sd_dist_registered(self):
+        with pm.Model() as m:
+            sd_dist = pm.Exponential("sd_dist", 1, size=3)
+            with pytest.raises(
+                ValueError, match="The dist sd_dist was already registered in the current model"
+            ):
+                x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
+
+    def test_no_warning_logp(self):
+        # Check that calling logp of a model with LKJCholeskyCov does not issue any warnings
+        # due to the RandomVariable in the graph
+        with pm.Model() as m:
+            sd_dist = pm.Exponential.dist(1, size=3)
+            x = pm.LKJCholeskyCov("x", n=3, eta=1, sd_dist=sd_dist)
+        with pytest.warns(None) as record:
+            m.logpt()
+        assert not record
diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index 6ac217ac7c..55b7c1ac3f 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -47,7 +47,11 @@ def random_polyagamma(*args, **kwargs):
 from pymc.distributions.discrete import _OrderedLogistic, _OrderedProbit
 from pymc.distributions.dist_math import clipped_beta_rvs
 from pymc.distributions.logprob import logp
-from pymc.distributions.multivariate import _OrderedMultinomial, quaddist_matrix
+from pymc.distributions.multivariate import (
+    _LKJCholeskyCov,
+    _OrderedMultinomial,
+    quaddist_matrix,
+)
 from pymc.distributions.shape_utils import to_tuple
 from pymc.tests.helpers import SeededTest, select_by_precision
 from pymc.tests.test_distributions import (
@@ -1867,6 +1871,43 @@ def ref_rand(size, n, eta):
         )
 
 
+class TestLKJCholeskyCov(BaseTestDistributionRandom):
+    pymc_dist = _LKJCholeskyCov
+    pymc_dist_params = {"eta": 1.0, "n": 3, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
+    expected_rv_op_params = {"n": 3, "eta": 1.0, "sd_dist": pm.Constant.dist([0.5, 1.0, 2.0])}
+    size = None
+
+    sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
+    sizes_expected = [
+        (6,),
+        (6,),
+        (1, 6),
+        (1, 6),
+        (5, 6),
+        (4, 5, 6),
+        (2, 4, 2, 6),
+    ]
+
+    tests_to_run = [
+        "check_rv_size",
+        "check_draws_match_expected",
+    ]
+
+    def check_rv_size(self):
+        for size, expected in zip(self.sizes_to_check, self.sizes_expected):
+            sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), 3))
+            pymc_rv = self.pymc_dist.dist(n=3, eta=1, sd_dist=sd_dist, size=size)
+            expected_symbolic = tuple(pymc_rv.shape.eval())
+            actual = pymc_rv.eval().shape
+            assert actual == expected_symbolic == expected
+
+    def check_draws_match_expected(self):
+        # TODO: Find better comparison:
+        rng = aesara.shared(self.get_random_state(reset=True))
+        x = _LKJCholeskyCov.dist(n=2, eta=10_000, sd_dist=pm.Constant.dist([0.5, 2.0]), rng=rng)
+        assert np.all(np.abs(x.eval() - np.array([0.5, 0, 2.0])) < 0.01)
+
+
 class TestScalarParameterSamples(SeededTest):
     @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
     def test_normalmixture(self):
@@ -2346,9 +2387,11 @@ def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
         with pm.Model() as model:
             mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
             sd_dist = pm.Exponential.dist(1.0, shape=3)
+            # pylint: disable=unpacking-non-sequence
             chol, corr, stds = pm.LKJCholeskyCov(
                 "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
             )
+            # pylint: enable=unpacking-non-sequence
             mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape)
             prior = pm.sample_prior_predictive(samples=sample_shape)
 
@@ -2363,9 +2406,11 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
         with pm.Model() as model:
             mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
             sd_dist = pm.Exponential.dist(1.0, shape=3)
+            # pylint: disable=unpacking-non-sequence
             chol, corr, stds = pm.LKJCholeskyCov(
                 "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
             )
+            # pylint: enable=unpacking-non-sequence
             mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
             prior = pm.sample_prior_predictive(samples=sample_shape)
 
@@ -2457,9 +2502,11 @@ def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
         with pm.Model() as model:
             mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
             sd_dist = pm.Exponential.dist(1.0, shape=3)
+            # pylint: disable=unpacking-non-sequence
             chol, corr, stds = pm.LKJCholeskyCov(
                 "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
             )
+            # pylint: enable=unpacking-non-sequence
             mv = pm.MvGaussianRandomWalk("mv", mu, chol=chol, shape=dist_shape)
             prior = pm.sample_prior_predictive(samples=sample_shape)
 
@@ -2475,9 +2522,11 @@ def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
         with pm.Model() as model:
             mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
             sd_dist = pm.Exponential.dist(1.0, shape=3)
+            # pylint: disable=unpacking-non-sequence
             chol, corr, stds = pm.LKJCholeskyCov(
                 "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
             )
+            # pylint: enable=unpacking-non-sequence
             mv = pm.MvGaussianRandomWalk("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
             prior = pm.sample_prior_predictive(samples=sample_shape)
 
diff --git a/pymc/tests/test_idata_conversion.py b/pymc/tests/test_idata_conversion.py
index 1a659108ba..c3105768c4 100644
--- a/pymc/tests/test_idata_conversion.py
+++ b/pymc/tests/test_idata_conversion.py
@@ -332,7 +332,6 @@ def test_missing_data_model(self):
         assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)
 
     @pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4")
-    @pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
     def test_mv_missing_data_model(self):
         data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)
 
@@ -340,7 +339,9 @@ def test_mv_missing_data_model(self):
         with model:
             mu = pm.Normal("mu", 0, 1, size=2)
             sd_dist = pm.HalfNormal.dist(1.0)
+            # pylint: disable=unpacking-non-sequence
             chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
+            # pylint: enable=unpacking-non-sequence
             y = pm.MvNormal("y", mu=mu, chol=chol, observed=data)
             inference_data = pm.sample(100, chains=2, return_inferencedata=True)
 
diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py
index 3262f83b8f..bd231959d4 100644
--- a/pymc/tests/test_mixture.py
+++ b/pymc/tests/test_mixture.py
@@ -368,7 +368,7 @@ def build_toy_dataset(N, K):
                 mu.append(pm.Normal("mu%i" % i, 0, 10, shape=D))
                 packed_chol.append(
                     pm.LKJCholeskyCov(
-                        "chol_cov_%i" % i, eta=2, n=D, sd_dist=pm.HalfNormal.dist(2.5)
+                        "chol_cov_%i" % i, eta=2, n=D, sd_dist=pm.HalfNormal.dist(2.5, size=D)
                     )
                 )
                 chol.append(pm.expand_packed_triangular(D, packed_chol[i], lower=True))
diff --git a/pymc/tests/test_posteriors.py b/pymc/tests/test_posteriors.py
index 8122588a23..312e1e63c2 100644
--- a/pymc/tests/test_posteriors.py
+++ b/pymc/tests/test_posteriors.py
@@ -95,7 +95,6 @@ class TestNUTSNormalLong(sf.NutsFixture, sf.NormalFixture):
     atol = 0.001
 
 
-@pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
 class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture):
     n_samples = 2000
     tune = 1000

From 6f13f7eb5dd8619ace99253f9e3f40d7f3766282 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Tue, 25 Jan 2022 14:45:37 +0100
Subject: [PATCH 07/10] Add LKJCholeskyCov moment

---
 pymc/distributions/multivariate.py       |  6 ++++++
 pymc/tests/test_distributions_moments.py | 22 ++++++++++++++++++++++
 2 files changed, 28 insertions(+)

diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index 7c1e53a55f..45d634449f 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -1202,6 +1202,12 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
 
         return super().dist([n, eta, sd_dist], size=size, **kwargs)
 
+    def get_moment(rv, size, n, eta, sd_dists):
+        diag_idxs = (at.cumsum(at.arange(1, n + 1)) - 1).astype("int32")
+        moment = at.zeros_like(rv)
+        moment = at.set_subtensor(moment[..., diag_idxs], 1)
+        return moment
+
     def logp(value, n, eta, sd_dist):
         """
         Calculate log-probability of Covariance matrix with LKJ
diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py
index 0dc6789b70..0f163d73c2 100644
--- a/pymc/tests/test_distributions_moments.py
+++ b/pymc/tests/test_distributions_moments.py
@@ -39,6 +39,7 @@
     KroneckerNormal,
     Kumaraswamy,
     Laplace,
+    LKJCholeskyCov,
     LKJCorr,
     Logistic,
     LogitNormal,
@@ -1439,3 +1440,24 @@ def test_lkjcorr_moment(n, eta, size, expected):
     with Model() as model:
         LKJCorr("x", n=n, eta=eta, size=size)
     assert_moment_is_expected(model, expected)
+
+
+@pytest.mark.parametrize(
+    "n, eta, size, expected",
+    [
+        (3, 1, None, np.array([1, 0, 1, 0, 0, 1])),
+        (4, 1, None, np.array([1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])),
+        (3, 1, 1, np.array([[1, 0, 1, 0, 0, 1]])),
+        (
+            4,
+            1,
+            (2, 3),
+            np.full((2, 3, 10), np.array([1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])),
+        ),
+    ],
+)
+def test_lkjcholeskycov_moment(n, eta, size, expected):
+    with Model() as model:
+        sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), n))
+        LKJCholeskyCov("x", n=n, eta=eta, sd_dist=sd_dist, size=size, compute_corr=False)
+    assert_moment_is_expected(model, expected, check_finite_logp=size is None)

From 45bbc9e1eb03711e060b729c8261034b10836c13 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Tue, 25 Jan 2022 13:40:43 +0100
Subject: [PATCH 08/10] Reenable old MvNormal tests

---
 pymc/tests/test_distributions_random.py | 158 +++++++++++-------------
 1 file changed, 69 insertions(+), 89 deletions(-)

diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index 55b7c1ac3f..d3e51f0d27 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -1125,6 +1125,75 @@ class TestMvNormalTau(BaseTestDistributionRandom):
     checks_to_run = ["check_pymc_params_match_rv_op"]
 
 
+class TestMvNormalMisc:
+    def test_with_chol_rv(self):
+        with pm.Model() as model:
+            mu = pm.Normal("mu", 0.0, 1.0, size=3)
+            sd_dist = pm.Exponential.dist(1.0, size=3)
+            # pylint: disable=unpacking-non-sequence
+            chol, _, _ = pm.LKJCholeskyCov(
+                "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
+            )
+            # pylint: enable=unpacking-non-sequence
+            mv = pm.MvNormal("mv", mu, chol=chol, size=4)
+            prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)
+
+        assert prior["mv"].shape == (10, 4, 3)
+
+    def test_with_cov_rv(
+        self,
+    ):
+        with pm.Model() as model:
+            mu = pm.Normal("mu", 0.0, 1.0, shape=3)
+            sd_dist = pm.Exponential.dist(1.0, shape=3)
+            # pylint: disable=unpacking-non-sequence
+            chol, corr, stds = pm.LKJCholeskyCov(
+                "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
+            )
+            # pylint: enable=unpacking-non-sequence
+            mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), size=4)
+            prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)
+
+        assert prior["mv"].shape == (10, 4, 3)
+
+    def test_issue_3758(self):
+        np.random.seed(42)
+        ndim = 50
+        with pm.Model() as model:
+            a = pm.Normal("a", sigma=100, shape=ndim)
+            b = pm.Normal("b", mu=a, sigma=1, shape=ndim)
+            c = pm.MvNormal("c", mu=a, chol=np.linalg.cholesky(np.eye(ndim)), shape=ndim)
+            d = pm.MvNormal("d", mu=a, cov=np.eye(ndim), shape=ndim)
+            samples = pm.sample_prior_predictive(1000, return_inferencedata=False)
+
+        for var in "abcd":
+            assert not np.isnan(np.std(samples[var]))
+
+        for var in "bcd":
+            std = np.std(samples[var] - samples["a"])
+            npt.assert_allclose(std, 1, rtol=1e-2)
+
+    def test_issue_3829(self):
+        with pm.Model() as model:
+            x = pm.MvNormal("x", mu=np.zeros(5), cov=np.eye(5), shape=(2, 5))
+            trace_pp = pm.sample_prior_predictive(50, return_inferencedata=False)
+
+        assert np.shape(trace_pp["x"][0]) == (2, 5)
+
+    def test_issue_3706(self):
+        N = 10
+        Sigma = np.eye(2)
+
+        with pm.Model() as model:
+            X = pm.MvNormal("X", mu=np.zeros(2), cov=Sigma, shape=(N, 2))
+            betas = pm.Normal("betas", 0, 1, shape=2)
+            y = pm.Deterministic("y", pm.math.dot(X, betas))
+
+            prior_pred = pm.sample_prior_predictive(1, return_inferencedata=False)
+
+        assert prior_pred["X"].shape == (1, N, 2)
+
+
 class TestMvStudentTCov(BaseTestDistributionRandom):
     def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
         chi2_samples = rng.chisquare(nu, size=size)
@@ -2366,94 +2435,6 @@ def generate_shapes(include_params=False):
     return data
 
 
-@pytest.mark.skip(reason="This test is covered by Aesara")
-class TestMvNormal(SeededTest):
-    @pytest.mark.parametrize(
-        ["sample_shape", "dist_shape", "mu_shape", "param"],
-        generate_shapes(include_params=True),
-        ids=str,
-    )
-    def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param):
-        dist = pm.MvNormal.dist(mu=np.ones(mu_shape), **{param: np.eye(3)}, shape=dist_shape)
-        output_shape = to_tuple(sample_shape) + dist_shape
-        assert dist.random(size=sample_shape).shape == output_shape
-
-    @pytest.mark.parametrize(
-        ["sample_shape", "dist_shape", "mu_shape"],
-        generate_shapes(include_params=False),
-        ids=str,
-    )
-    def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
-        with pm.Model() as model:
-            mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
-            sd_dist = pm.Exponential.dist(1.0, shape=3)
-            # pylint: disable=unpacking-non-sequence
-            chol, corr, stds = pm.LKJCholeskyCov(
-                "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
-            )
-            # pylint: enable=unpacking-non-sequence
-            mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape)
-            prior = pm.sample_prior_predictive(samples=sample_shape)
-
-        assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
-
-    @pytest.mark.parametrize(
-        ["sample_shape", "dist_shape", "mu_shape"],
-        generate_shapes(include_params=False),
-        ids=str,
-    )
-    def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
-        with pm.Model() as model:
-            mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
-            sd_dist = pm.Exponential.dist(1.0, shape=3)
-            # pylint: disable=unpacking-non-sequence
-            chol, corr, stds = pm.LKJCholeskyCov(
-                "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
-            )
-            # pylint: enable=unpacking-non-sequence
-            mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
-            prior = pm.sample_prior_predictive(samples=sample_shape)
-
-        assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
-
-    def test_issue_3758(self):
-        np.random.seed(42)
-        ndim = 50
-        with pm.Model() as model:
-            a = pm.Normal("a", sigma=100, shape=ndim)
-            b = pm.Normal("b", mu=a, sigma=1, shape=ndim)
-            c = pm.MvNormal("c", mu=a, chol=np.linalg.cholesky(np.eye(ndim)), shape=ndim)
-            d = pm.MvNormal("d", mu=a, cov=np.eye(ndim), shape=ndim)
-            samples = pm.sample_prior_predictive(1000)
-
-        for var in "abcd":
-            assert not np.isnan(np.std(samples[var]))
-
-        for var in "bcd":
-            std = np.std(samples[var] - samples["a"])
-            npt.assert_allclose(std, 1, rtol=1e-2)
-
-    def test_issue_3829(self):
-        with pm.Model() as model:
-            x = pm.MvNormal("x", mu=np.zeros(5), cov=np.eye(5), shape=(2, 5))
-            trace_pp = pm.sample_prior_predictive(50)
-
-        assert np.shape(trace_pp["x"][0]) == (2, 5)
-
-    def test_issue_3706(self):
-        N = 10
-        Sigma = np.eye(2)
-
-        with pm.Model() as model:
-            X = pm.MvNormal("X", mu=np.zeros(2), cov=Sigma, shape=(N, 2))
-            betas = pm.Normal("betas", 0, 1, shape=2)
-            y = pm.Deterministic("y", pm.math.dot(X, betas))
-
-            prior_pred = pm.sample_prior_predictive(1)
-
-        assert prior_pred["X"].shape == (1, N, 2)
-
-
 @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
 def test_matrix_normal_random_with_random_variables():
     """
@@ -2492,7 +2473,6 @@ def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param):
         output_shape = to_tuple(sample_shape) + dist_shape
         assert dist.random(size=sample_shape).shape == output_shape
 
-    @pytest.mark.xfail
     @pytest.mark.parametrize(
         ["sample_shape", "dist_shape", "mu_shape"],
         generate_shapes(include_params=False),

From 5f43bb401fb4622b987806b4758ccb04ce6c174a Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Tue, 25 Jan 2022 13:49:35 +0100
Subject: [PATCH 09/10] Reenable old MatrixNormal test

---
 pymc/tests/test_distributions_random.py | 47 ++++++++++++-------------
 pymc/tests/test_idata_conversion.py     |  4 +--
 2 files changed, 25 insertions(+), 26 deletions(-)

diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index d3e51f0d27..202eff3fbe 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -1756,6 +1756,7 @@ class TestMatrixNormal(BaseTestDistributionRandom):
         "check_pymc_params_match_rv_op",
         "check_draws",
         "check_errors",
+        "check_random_variable_prior",
     ]
 
     def check_draws(self):
@@ -1824,6 +1825,28 @@ def check_errors(self):
                     shape=15,
                 )
 
+    def check_random_variable_prior(self):
+        """
+        This test checks for shape correctness when using MatrixNormal distribution
+        with parameters as random variables.
+        Originally reported - https://github.com/pymc-devs/pymc/issues/3585
+        """
+        K = 3
+        D = 15
+        mu_0 = np.zeros((D, K))
+        lambd = 1.0
+        with pm.Model() as model:
+            sd_dist = pm.HalfCauchy.dist(beta=2.5, size=D)
+            packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist, compute_corr=False)
+            L = pm.expand_packed_triangular(D, packedL, lower=True)
+            Sigma = pm.Deterministic("Sigma", L.dot(L.T))  # D x D covariance
+            mu = pm.MatrixNormal(
+                "mu", mu=mu_0, rowcov=(1 / lambd) * Sigma, colcov=np.eye(K), shape=(D, K)
+            )
+            prior = pm.sample_prior_predictive(2, return_inferencedata=False)
+
+        assert prior["mu"].shape == (2, D, K)
+
 
 class TestInterpolated(BaseTestDistributionRandom):
     def interpolated_rng_fn(self, size, mu, sigma, rng):
@@ -2435,30 +2458,6 @@ def generate_shapes(include_params=False):
     return data
 
 
-@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
-def test_matrix_normal_random_with_random_variables():
-    """
-    This test checks for shape correctness when using MatrixNormal distribution
-    with parameters as random variables.
-    Originally reported - https://github.com/pymc-devs/pymc/issues/3585
-    """
-    K = 3
-    D = 15
-    mu_0 = np.zeros((D, K))
-    lambd = 1.0
-    with pm.Model() as model:
-        sd_dist = pm.HalfCauchy.dist(beta=2.5)
-        packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist)
-        L = pm.expand_packed_triangular(D, packedL, lower=True)
-        Sigma = pm.Deterministic("Sigma", L.dot(L.T))  # D x D covariance
-        mu = pm.MatrixNormal(
-            "mu", mu=mu_0, rowcov=(1 / lambd) * Sigma, colcov=np.eye(K), shape=(D, K)
-        )
-        prior = pm.sample_prior_predictive(2)
-
-    assert prior["mu"].shape == (2, D, K)
-
-
 @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
 class TestMvGaussianRandomWalk(SeededTest):
     @pytest.mark.parametrize(
diff --git a/pymc/tests/test_idata_conversion.py b/pymc/tests/test_idata_conversion.py
index c3105768c4..acc9093960 100644
--- a/pymc/tests/test_idata_conversion.py
+++ b/pymc/tests/test_idata_conversion.py
@@ -331,14 +331,14 @@ def test_missing_data_model(self):
         # See https://github.com/pymc-devs/pymc/issues/5255
         assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)
 
-    @pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4")
+    @pytest.mark.xfail(reason="Multivariate partial observed RVs not implemented for V4")
     def test_mv_missing_data_model(self):
         data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)
 
         model = pm.Model()
         with model:
             mu = pm.Normal("mu", 0, 1, size=2)
-            sd_dist = pm.HalfNormal.dist(1.0)
+            sd_dist = pm.HalfNormal.dist(1.0, size=2)
             # pylint: disable=unpacking-non-sequence
             chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
             # pylint: enable=unpacking-non-sequence

From 8666bbc2745056b7ff5070c1f21e7465f641a9b7 Mon Sep 17 00:00:00 2001
From: Ricardo <ricardo.vieira1994@gmail.com>
Date: Fri, 28 Jan 2022 12:16:01 +0100
Subject: [PATCH 10/10] Make `TestMvNormalMisc.test_issue_3758` bound less
 strict

---
 pymc/tests/test_distributions_random.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py
index 202eff3fbe..0ec757c37b 100644
--- a/pymc/tests/test_distributions_random.py
+++ b/pymc/tests/test_distributions_random.py
@@ -1171,7 +1171,7 @@ def test_issue_3758(self):
 
         for var in "bcd":
             std = np.std(samples[var] - samples["a"])
-            npt.assert_allclose(std, 1, rtol=1e-2)
+            npt.assert_allclose(std, 1, rtol=2e-2)
 
     def test_issue_3829(self):
         with pm.Model() as model: