Closed
Description
Describe the issue:
I am using pytensor batched_dot scan to multiply a tensor of shape (X,Y,Z) with matrix of shape (Y,Z). I want the final product to have shape (X,Y). I planned to use batched_dot but got an error. Is there a way to get the above multiplication without using batched_dot? or is there a way to solve the above pytensor error?
Reproducable code example:
import jax
jax.config.update('jax_platform_name', 'cpu')
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
from pymc.pytensorf import collect_default_updates
u = np.zeros((400,86,10))
y = np.ones((86,400))
T = 86
latent_variables_ar = 10
with pm.Model()as mod:
def step(x, A, Q):
innov = pm.MvNormal.dist(mu=0, tau=Q)
next_x = pt.nlinalg.matrix_dot(x,A) + innov
return next_x, collect_default_updates([x, A, Q], [next_x])
x0_ar = pt.zeros(latent_variables_ar)
mu2_ar = np.zeros(latent_variables_ar)
sd_dist_ar = pm.Exponential.dist(1.0, shape=latent_variables_ar)
chol2_ar, corr_ar, stds_ar = pm.LKJCholeskyCov('chol_cov_ar', n=latent_variables_ar, eta=2,
sd_dist=sd_dist_ar, compute_corr=True)
A_ar = pm.MvNormal('A_ar', mu=mu2_ar, chol=chol2_ar, shape=(latent_variables_ar,latent_variables_ar))
sigmas_Q_ar = pm.HalfNormal('sigmas_Q_ar', sigma=1, shape= (latent_variables_ar))
Q_ar = pt.diag(sigmas_Q_ar)
ar_states_pt, ar_updates = pytensor.scan(step,
outputs_info=[x0_ar],
non_sequences=[A_ar, Q_ar],
n_steps=T,
strict=True)
mod.register_rv(ar_states_pt, name='ar_states_pt', initval=pt.zeros((T, latent_variables_ar)))
lambdas = pm.Deterministic("lambdas",pt.batched_dot(ar_states_pt, pt.transpose(u,axes=(1,2,0))) )
obs = pm.Poisson('obs', lambdas, observed=y)
with mod:
inference = pm.ADVI()
tracker = pm.callbacks.Tracker(
mean= inference.approx.mean.eval, # callable that returns mean
std= inference.approx.std.eval # callable that returns std
)
approx = pm.fit(n= 20000, method=inference, callbacks=[tracker],obj_optimizer=pm.adam(learning_rate=0.25), obj_n_mc=10)
idata = approx.sample(2000)
Error message:
.pytensor/compiledir_Linux-5.10-amzn2.x86_64-x86_64-with-glibc2.17-x86_64-3.8.11-64/tmpgwfgby2j/mbac13d27c3e2bdce3d5bc031b91c621a29f7883f3597e67f0dd72245e2913160.so: undefined symbol: dgemm_
Warning-
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
PyTensor version information:
numpy==1.22.1
pymc==5.3.1
pytensor==2.11.2
python==3.8
Context for the issue:
No response