Skip to content

Merge statespace module from http://github.com/jessegrabowski/pymc_statespace #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 140 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
140 commits
Select commit Hold shift + click to select a range
6dba1dc
move http://github.com/jessegrabowski/pymc_statespace to pymc-experim…
jessegrabowski May 29, 2023
d9f3239
Make tests compatible with float32
jessegrabowski May 29, 2023
c37fb77
Fix remaining tests that failed when dtype is float32
jessegrabowski May 29, 2023
5f48976
Fix remaining tests that failed when dtype is float32
jessegrabowski May 29, 2023
d04adaf
Merge remote-tracking branch 'origin/statespace' into statespace
jessegrabowski May 29, 2023
6b0a123
Merge branch 'main' into statespace
jessegrabowski May 29, 2023
12b4db5
Merge branch 'pymc-devs:main' into statespace
jessegrabowski Jul 8, 2023
ead1036
Merge branch 'statespace' of https://github.com/jessegrabowski/pymc-e…
jessegrabowski Jul 8, 2023
c6a4f85
Replace print with `logger.info`
jessegrabowski Jul 12, 2023
8e18779
Use `getattr` to access model variables in `gather_required_random_va…
jessegrabowski Jul 12, 2023
8b05fb7
Reduce code duplication, eliminate use of `pathlib` in favor of expli…
jessegrabowski Jul 12, 2023
389bfbe
Refactor `PyMCStateSpace` and `PytensorRepresentation` to no longer r…
jessegrabowski Jul 12, 2023
b3d5c2a
Merge remote-tracking branch 'origin/statespace' into statespace
jessegrabowski Jul 13, 2023
e577590
Refactor kalman filters to remove singleton dimensions everywhere pos…
jessegrabowski Jul 13, 2023
c761d85
Add tests for sarimax, add first pass at state space distribution
jessegrabowski Jul 13, 2023
4d5cf7e
Merge remote-tracking branch 'origin/statespace' into statespace
jessegrabowski Jul 14, 2023
96cca74
Add more helpers to statespace class for building PyMC model
jessegrabowski Jul 15, 2023
9532901
Changes to SARIMAX to make states interpretable, plus bad draft pymc …
jessegrabowski Jul 16, 2023
953542e
Distribution attempt 2
jessegrabowski Jul 16, 2023
ec03342
Distribution attempt 3
jessegrabowski Jul 18, 2023
9cf56d1
Fix broken tests
jessegrabowski Jul 18, 2023
496f815
Fix remaining tests
jessegrabowski Jul 19, 2023
7e3ed2f
Preserve name/shape information after updates to matrices
jessegrabowski Jul 19, 2023
9efa465
Add test for `SequenceMvNormal`
jessegrabowski Jul 19, 2023
5ef77a9
Add logp as parameter to `SequenceMvNormal`
jessegrabowski Jul 20, 2023
8648155
Add logp argument to `SequenceMvNormal` test
jessegrabowski Jul 20, 2023
9a4636d
Refactor shape checking to use `type.shape`
jessegrabowski Jul 20, 2023
ceecd15
Refactor shape checking to use `type.shape`
jessegrabowski Jul 20, 2023
d31c220
Add float32 test skip back to `test_filters_match_statsmodel_output`
jessegrabowski Jul 21, 2023
cf49e92
Refactor `LinearGaussianStateSpace` for use in prior/posterior predic…
jessegrabowski Jul 21, 2023
570d613
Infer coordinates for statespace matrices from parameters
jessegrabowski Jul 21, 2023
90dcf69
Refactor posterior predictive sampling, coords, and data handling
jessegrabowski Jul 22, 2023
15a337c
Remove typecheck with `Union`
jessegrabowski Jul 22, 2023
d76a12c
Fix float32 error in tests
jessegrabowski Jul 22, 2023
6b8575a
Fix float32 error in tests
jessegrabowski Jul 22, 2023
2f19ec1
Fix bugs in SARMIAX, add measurement error option to SARIMAX, update …
jessegrabowski Jul 22, 2023
ef32c5b
always add `TIME_DIM` to model coords, even when data aren't given dims
jessegrabowski Jul 22, 2023
de72922
always add `TIME_DIM` to model coords, even when data aren't given dims
jessegrabowski Jul 22, 2023
d982566
Add forecast method to `Statespace`, add coord/dim info to `VARMAX`, …
jessegrabowski Jul 23, 2023
dd81173
Add `impulse_response_function` method to `Statespace`, update `ARIMA…
jessegrabowski Jul 24, 2023
6538d18
Update VARMAX notebook, remove `pm.DiracDelta` from `Statespace.forec…
jessegrabowski Jul 24, 2023
36d7901
Update VARMAX example notebook
jessegrabowski Jul 25, 2023
c1208db
Add docstrings to statespace methods and utilities
jessegrabowski Jul 25, 2023
d9caa01
Merge remote-tracking branch 'origin/statespace' into statespace
jessegrabowski Jul 25, 2023
8f47ef3
More docstrings
jessegrabowski Jul 26, 2023
96b1f37
documentation
jessegrabowski Jul 27, 2023
04ad689
More docs, fix failing tests
jessegrabowski Jul 27, 2023
4f41b2a
Implement conditional and unconditonal sampling from prior idata
jessegrabowski Jul 27, 2023
1b6297f
Trying to get the API docs to render
jessegrabowski Jul 27, 2023
3f766d3
Trying to get the API docs to render
jessegrabowski Jul 27, 2023
ca91078
Revert "Trying to get the API docs to render"
jessegrabowski Jul 27, 2023
1de05e8
Trying to get the API docs to render, revert local change to `pyproje…
jessegrabowski Jul 27, 2023
94876b1
Trying to get the API docs to render
jessegrabowski Jul 27, 2023
8cf95b4
Remove `specify_broadcast` in `SingleTimeseriesFilter.update`
jessegrabowski Jul 27, 2023
849ba1b
Fixing docstrings
jessegrabowski Jul 27, 2023
7ff39e2
Fixing docstrings
jessegrabowski Jul 27, 2023
e921698
create toctree hierarchy
jessegrabowski Jul 27, 2023
4e0b463
create toctree hierarchy
jessegrabowski Jul 27, 2023
a0f3343
create toctree hierarchy
jessegrabowski Jul 27, 2023
d91fc37
docstrings
jessegrabowski Jul 27, 2023
ad933b7
docstrings
jessegrabowski Jul 27, 2023
29a88ff
docstrings
jessegrabowski Jul 27, 2023
61b6945
Making a table in a docstring
jessegrabowski Jul 27, 2023
ae106c2
Making a table in a docstring
jessegrabowski Jul 27, 2023
b75ff41
More docs
jessegrabowski Jul 28, 2023
12f56c4
More docstrings
jessegrabowski Jul 28, 2023
812d6c5
More docstrings
jessegrabowski Jul 28, 2023
c7772ac
More docstrings
jessegrabowski Jul 28, 2023
7670f4e
Remove `BayesianLocalLevel`, add a new `StructuralTimeSeries` module
jessegrabowski Jul 29, 2023
f525166
Remove `BayesianLocalLevel`, add a new `StructuralTimeSeries` module
jessegrabowski Jul 29, 2023
32c58bc
Bug fixes in `structural`, new example notebook
jessegrabowski Jul 30, 2023
a231c29
Bug fixes in `structural`, update structural example
jessegrabowski Jul 31, 2023
e2b93e6
Kalman filter no longer returns a forecast on the predicted states
jessegrabowski Jul 31, 2023
0f2a2b9
Fix bugs in IRF API
jessegrabowski Jul 31, 2023
95a9860
Fix bugs in IRF API
jessegrabowski Jul 31, 2023
caa3e0b
Add `airpass.csv` test data
jessegrabowski Jul 31, 2023
cf2157e
More docstrings
jessegrabowski Jul 31, 2023
6b8c9c9
More docstrings
jessegrabowski Jul 31, 2023
01d063a
More docstrings
jessegrabowski Jul 31, 2023
7f824dc
More docstrings
jessegrabowski Jul 31, 2023
f2b5281
Fix VARMAX tests
jessegrabowski Jul 31, 2023
669382a
More docs
jessegrabowski Jul 31, 2023
f961eac
More docs, allow user to name states in seasonal components
jessegrabowski Jul 31, 2023
5631ad3
more docs, rename `test_structural_model` to `test_structural`
jessegrabowski Aug 1, 2023
21570c2
Allow non-integer seasonal length in `FrequencySeasonality` (breaks t…
jessegrabowski Aug 1, 2023
5b1dbc8
Seed `sample_posterior` and `sample_prior` tests, unbreak `test_struc…
jessegrabowski Aug 1, 2023
edcd9ca
Set test seed as a global constant
jessegrabowski Aug 2, 2023
ee4a817
Remove numba dependency
jessegrabowski Aug 2, 2023
68ac3c2
Add helper function to `StructuralTimeSeries` to reconstruct componen…
jessegrabowski Aug 2, 2023
f899096
Add `fast_eval` as test utility
jessegrabowski Aug 2, 2023
c4aee6b
Changes from review feedback
jessegrabowski Aug 2, 2023
29d1778
Try to fix failing test
jessegrabowski Aug 2, 2023
f324db2
Catch expected test warnings with `pytest.mark.filterwarnings`
jessegrabowski Aug 2, 2023
b9a6dd2
Scope fixtures in sampling tests to speed them up a bit
jessegrabowski Aug 2, 2023
69395ef
More fixes from review
jessegrabowski Aug 3, 2023
c6485a0
Trying to fix the failing test
jessegrabowski Aug 3, 2023
90f9fdd
Remove numba from `windows-environment-test.yml`
jessegrabowski Aug 3, 2023
86b7642
Test for NaNs in sampled statespace objects
jessegrabowski Aug 3, 2023
6b69719
Redo the `Custom Statespace` notebooks
jessegrabowski Aug 3, 2023
b4a2d87
Add `ImputationWarning` to `mask_missing_values_in_data`
jessegrabowski Aug 3, 2023
5a3d4cd
Add `ImputationWarning` to `mask_missing_values_in_data`
jessegrabowski Aug 3, 2023
b2cdc1c
Add off-diagonal averaging to covariance `stabilize` function.
jessegrabowski Aug 3, 2023
ab67bb8
Covariance matrices output by kalman filter not always PSD
jessegrabowski Aug 4, 2023
795910c
Add some stability tricks to Kalman filters
jessegrabowski Aug 4, 2023
cd7e712
Add `measurement_error` kwarg to `PyMCStateSpace`
jessegrabowski Aug 4, 2023
221a6ef
Allow `eig == 0.0` to pass PSD test
jessegrabowski Aug 4, 2023
ff0cd98
Add stability tests to VARMAX and SARIMAX
jessegrabowski Aug 4, 2023
adf5e82
LGSS distribution doesn't need to know anything about measurement error
jessegrabowski Aug 5, 2023
78d982f
Re-run Structural Timeseries Modeling.ipynb with some new features
jessegrabowski Aug 5, 2023
cf0f0c0
Delete out-of-date example notebook
jessegrabowski Aug 5, 2023
3060676
Updates to Making a Custom Statespace Model.ipynb
jessegrabowski Aug 5, 2023
53226cb
Link to the bVAR example notebook in the VARMAX Example.ipynb
jessegrabowski Aug 5, 2023
bd2e7e7
Make the IRF plots more readable in ARIMA Example.ipynb
jessegrabowski Aug 5, 2023
225810f
Increase `JITTER_DEFAULT` when pytensor is in `float32` mode
jessegrabowski Aug 5, 2023
fdefbee
Increase `JITTER_DEFAULT` when pytensor is in `float32` mode
jessegrabowski Aug 5, 2023
de9618d
Merge remote-tracking branch 'origin/statespace' into statespace
jessegrabowski Aug 5, 2023
b8995aa
Add stabilization to univariate filter covariance
jessegrabowski Aug 5, 2023
b0cf3de
Trying to get the last tests to pass
jessegrabowski Aug 5, 2023
3a143a3
Different stability strategy for UnivariateFilter
jessegrabowski Aug 5, 2023
259f6f4
Skip overly sensitive float32 tests
jessegrabowski Aug 5, 2023
7f299a3
Remove `update` function
jessegrabowski Aug 8, 2023
1ce2c7a
All tests pass
jessegrabowski Aug 9, 2023
593745c
Update Making a Custom Statespace Model.ipynb to reflect refactor
jessegrabowski Aug 9, 2023
5748bca
Relax float32 test tolerance in `test_structural.py`
jessegrabowski Aug 9, 2023
2f87538
Set dtype on numpy arrays used in test_structural.py
jessegrabowski Aug 9, 2023
c3dcfa8
Remove dictionary merge with pipe
jessegrabowski Aug 9, 2023
3cd9dd9
Disable all PSD tests for univariate filter when floatX=float32
jessegrabowski Aug 9, 2023
44d7db0
Remove `variable_by_shape` helper, use `pt.tensor` directly.
jessegrabowski Aug 9, 2023
c60ff4e
Remove unused helper functions
jessegrabowski Aug 9, 2023
27163c5
Adjust new tests for float32
jessegrabowski Aug 10, 2023
03c59db
Add test for equivalence between SARIMA representations
jessegrabowski Aug 10, 2023
731228f
Add test for equivalence between SARIMA representations
jessegrabowski Aug 10, 2023
0c8b590
Use `self.mode` in all kalman filter scans
jessegrabowski Aug 10, 2023
688fd01
Begin adding seasonal components to SARIMAX.py
jessegrabowski Aug 12, 2023
891a823
Implement seasonal lags and differences in SARIMAX
jessegrabowski Aug 13, 2023
a3b0958
Add an Exogenous Regression component to `structural.py`
jessegrabowski Aug 14, 2023
65b4294
Expand support for exogenous variables in statespace models
jessegrabowski Aug 18, 2023
cfe6e0a
Tweak tests
jessegrabowski Aug 18, 2023
e4cdd12
Tweak tests
jessegrabowski Aug 18, 2023
1342afa
Merge remote-tracking branch 'origin/statespace' into statespace
jessegrabowski Aug 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ dependencies:
- pytest>=3.0
- dask
- xhistogram
- statsmodels
- numba
- pip:
- pymc>=5.6.0 # CI was failing to resolve
- blackjax
Expand Down
2 changes: 2 additions & 0 deletions conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ dependencies:
- pytest>=3.0
- dask
- xhistogram
- statsmodels
- numba
- pip:
- pymc>=5.6.0 # CI was failing to resolve
- scikit-learn
875 changes: 875 additions & 0 deletions notebooks/ARMA Example.ipynb

Large diffs are not rendered by default.

602 changes: 602 additions & 0 deletions notebooks/Custom SSM - Daily Seasonality.ipynb

Large diffs are not rendered by default.

1,893 changes: 1,893 additions & 0 deletions notebooks/Nile Local Level Model.ipynb

Large diffs are not rendered by default.

1,020 changes: 1,020 additions & 0 deletions notebooks/VARMAX Example.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pymc_experimental/statespace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pymc_experimental.statespace.models.local_level import BayesianLocalLevel
from pymc_experimental.statespace.models.SARIMAX import BayesianARMA
from pymc_experimental.statespace.models.VARMAX import BayesianVARMAX

__all__ = ["BayesianLocalLevel", "BayesianARMA", "BayesianVARMAX"]
Empty file.
285 changes: 285 additions & 0 deletions pymc_experimental/statespace/core/representation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
from functools import reduce
from typing import List, Optional, Tuple, Type, Union

import numpy as np
import pandas.core.tools.datetimes
import pytensor
import pytensor.tensor as pt
from pandas import DataFrame

floatX = pytensor.config.floatX
KeyLike = Union[Tuple[Union[str, int]], str]

NEVER_TIME_VARYING = ["initial_state", "initial_state_cov", "a0", "P0"]
VECTOR_VALUED = ["initial_state", "state_intercept", "obs_intercept", "a0", "c", "d"]


def _preprocess_data(data: Union[DataFrame, np.ndarray], expected_dims=3):
if isinstance(data, pandas.DataFrame):
data = data.values
elif not isinstance(data, np.ndarray):
raise ValueError("Expected pandas Dataframe or numpy array as data")

if data.ndim < expected_dims:
n_dims = data.ndim
n_to_add = expected_dims - n_dims + 1
data = reduce(lambda a, b: np.expand_dims(a, -1), [data] * n_to_add)

return data


class PytensorRepresentation:
def __init__(
self,
k_endog: int,
k_states: int,
k_posdef: int,
design: Optional[np.ndarray] = None,
obs_intercept: Optional[np.ndarray] = None,
obs_cov=None,
transition=None,
state_intercept=None,
selection=None,
state_cov=None,
initial_state=None,
initial_state_cov=None,
) -> None:
"""
A representation of a State Space model, in Pytensor. Shamelessly copied from the Statsmodels.api implementation
found here:

https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py

Parameters
----------
k_endog: int
Number of observed states (called "endogeous states" in statsmodels)
k_states: int
Number of hidden states
k_posdef: int
Number of states that have exogenous shocks; also the rank of the selection matrix R.
design: ArrayLike, optional
Design matrix, denoted 'Z' in [1].
obs_intercept: ArrayLike, optional
Constant vector in the observation equation, denoted 'd' in [1]. Currently
not used.
obs_cov: ArrayLike, optional
Covariance matrix for multivariate-normal errors in the observation equation. Denoted 'H' in
[1].
transition: ArrayLike, optional
Transition equation that updates the hidden state between time-steps. Denoted 'T' in [1].
state_intercept: ArrayLike, optional
Constant vector for the observation equation, denoted 'c' in [1]. Currently not used.
selection: ArrayLike, optional
Selection matrix that matches shocks to hidden states, denoted 'R' in [1]. This is the identity
matrix when k_posdef = k_states.
state_cov: ArrayLike, optional
Covariance matrix for state equations, denoted 'Q' in [1]. Null matrix when there is no observation
noise.
initial_state: ArrayLike, optional
Experimental setting to allow for Bayesian estimation of the initial state, denoted `alpha_0` in [1]. Default
It should potentially be removed in favor of the closed-form diffuse initialization.
initial_state_cov: ArrayLike, optional
Experimental setting to allow for Bayesian estimation of the initial state, denoted `P_0` in [1]. Default
It should potentially be removed in favor of the closed-form diffuse initialization.

References
----------
.. [1] Durbin, James, and Siem Jan Koopman. 2012.
Time Series Analysis by State Space Methods: Second Edition.
Oxford University Press.
"""
self.k_states = k_states
self.k_endog = k_endog
self.k_posdef = k_posdef if k_posdef is not None else k_states

# The last dimension is for time varying matrices; it could be n_obs. Not thinking about that now.
self.shapes = {
"design": (self.k_endog, self.k_states, 1),
"obs_intercept": (self.k_endog, 1),
"obs_cov": (self.k_endog, self.k_endog, 1),
"transition": (self.k_states, self.k_states, 1),
"state_intercept": (self.k_states, 1),
"selection": (self.k_states, self.k_posdef, 1),
"state_cov": (self.k_posdef, self.k_posdef, 1),
"initial_state": (self.k_states,),
"initial_state_cov": (self.k_states, self.k_states),
}

# Initialize the representation matrices
scope = locals()
for name, shape in self.shapes.items():
if scope[name] is not None:
matrix = self._numpy_to_pytensor(name, scope[name])
setattr(self, name, matrix)

else:
setattr(self, name, pt.zeros(shape, dtype=floatX))

def _validate_key(self, key: KeyLike) -> None:
if key not in self.shapes:
raise IndexError(f"{key} is an invalid state space matrix name")

def update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.TensorType]) -> None:
# TODO: Get rid of these evals
if isinstance(value, (pt.TensorConstant, pt.TensorVariable)):
shape = value.shape.eval()
else:
shape = value.shape

old_shape = self.shapes[key]
check_slice = slice(None, 2) if key not in VECTOR_VALUED else slice(None, 1)

if not all([a == b for a, b in zip(shape[check_slice], old_shape[check_slice])]):
raise ValueError(
f"The first two dimensions of {key} must be {old_shape[check_slice]}, found {shape[check_slice]}"
)

# Add time dimension dummy if none present
if len(shape) == 2 and key not in NEVER_TIME_VARYING:
self.shapes[key] = shape + (1,)

self.shapes[key] = shape

def _add_time_dim_to_slice(
self, name: str, slice_: Union[List[int], Tuple[int]], n_dim: int
) -> Tuple[int]:
# Case 1: There is never a time dim. No changes needed.
if name in NEVER_TIME_VARYING:
return slice_

# Case 2: The matrix has a time dim, and it was requested. No changes needed.
if len(slice_) == n_dim:
return slice_

# Case 3: There's no time dim on the matrix, and none requested. Slice away the dummy dim.
if len(slice_) < n_dim:
empty_slice = (slice(None, None, None),)
n_omitted = n_dim - len(slice_) - 1
return tuple(slice_) + empty_slice * n_omitted + (0,)

@staticmethod
def _validate_key_and_get_type(key: KeyLike) -> Type[str]:
if isinstance(key, tuple) and not isinstance(key[0], str):
raise IndexError("First index must the name of a valid state space matrix.")

return type(key)

def _validate_matrix_shape(self, name: str, X: np.ndarray) -> None:
*expected_shape, time_dim = self.shapes[name]
expected_shape = tuple(expected_shape)

is_vector = name in VECTOR_VALUED
not_time_varying = name in NEVER_TIME_VARYING

if not_time_varying:
if is_vector:
if X.ndim != 1:
raise ValueError(
f"Array provided for {name} has {X.ndim} dimensions, but it must have exactly 1."
)

else:
if X.ndim != 2:
raise ValueError(
f"Array provided for {name} has {X.ndim} dimensions, but it must have exactly 2."
)

else:
if is_vector:
if X.ndim not in [1, 2]:
raise ValueError(
f"Array provided for {name} has {X.ndim} dimensions, "
f"expecting 1 (static) or 2 (time-varying)"
)
if X.ndim == 2 and X.shape[:-1] != expected_shape:
raise ValueError(
f"First dimension of array provided for {name} has shape {X.shape[0]}, "
f"expected {expected_shape}"
)

else:
if X.ndim not in [2, 3]:
raise ValueError(
f"Array provided for {name} has {X.ndim} dimensions, "
f"expecting 2 (static) or 3 (time-varying)"
)

if X.ndim == 3 and X.shape[:-1] != expected_shape:
raise ValueError(
f"First two dimensions of array provided for {name} has shape {X.shape[:-1]}, "
f"expected {expected_shape}"
)

# TODO: Think of another way to validate shapes of time-varying matrices if we don't know the data
# when the PytensorRepresentation is recreated
# if X.shape[-1] != self.data.shape[0]:
# raise ValueError(
# f"Last dimension (time dimension) of array provided for {name} has shape "
# f"{X.shape[-1]}, expected {self.data.shape[0]} (equal to the first dimension of the "
# f"provided data)"
# )

def _numpy_to_pytensor(self, name: str, X: np.ndarray) -> pt.TensorVariable:
X = X.copy()
self._validate_matrix_shape(name, X)

# Add a time dimension if one isn't provided
if name not in NEVER_TIME_VARYING:
if X.ndim == 1 and name in VECTOR_VALUED:
X = X[..., None]
elif X.ndim == 2:
X = X[..., None]
return pt.as_tensor(X, name=name, dtype=floatX)

def __getitem__(self, key: KeyLike) -> pt.TensorVariable:
_type = self._validate_key_and_get_type(key)

# Case 1: user asked for an entire matrix by name
if _type is str:
self._validate_key(key)
matrix = getattr(self, key)

# Slice away the time dimension if it's a dummy
if (self.shapes[key][-1] == 1) and (key not in NEVER_TIME_VARYING):
return matrix[(slice(None),) * (matrix.ndim - 1) + (0,)]

# If it's time varying, return everything
else:
return matrix

# Case 2: user asked for a particular matrix and some slices of it
elif _type is tuple:
name, *slice_ = key
self._validate_key(name)

matrix = getattr(self, name)
slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim)

return matrix[slice_]

# Case 3: There is only one slice index, but it's not a string
else:
raise IndexError("First index must the name of a valid state space matrix.")

def __setitem__(self, key: KeyLike, value: Union[float, int, np.ndarray]) -> None:
_type = type(key)
# Case 1: key is a string: we are setting an entire matrix.
if _type is str:
self._validate_key(key)
if isinstance(value, np.ndarray):
value = self._numpy_to_pytensor(key, value)
setattr(self, key, value)
self.update_shape(key, value)

# Case 2: key is a string plus a slice: we are setting a subset of a matrix
elif _type is tuple:
name, *slice_ = key
self._validate_key(name)

matrix = getattr(self, name)

slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim)

matrix = pt.set_subtensor(matrix[slice_], value)
setattr(self, name, matrix)
Loading