From 574bcb7bf84a152a63441dfbb7e55a8831fe326e Mon Sep 17 00:00:00 2001 From: Sayam753 Date: Sun, 30 Jan 2022 20:52:28 +0530 Subject: [PATCH 1/2] Initial commit - Batched MvNormal distribution --- pymc/distributions/multivariate.py | 26 ++++---- pymc/distributions/multivariate_utils.py | 76 ++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 12 deletions(-) create mode 100644 pymc/distributions/multivariate_utils.py diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 45d634449f..04e1fb1fe1 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -28,7 +28,7 @@ from aesara.graph.op import Op from aesara.raise_op import Assert from aesara.sparse.basic import sp_sum -from aesara.tensor import gammaln, sigmoid +from aesara.tensor import gammaln, sigmoid, swapaxes from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal from aesara.tensor.random.op import RandomVariable, default_shape_from_params @@ -57,6 +57,7 @@ multigammaln, ) from pymc.distributions.distribution import Continuous, Discrete +from pymc.distributions.multivariate_utils import batched_matrix_inverse from pymc.distributions.shape_utils import ( broadcast_dist_samples_to, rv_size_is_none, @@ -92,28 +93,23 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs): if chol is not None and not lower: - chol = chol.T + chol = swapaxes(chol, -1, -2) if len([i for i in [tau, cov, chol] if i is not None]) != 1: raise ValueError("Incompatible parameterization. Specify exactly one of tau, cov, or chol.") if cov is not None: cov = at.as_tensor_variable(cov) - if cov.ndim != 2: - raise ValueError("cov must be two dimensional.") elif tau is not None: tau = at.as_tensor_variable(tau) - if tau.ndim != 2: - raise ValueError("tau must be two dimensional.") # TODO: What's the correct order/approach (in the non-square case)? # `aesara.tensor.nlinalg.tensorinv`? - cov = matrix_inverse(tau) + cov = batched_matrix_inverse(tau) else: # TODO: What's the correct order/approach (in the non-square case)? chol = at.as_tensor_variable(chol) - if chol.ndim != 2: - raise ValueError("chol must be two dimensional.") - cov = chol.dot(chol.T) + chol_transpose = swapaxes(chol, -1, -2) + cov = chol * chol_transpose return cov @@ -240,8 +236,14 @@ class MvNormal(Continuous): def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs): mu = at.as_tensor_variable(mu) cov = quaddist_matrix(cov, chol, tau, lower) - # Aesara is stricter about the shape of mu, than PyMC used to be - mu = at.broadcast_arrays(mu, cov[..., -1])[0] + + distribution_shape = at.broadcast_shape(mu.shape, cov.shape[:-1], arrays_are_shapes=True) + mu = at.broadcast_to(mu, distribution_shape) + + event_shape = distribution_shape[-1] + cov_shape = at.broadcast_shape(mu[..., None].shape, event_shape, arrays_are_shapes=True) + + cov = at.broadcast_to(cov, cov_shape) return super().dist([mu, cov], **kwargs) def get_moment(rv, size, mu, cov): diff --git a/pymc/distributions/multivariate_utils.py b/pymc/distributions/multivariate_utils.py new file mode 100644 index 0000000000..65e6561ec7 --- /dev/null +++ b/pymc/distributions/multivariate_utils.py @@ -0,0 +1,76 @@ +import numpy as np + +from aesara.graph.basic import Apply +from aesara.graph.op import Op +from aesara.tensor import swapaxes +from aesara.tensor.basic import as_tensor_variable + + +class BatchedMatrixInverse(Op): + """Computes the inverse of a matrix. + + `aesara.tensor.nlinalg.matrix_inverse` can only inverse square matrices. + This Op can inverse batches of square matrices. + """ + + __props__ = () + + def __init__(self): + pass + + def make_node(self, x): + x = as_tensor_variable(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, outputs): + (x,) = inputs + (z,) = outputs + z[0] = np.linalg.inv(x).astype(x.dtype) + + def grad(self, inputs, g_outputs): + """ + Checkout Page 10 of Matrix Cookbook: + https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf + """ + (x,) = inputs + xi = self(x) + (gz,) = g_outputs + + # Take transpose of last two dimensions + gz_transpose = swapaxes(gz, -1, -2) + + output = xi * gz_transpose * xi + + output_transpose = swapaxes(output, -1, -2) + return [-output_transpose] + + def R_op(self, inputs, eval_points): + r"""The gradient function should return + + .. math:: \frac{\partial X^{-1}}{\partial X}V, + + where :math:`V` corresponds to ``g_outputs`` and :math:`X` to + ``inputs``. Using the `matrix cookbook + `_, + one can deduce that the relation corresponds to + + .. math:: X^{-1} \cdot V \cdot X^{-1}. + + """ + (x,) = inputs + xi = self(x) + (ev,) = eval_points + if ev is None: + return [None] + return [-xi * ev * xi] + + def infer_shape(self, fgraph, node, shapes): + return shapes + + +batched_matrix_inverse = BatchedMatrixInverse() +# import aesara.tensor as at + +# array = np.stack([np.eye(3), np.eye(3)]) +# array_tensor = at.as_tensor_variable(array) +# at.grad(batched_matrix_inverse(array_tensor).mean(), array_tensor) From 5ee641f63c4c7ae34d71c9b28b208be82272df60 Mon Sep 17 00:00:00 2001 From: Sayam753 Date: Mon, 31 Jan 2022 00:33:48 +0530 Subject: [PATCH 2/2] RV object creation works and so does random sampling --- pymc/distributions/multivariate.py | 6 +++--- pymc/tests/test_distributions_random.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 04e1fb1fe1..b652232fec 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -237,11 +237,11 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs): mu = at.as_tensor_variable(mu) cov = quaddist_matrix(cov, chol, tau, lower) - distribution_shape = at.broadcast_shape(mu.shape, cov.shape[:-1], arrays_are_shapes=True) + distribution_shape = at.stack(at.broadcast_shape(mu, cov[..., -1])) mu = at.broadcast_to(mu, distribution_shape) - event_shape = distribution_shape[-1] - cov_shape = at.broadcast_shape(mu[..., None].shape, event_shape, arrays_are_shapes=True) + event_shape = [distribution_shape[-1]] + cov_shape = at.concatenate((distribution_shape, event_shape)) cov = at.broadcast_to(cov, cov_shape) return super().dist([mu, cov], **kwargs) diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 0ec757c37b..0718b73bbf 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1094,9 +1094,9 @@ def check_mu_broadcast_helper(self): assert mu.eval().shape == (10, 3) # Cov is artificually limited to being 2D - # x = pm.MvNormal.dist(mu=np.ones((10, 1)), cov=np.full((2, 3, 3), np.eye(3))) - # mu = x.owner.inputs[3] - # assert mu.eval().shape == (10, 2, 3) + x = pm.MvNormal.dist(mu=np.ones((10, 1, 1)), cov=np.full((2, 3, 3), np.eye(3))) + mu = x.owner.inputs[3] + assert mu.eval().shape == (10, 2, 3) class TestMvNormalChol(BaseTestDistributionRandom):