Skip to content

dgemm_ not found when using batched_dot #307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
DushyantSahoo opened this issue May 18, 2023 · 5 comments
Closed

dgemm_ not found when using batched_dot #307

DushyantSahoo opened this issue May 18, 2023 · 5 comments
Labels
bug Something isn't working

Comments

@DushyantSahoo
Copy link

DushyantSahoo commented May 18, 2023

Describe the issue:

I am using pytensor batched_dot scan to multiply a tensor of shape (X,Y,Z) with matrix of shape (Y,Z). I want the final product to have shape (X,Y). I planned to use batched_dot but got an error. Is there a way to get the above multiplication without using batched_dot? or is there a way to solve the above pytensor error?

Reproducable code example:

import jax
jax.config.update('jax_platform_name', 'cpu')

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt

from pymc.pytensorf import collect_default_updates
u = np.zeros((400,86,10))
y = np.ones((86,400))
T = 86
latent_variables_ar = 10
with pm.Model()as mod:
    def step(x, A, Q):
        innov = pm.MvNormal.dist(mu=0, tau=Q)

        next_x = pt.nlinalg.matrix_dot(x,A) + innov

        return next_x, collect_default_updates([x, A, Q], [next_x])
    x0_ar = pt.zeros(latent_variables_ar)
    mu2_ar = np.zeros(latent_variables_ar)
    sd_dist_ar = pm.Exponential.dist(1.0, shape=latent_variables_ar)
    chol2_ar, corr_ar, stds_ar = pm.LKJCholeskyCov('chol_cov_ar', n=latent_variables_ar, eta=2,
    sd_dist=sd_dist_ar, compute_corr=True)
    A_ar = pm.MvNormal('A_ar', mu=mu2_ar, chol=chol2_ar, shape=(latent_variables_ar,latent_variables_ar))

    sigmas_Q_ar = pm.HalfNormal('sigmas_Q_ar', sigma=1, shape= (latent_variables_ar))
    Q_ar = pt.diag(sigmas_Q_ar)    
    ar_states_pt, ar_updates = pytensor.scan(step, 
                                              outputs_info=[x0_ar], 
                                              non_sequences=[A_ar, Q_ar],
                                              n_steps=T, 
                                              strict=True)
    mod.register_rv(ar_states_pt, name='ar_states_pt', initval=pt.zeros((T, latent_variables_ar)))
    lambdas = pm.Deterministic("lambdas",pt.batched_dot(ar_states_pt, pt.transpose(u,axes=(1,2,0))) )
    obs = pm.Poisson('obs', lambdas, observed=y)
with mod:
    inference = pm.ADVI()
    tracker = pm.callbacks.Tracker(
    mean= inference.approx.mean.eval,  # callable that returns mean
    std= inference.approx.std.eval  # callable that returns std
    )
    approx = pm.fit(n= 20000, method=inference, callbacks=[tracker],obj_optimizer=pm.adam(learning_rate=0.25), obj_n_mc=10)

 

idata = approx.sample(2000) 

Error message:

.pytensor/compiledir_Linux-5.10-amzn2.x86_64-x86_64-with-glibc2.17-x86_64-3.8.11-64/tmpgwfgby2j/mbac13d27c3e2bdce3d5bc031b91c621a29f7883f3597e67f0dd72245e2913160.so: undefined symbol: dgemm_

Warning-
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

PyTensor version information:

numpy==1.22.1 pymc==5.3.1 pytensor==2.11.2 python==3.8

Context for the issue:

No response

@DushyantSahoo DushyantSahoo added the bug Something isn't working label May 18, 2023
@DushyantSahoo
Copy link
Author

Has anyone got a chance to look at it? Thanks for the help in advance.

@ricardoV94
Copy link
Member

I don't see an error, just a warning, probably due to incorrect installation?

As this seems like a user question, please open a topic in our discourse forum: https://discourse.pymc.io/

@DushyantSahoo
Copy link
Author

DushyantSahoo commented May 30, 2023

Sorry for the confusing error, just updated the error. When using batched_dot, I get an error "undefined symbol: dgemm_". In the batched_dot function, it calls dgemm and for me, it is not able to find it. Do you think it is because of incorrect installation? I have been using pytensor and pymc wheel files for installation. I have also posted this on pymc discourse forum.

@ricardoV94
Copy link
Member

We recommend conda-forge which takes care of all the compiler dependencies.

If you are unsuccessful please reach out on discourse

@DushyantSahoo
Copy link
Author

Thanks for the suggestion! Setting os.environ["PYTENSOR_FLAGS"]="blas__ldflags= -L/opt/omniai/work/instance1/jupyter/ssm-env/lib -lblas " helped

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants