You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 31, 2022. It is now read-only.
Not specifying sinkhorn_kwargs in gromov_wasserstein raises an AttributeError; to reproduce:
fromott.core.gromov_wassersteinimportgromov_wassersteinfromott.geometry.geometryimportGeometryimportjax.numpyasjnpx=Geometry(cost_matrix=jnp.ones((10, 10)))
y=Geometry(cost_matrix=jnp.ones((5, 5)))
gromov_wasserstein(x, y, sinkhorn_kwargs={}) # works as expectedgromov_wasserstein(x, y) # raises the error below
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_246820/4027485735.py in <module>
6 y = Geometry(cost_matrix=jnp.ones((5, 5)))
7 gromov_wasserstein(x, y, sinkhorn_kwargs={}) # works as expected
----> 8 gromov_wasserstein(x, y) # raises the error below
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/core/gromov_wasserstein.py in gromov_wasserstein(geom_x, geom_y, a, b, epsilon, loss, max_iterations, jit, warm_start, sinkhorn_kwargs, **kwargs)
151raiseValueError('Unknown loss. Either pass an instance of GWLoss or '152f'a string among: [{",".join(GW_LOSSES.keys())}]')
--> 153 tau_a = sinkhorn_kwargs.get('tau_a', 1.0)
154 tau_b = sinkhorn_kwargs.get('tau_b', 1.0)
155if tau_a !=1.0or tau_b !=1.0:
AttributeError: 'NoneType' object has no attribute 'get'