Skip to content

dgemm_ not found when using batched_dot #307

Closed
@DushyantSahoo

Description

@DushyantSahoo

Describe the issue:

I am using pytensor batched_dot scan to multiply a tensor of shape (X,Y,Z) with matrix of shape (Y,Z). I want the final product to have shape (X,Y). I planned to use batched_dot but got an error. Is there a way to get the above multiplication without using batched_dot? or is there a way to solve the above pytensor error?

Reproducable code example:

import jax
jax.config.update('jax_platform_name', 'cpu')

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt

from pymc.pytensorf import collect_default_updates
u = np.zeros((400,86,10))
y = np.ones((86,400))
T = 86
latent_variables_ar = 10
with pm.Model()as mod:
    def step(x, A, Q):
        innov = pm.MvNormal.dist(mu=0, tau=Q)

        next_x = pt.nlinalg.matrix_dot(x,A) + innov

        return next_x, collect_default_updates([x, A, Q], [next_x])
    x0_ar = pt.zeros(latent_variables_ar)
    mu2_ar = np.zeros(latent_variables_ar)
    sd_dist_ar = pm.Exponential.dist(1.0, shape=latent_variables_ar)
    chol2_ar, corr_ar, stds_ar = pm.LKJCholeskyCov('chol_cov_ar', n=latent_variables_ar, eta=2,
    sd_dist=sd_dist_ar, compute_corr=True)
    A_ar = pm.MvNormal('A_ar', mu=mu2_ar, chol=chol2_ar, shape=(latent_variables_ar,latent_variables_ar))

    sigmas_Q_ar = pm.HalfNormal('sigmas_Q_ar', sigma=1, shape= (latent_variables_ar))
    Q_ar = pt.diag(sigmas_Q_ar)    
    ar_states_pt, ar_updates = pytensor.scan(step, 
                                              outputs_info=[x0_ar], 
                                              non_sequences=[A_ar, Q_ar],
                                              n_steps=T, 
                                              strict=True)
    mod.register_rv(ar_states_pt, name='ar_states_pt', initval=pt.zeros((T, latent_variables_ar)))
    lambdas = pm.Deterministic("lambdas",pt.batched_dot(ar_states_pt, pt.transpose(u,axes=(1,2,0))) )
    obs = pm.Poisson('obs', lambdas, observed=y)
with mod:
    inference = pm.ADVI()
    tracker = pm.callbacks.Tracker(
    mean= inference.approx.mean.eval,  # callable that returns mean
    std= inference.approx.std.eval  # callable that returns std
    )
    approx = pm.fit(n= 20000, method=inference, callbacks=[tracker],obj_optimizer=pm.adam(learning_rate=0.25), obj_n_mc=10)

 

idata = approx.sample(2000) 

Error message:

.pytensor/compiledir_Linux-5.10-amzn2.x86_64-x86_64-with-glibc2.17-x86_64-3.8.11-64/tmpgwfgby2j/mbac13d27c3e2bdce3d5bc031b91c621a29f7883f3597e67f0dd72245e2913160.so: undefined symbol: dgemm_

Warning-
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

PyTensor version information:

numpy==1.22.1 pymc==5.3.1 pytensor==2.11.2 python==3.8

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions