From 660da7ca8e4f5d4b99af7abc0cb9fa309b0667c7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 14 Feb 2024 10:22:20 +0800 Subject: [PATCH 1/4] Delete `pytensor_scipy.py` --- .../statespace/filters/kalman_filter.py | 3 +- .../statespace/models/SARIMAX.py | 2 +- pymc_experimental/statespace/models/VARMAX.py | 2 +- .../statespace/utils/pytensor_scipy.py | 85 ---------------- .../tests/statespace/test_pytensor_scipy.py | 99 ------------------- 5 files changed, 3 insertions(+), 188 deletions(-) delete mode 100644 pymc_experimental/statespace/utils/pytensor_scipy.py delete mode 100644 pymc_experimental/tests/statespace/test_pytensor_scipy.py diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index bf17d743..87fdc746 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -9,7 +9,7 @@ from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.nlinalg import matrix_dot -from pytensor.tensor.slinalg import solve_triangular +from pytensor.tensor.slinalg import solve_discrete_are, solve_triangular from pymc_experimental.statespace.filters.utilities import ( quad_form_sym, @@ -17,7 +17,6 @@ stabilize, ) from pymc_experimental.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL -from pymc_experimental.statespace.utils.pytensor_scipy import solve_discrete_are MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] diff --git a/pymc_experimental/statespace/models/SARIMAX.py b/pymc_experimental/statespace/models/SARIMAX.py index 9350b5b1..55d3f01b 100644 --- a/pymc_experimental/statespace/models/SARIMAX.py +++ b/pymc_experimental/statespace/models/SARIMAX.py @@ -2,6 +2,7 @@ import numpy as np import pytensor.tensor as pt +from pytensor.tensor.slinalg import solve_discrete_lyapunov from pymc_experimental.statespace.core.statespace import PyMCStateSpace, floatX from pymc_experimental.statespace.models.utilities import ( @@ -19,7 +20,6 @@ SEASONAL_AR_PARAM_DIM, SEASONAL_MA_PARAM_DIM, ) -from pymc_experimental.statespace.utils.pytensor_scipy import solve_discrete_lyapunov def _verify_order(p, d, q, P, D, Q, S): diff --git a/pymc_experimental/statespace/models/VARMAX.py b/pymc_experimental/statespace/models/VARMAX.py index 0fe1ec64..0942d6db 100644 --- a/pymc_experimental/statespace/models/VARMAX.py +++ b/pymc_experimental/statespace/models/VARMAX.py @@ -3,6 +3,7 @@ import numpy as np import pytensor import pytensor.tensor as pt +from pytensor.tensor.slinalg import solve_discrete_lyapunov from pymc_experimental.statespace.core.statespace import PyMCStateSpace from pymc_experimental.statespace.models.utilities import make_default_coords @@ -16,7 +17,6 @@ SHOCK_AUX_DIM, SHOCK_DIM, ) -from pymc_experimental.statespace.utils.pytensor_scipy import solve_discrete_lyapunov floatX = pytensor.config.floatX diff --git a/pymc_experimental/statespace/utils/pytensor_scipy.py b/pymc_experimental/statespace/utils/pytensor_scipy.py deleted file mode 100644 index 996a1a9f..00000000 --- a/pymc_experimental/statespace/utils/pytensor_scipy.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytensor -import pytensor.tensor as pt -import scipy -from pytensor.tensor import TensorVariable, as_tensor_variable -from pytensor.tensor.nlinalg import matrix_dot -from pytensor.tensor.slinalg import solve_discrete_lyapunov - -floatX = pytensor.config.floatX - - -class SolveDiscreteARE(pt.Op): - __props__ = ("enforce_Q_symmetric",) - - def __init__(self, enforce_Q_symmetric=False): - self.enforce_Q_symmetric = enforce_Q_symmetric - - def make_node(self, A, B, Q, R): - A = as_tensor_variable(A) - B = as_tensor_variable(B) - Q = as_tensor_variable(Q) - R = as_tensor_variable(R) - - out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, Q.dtype, R.dtype) - X = pytensor.tensor.matrix(dtype=out_dtype) - - return pytensor.graph.basic.Apply(self, [A, B, Q, R], [X]) - - def perform(self, node, inputs, output_storage): - A, B, Q, R = inputs - X = output_storage[0] - - if self.enforce_Q_symmetric: - Q = 0.5 * (Q + Q.T) - X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R).astype(floatX) - - def infer_shape(self, fgraph, node, shapes): - return [shapes[0]] - - def grad(self, inputs, output_grads): - # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf - A, B, Q, R = inputs - - (dX,) = output_grads - X = self(A, B, Q, R) - - K_inner = R + pt.linalg.matrix_dot(B.T, X, B) - K_inner_inv = pt.linalg.solve(K_inner, pt.eye(R.shape[0])) - K = matrix_dot(K_inner_inv, B.T, X, A) - - A_tilde = A - B.dot(K) - - dX_symm = 0.5 * (dX + dX.T) - S = solve_discrete_lyapunov(A_tilde, dX_symm).astype(floatX) - - A_bar = 2 * matrix_dot(X, A_tilde, S) - B_bar = -2 * matrix_dot(X, A_tilde, S, K.T) - Q_bar = S - R_bar = matrix_dot(K, S, K.T) - - return [A_bar, B_bar, Q_bar, R_bar] - - -def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: - """ - Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`. - Parameters - ---------- - A: ArrayLike - Square matrix of shape M x M - B: ArrayLike - Square matrix of shape M x M - Q: ArrayLike - Symmetric square matrix of shape M x M - R: ArrayLike - Square matrix of shape N x N - enforce_Q_symmetric: bool - If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry - - Returns - ------- - X: pt.matrix - Square matrix of shape M x M, representing the solution to the DARE - """ - - return SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R) diff --git a/pymc_experimental/tests/statespace/test_pytensor_scipy.py b/pymc_experimental/tests/statespace/test_pytensor_scipy.py deleted file mode 100644 index 8ceae336..00000000 --- a/pymc_experimental/tests/statespace/test_pytensor_scipy.py +++ /dev/null @@ -1,99 +0,0 @@ -import sys -import unittest - -import numpy as np -import pytensor -from numpy.testing import assert_allclose -from pytensor.configdefaults import config -from pytensor.gradient import verify_grad as orig_verify_grad - -from pymc_experimental.statespace.utils.pytensor_scipy import ( - SolveDiscreteARE, - solve_discrete_are, -) - -floatX = pytensor.config.floatX -solve_discrete_are_enforce = SolveDiscreteARE(enforce_Q_symmetric=True) - - -def fetch_seed(pseed=None): - """ - Copied from pytensor.test.unittest_tools - """ - - seed = pseed or config.unittests__rseed - if seed == "random": - seed = None - - try: - if seed: - seed = int(seed) - else: - seed = None - except ValueError: - print( - ("Error: config.unittests__rseed contains " "invalid seed, using None instead"), - file=sys.stderr, - ) - seed = None - - return seed - - -def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs): - """ - Copied from pytensor.test.unittest_tools - """ - if rng is None: - rng = np.random.default_rng(fetch_seed()) - - # TODO: Needed to increase tolerance for certain tests when migrating to - # Generators from RandomStates. Caused flaky test failures. Needs further investigation - if "rel_tol" not in kwargs: - kwargs["rel_tol"] = 0.05 - if "abs_tol" not in kwargs: - kwargs["abs_tol"] = 0.05 - orig_verify_grad(op, pt, n_tests, rng, *args, **kwargs) - - -class TestSolveDiscreteARE(unittest.TestCase): - def test_forward(self): - # TEST CASE 4 : darex #1 -- taken from Scipy tests - a, b, q, r = ( - np.array([[4, 3], [-4.5, -3.5]], dtype=floatX), - np.array([[1], [-1]], dtype=floatX), - np.array([[9, 6], [6, 4]], dtype=floatX), - np.array([[1]], dtype=floatX), - ) - a, b, q, r = (x.astype(floatX) for x in [a, b, q, r]) - - x = solve_discrete_are(a, b, q, r).eval() - res = a.T.dot(x.dot(a)) - x + q - res -= ( - a.conj() - .T.dot(x.dot(b)) - .dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a))) - ) - - atol = 1e-4 if floatX == "float32" else 1e-12 - assert_allclose(res, np.zeros_like(res), atol=atol) - - def test_backward(self): - a, b, q, r = ( - np.array([[4, 3], [-4.5, -3.5]], dtype=floatX), - np.array([[1], [-1]], dtype=floatX), - np.array([[9, 6], [6, 4]], dtype=floatX), - np.array([[1]], dtype=floatX), - ) - a, b, q, r = (x.astype(floatX) for x in [a, b, q, r]) - - rng = np.random.default_rng(fetch_seed()) - - # TODO: Is there a "theoretically motivated" value to use here? I pulled 1e-4 out of a hat - atol = 1e-4 if floatX == "float32" else 1e-12 - - verify_grad(solve_discrete_are_enforce, pt=[a, b, q, r], rng=rng, abs_tol=atol) - - -if __name__ == "__main__": - unittest.main() From 667596a8a0ddfa47f2799f9185f692ba1adfd669 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 14 Feb 2024 10:23:08 +0800 Subject: [PATCH 2/4] Delete `block_diagonal`, use `pt.linalg.block_diag` --- .../statespace/models/structural.py | 27 ++++--------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/pymc_experimental/statespace/models/structural.py b/pymc_experimental/statespace/models/structural.py index 4c98695c..3c6f5af7 100644 --- a/pymc_experimental/statespace/models/structural.py +++ b/pymc_experimental/statespace/models/structural.py @@ -44,23 +44,6 @@ def _frequency_transition_block(s, j): return pt.stack([[pt.cos(lam), pt.sin(lam)], [-pt.sin(lam), pt.cos(lam)]]).squeeze() -def block_diagonal(matrices: list[pt.matrix]): - rows = [x.shape[0] for x in matrices] - cols = [x.shape[1] for x in matrices] - out = pt.zeros((sum(rows), sum(cols))) - row_cursor = 0 - col_cursor = 0 - - for row, col, mat in zip(rows, cols, matrices): - row_slice = slice(row_cursor, row_cursor + row) - col_slice = slice(col_cursor, col_cursor + col) - row_cursor += row - col_cursor += col - - out = pt.set_subtensor(out[row_slice, col_slice], mat) - return out - - class StructuralTimeSeries(PyMCStateSpace): r""" Structural Time Series Model @@ -527,7 +510,7 @@ def make_slice(name, x, o_x): initial_state = pt.concatenate(conform_time_varying_and_time_invariant_matrices(x0, o_x0)) initial_state.name = x0.name - initial_state_cov = block_diagonal([P0, o_P0]) + initial_state_cov = pt.linalg.block_diag(P0, o_P0) initial_state_cov.name = P0.name state_intercept = pt.concatenate(conform_time_varying_and_time_invariant_matrices(c, o_c)) @@ -536,19 +519,19 @@ def make_slice(name, x, o_x): obs_intercept = d + o_d obs_intercept.name = d.name - transition = block_diagonal([T, o_T]) + transition = pt.linalg.block_diag(T, o_T) transition.name = T.name design = pt.concatenate(conform_time_varying_and_time_invariant_matrices(Z, o_Z), axis=-1) design.name = Z.name - selection = block_diagonal([R, o_R]) + selection = pt.linalg.block_diag(R, o_R) selection.name = R.name obs_cov = H + o_H obs_cov.name = H.name - state_cov = block_diagonal([Q, o_Q]) + state_cov = pt.linalg.block_diag(Q, o_Q) state_cov.name = Q.name new_ssm = PytensorRepresentation( @@ -1326,7 +1309,7 @@ def make_symbolic_graph(self) -> None: self.ssm["initial_state", init_state_idx] = init_state T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)] - T = block_diagonal(T_mats) + T = pt.linalg.block_diag(*T_mats) self.ssm["transition", :, :] = T if self.innovations: From d32b71afd105a6e580efd8a81b5d2ef34907e8a7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 14 Feb 2024 12:04:49 +0800 Subject: [PATCH 3/4] Relax tolerance in loglikelihood tests --- pymc_experimental/tests/statespace/test_distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/tests/statespace/test_distributions.py b/pymc_experimental/tests/statespace/test_distributions.py index 910a8929..1da4be60 100644 --- a/pymc_experimental/tests/statespace/test_distributions.py +++ b/pymc_experimental/tests/statespace/test_distributions.py @@ -26,8 +26,8 @@ # TODO: These are pretty loose because of all the stabilizing of covariance matrices that is done inside the kalman # filters. When that is improved, this should be tightened. -ATOL = 1e-6 if floatX.endswith("64") else 1e-4 -RTOL = 1e-6 if floatX.endswith("64") else 1e-4 +ATOL = 1e-5 if floatX.endswith("64") else 1e-4 +RTOL = 1e-5 if floatX.endswith("64") else 1e-4 filter_names = [ "standard", From faf1c87ea122337dd475340c6a5025ed678fffb9 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 14 Feb 2024 10:23:29 +0800 Subject: [PATCH 4/4] Use updated pandas frequency string --- .../tests/statespace/utilities/test_helpers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/tests/statespace/utilities/test_helpers.py b/pymc_experimental/tests/statespace/utilities/test_helpers.py index 078c6649..de5999fb 100644 --- a/pymc_experimental/tests/statespace/utilities/test_helpers.py +++ b/pymc_experimental/tests/statespace/utilities/test_helpers.py @@ -19,8 +19,15 @@ def load_nile_test_data(): + from importlib.metadata import version + nile = pd.read_csv("pymc_experimental/tests/statespace/test_data/nile.csv", dtype={"x": floatX}) - nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq="AS-Jan") + major, minor, rev = map(int, version("pandas").split(".")) + if major >= 2 and minor >= 2 and rev >= 0: + freq_str = "YS-JAN" + else: + freq_str = "AS-JAN" + nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq=freq_str) nile.rename(columns={"x": "height"}, inplace=True) nile = (nile - nile.mean()) / nile.std() nile = nile.astype(floatX)