Skip to content

Batched MvNormal distribution #5424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from aesara.graph.op import Op
from aesara.raise_op import Assert
from aesara.sparse.basic import sp_sum
from aesara.tensor import gammaln, sigmoid
from aesara.tensor import gammaln, sigmoid, swapaxes
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
from aesara.tensor.random.basic import MultinomialRV, dirichlet, multivariate_normal
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
Expand Down Expand Up @@ -57,6 +57,7 @@
multigammaln,
)
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.multivariate_utils import batched_matrix_inverse
from pymc.distributions.shape_utils import (
broadcast_dist_samples_to,
rv_size_is_none,
Expand Down Expand Up @@ -92,28 +93,23 @@

def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
if chol is not None and not lower:
chol = chol.T
chol = swapaxes(chol, -1, -2)

if len([i for i in [tau, cov, chol] if i is not None]) != 1:
raise ValueError("Incompatible parameterization. Specify exactly one of tau, cov, or chol.")

if cov is not None:
cov = at.as_tensor_variable(cov)
if cov.ndim != 2:
raise ValueError("cov must be two dimensional.")
elif tau is not None:
tau = at.as_tensor_variable(tau)
if tau.ndim != 2:
raise ValueError("tau must be two dimensional.")
# TODO: What's the correct order/approach (in the non-square case)?
# `aesara.tensor.nlinalg.tensorinv`?
cov = matrix_inverse(tau)
cov = batched_matrix_inverse(tau)
else:
# TODO: What's the correct order/approach (in the non-square case)?
chol = at.as_tensor_variable(chol)
if chol.ndim != 2:
raise ValueError("chol must be two dimensional.")
cov = chol.dot(chol.T)
chol_transpose = swapaxes(chol, -1, -2)
cov = chol * chol_transpose

return cov

Expand Down Expand Up @@ -240,8 +236,14 @@ class MvNormal(Continuous):
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
mu = at.as_tensor_variable(mu)
cov = quaddist_matrix(cov, chol, tau, lower)
# Aesara is stricter about the shape of mu, than PyMC used to be
mu = at.broadcast_arrays(mu, cov[..., -1])[0]

distribution_shape = at.stack(at.broadcast_shape(mu, cov[..., -1]))
mu = at.broadcast_to(mu, distribution_shape)

event_shape = [distribution_shape[-1]]
cov_shape = at.concatenate((distribution_shape, event_shape))

cov = at.broadcast_to(cov, cov_shape)
return super().dist([mu, cov], **kwargs)

def get_moment(rv, size, mu, cov):
Expand Down
76 changes: 76 additions & 0 deletions pymc/distributions/multivariate_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import numpy as np

from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.tensor import swapaxes
from aesara.tensor.basic import as_tensor_variable


class BatchedMatrixInverse(Op):
"""Computes the inverse of a matrix.

`aesara.tensor.nlinalg.matrix_inverse` can only inverse square matrices.
This Op can inverse batches of square matrices.
"""
Comment on lines +9 to +14
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions:

  • Is there an existing function in aesara that can inverse batches of square matrices in a vectorized way?
  • If no, I have added this Op in a new file. Where should this class ideally be placed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distributions/dist_math.py?

Copy link
Member

@ricardoV94 ricardoV94 Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If numpy has support for batched matrix inversion but not Aesara, we can open an issue there. Here is a similar case: aesara-devs/aesara#791

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the Aesara Op is just doing the same as this one but has an unnecessary ndim check?

We should just open a PR to fix it there


__props__ = ()

def __init__(self):
pass

def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
z[0] = np.linalg.inv(x).astype(x.dtype)

def grad(self, inputs, g_outputs):
"""
Checkout Page 10 of Matrix Cookbook:
https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
"""
(x,) = inputs
xi = self(x)
(gz,) = g_outputs

# Take transpose of last two dimensions
gz_transpose = swapaxes(gz, -1, -2)

output = xi * gz_transpose * xi

output_transpose = swapaxes(output, -1, -2)
return [-output_transpose]

def R_op(self, inputs, eval_points):
r"""The gradient function should return

.. math:: \frac{\partial X^{-1}}{\partial X}V,

where :math:`V` corresponds to ``g_outputs`` and :math:`X` to
``inputs``. Using the `matrix cookbook
<http://www2.imm.dtu.dk/pubdb/views/publication_details.php?id=3274>`_,
one can deduce that the relation corresponds to

.. math:: X^{-1} \cdot V \cdot X^{-1}.

"""
(x,) = inputs
xi = self(x)
(ev,) = eval_points
if ev is None:
return [None]
return [-xi * ev * xi]

def infer_shape(self, fgraph, node, shapes):
return shapes


batched_matrix_inverse = BatchedMatrixInverse()
# import aesara.tensor as at

# array = np.stack([np.eye(3), np.eye(3)])
# array_tensor = at.as_tensor_variable(array)
# at.grad(batched_matrix_inverse(array_tensor).mean(), array_tensor)
6 changes: 3 additions & 3 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,9 @@ def check_mu_broadcast_helper(self):
assert mu.eval().shape == (10, 3)

# Cov is artificually limited to being 2D
# x = pm.MvNormal.dist(mu=np.ones((10, 1)), cov=np.full((2, 3, 3), np.eye(3)))
# mu = x.owner.inputs[3]
# assert mu.eval().shape == (10, 2, 3)
x = pm.MvNormal.dist(mu=np.ones((10, 1, 1)), cov=np.full((2, 3, 3), np.eye(3)))
mu = x.owner.inputs[3]
assert mu.eval().shape == (10, 2, 3)


class TestMvNormalChol(BaseTestDistributionRandom):
Expand Down