Skip to content
This repository was archived by the owner on Aug 31, 2022. It is now read-only.
This repository was archived by the owner on Aug 31, 2022. It is now read-only.

Specifying Geometry with kernel matrix and without epsilon causes RecursionError #18

@michalk8

Description

@michalk8

Code to reproduce:

from ott.core.sinkhorn import sinkhorn
import jax.numpy as jnp

geom_e = Geometry(kernel_matrix=jnp.ones((10, 10)), epsilon=1e-2)
print(geom_e.cost_matrix)  # ok
geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
print(geom.cost_matrix)  # raises the error below
---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
/tmp/ipykernel_246820/705571593.py in <module>
      2 print(geom_e.cost_matrix)
      3 geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
----> 4 print(geom.cost_matrix)

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    107   def cost_matrix(self):
    108     if self._cost_matrix is None:
--> 109       return -self.epsilon * jnp.log(self._kernel_matrix)
    110     return self._cost_matrix
    111 

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in epsilon(self)
    130   @property
    131   def epsilon(self):
--> 132     return self._epsilon.target
    133 
    134   @property

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in _epsilon(self)
    102       return self._epsilon_init
    103     eps = 5e-2 if self._epsilon_init is None else self._epsilon_init
--> 104     return epsilon_scheduler.Epsilon.make(eps, scale=self.scale, **self._kwargs)
    105 
    106   @property

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in scale(self)
     92     if (self._scale is None) and (trigger is not None):  # for dry run
     93       return jnp.where(
---> 94           trigger, jax.lax.stop_gradient(self.mean_cost_matrix), 1.0)
     95     else:
     96       return self._scale

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in mean_cost_matrix(self)
    116   @property
    117   def mean_cost_matrix(self):
--> 118     if isinstance(self.shape[0], int) and (self.shape[0] > 0):
    119       return jnp.sum(self.apply_cost(jnp.ones((self.shape[0],)))) / (
    120           self.shape[0] * self.shape[1])

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in shape(self)
    134   @property
    135   def shape(self):
--> 136     mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
    137     if mat is not None:
    138       return mat.shape

... last 6 frames repeated, from the frame below ...

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    107   def cost_matrix(self):
    108     if self._cost_matrix is None:
--> 109       return -self.epsilon * jnp.log(self._kernel_matrix)
    110     return self._cost_matrix
    111 

RecursionError: maximum recursion depth exceeded while calling a Python object

Version: 0.1.17

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions