Skip to content
This repository was archived by the owner on Aug 31, 2022. It is now read-only.
This repository was archived by the owner on Aug 31, 2022. It is now read-only.

GW default sinkhorn kwargs raises AttributeError #17

@michalk8

Description

@michalk8

Not specifying sinkhorn_kwargs in gromov_wasserstein raises an AttributeError; to reproduce:

from ott.core.gromov_wasserstein import gromov_wasserstein
from ott.geometry.geometry import Geometry
import jax.numpy as jnp

x = Geometry(cost_matrix=jnp.ones((10, 10)))
y = Geometry(cost_matrix=jnp.ones((5, 5)))
gromov_wasserstein(x, y, sinkhorn_kwargs={})  # works as expected
gromov_wasserstein(x, y)  # raises the error below
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_246820/4027485735.py in <module>
      6 y = Geometry(cost_matrix=jnp.ones((5, 5)))
      7 gromov_wasserstein(x, y, sinkhorn_kwargs={})  # works as expected
----> 8 gromov_wasserstein(x, y)  # raises the error below

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/core/gromov_wasserstein.py in gromov_wasserstein(geom_x, geom_y, a, b, epsilon, loss, max_iterations, jit, warm_start, sinkhorn_kwargs, **kwargs)
    151     raise ValueError('Unknown loss. Either pass an instance of GWLoss or '
    152                      f'a string among: [{",".join(GW_LOSSES.keys())}]')
--> 153   tau_a = sinkhorn_kwargs.get('tau_a', 1.0)
    154   tau_b = sinkhorn_kwargs.get('tau_b', 1.0)
    155   if tau_a != 1.0 or tau_b != 1.0:

AttributeError: 'NoneType' object has no attribute 'get'

Version 0.1.17.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions