This repository was archived by the owner on Aug 31, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
This repository was archived by the owner on Aug 31, 2022. It is now read-only.
Specifying Geometry with kernel matrix and without epsilon causes RecursionError #18
Copy link
Copy link
Closed
Description
Code to reproduce:
from ott.core.sinkhorn import sinkhorn
import jax.numpy as jnp
geom_e = Geometry(kernel_matrix=jnp.ones((10, 10)), epsilon=1e-2)
print(geom_e.cost_matrix) # ok
geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
print(geom.cost_matrix) # raises the error below
---------------------------------------------------------------------------
RecursionError Traceback (most recent call last)
/tmp/ipykernel_246820/705571593.py in <module>
2 print(geom_e.cost_matrix)
3 geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
----> 4 print(geom.cost_matrix)
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
107 def cost_matrix(self):
108 if self._cost_matrix is None:
--> 109 return -self.epsilon * jnp.log(self._kernel_matrix)
110 return self._cost_matrix
111
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in epsilon(self)
130 @property
131 def epsilon(self):
--> 132 return self._epsilon.target
133
134 @property
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in _epsilon(self)
102 return self._epsilon_init
103 eps = 5e-2 if self._epsilon_init is None else self._epsilon_init
--> 104 return epsilon_scheduler.Epsilon.make(eps, scale=self.scale, **self._kwargs)
105
106 @property
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in scale(self)
92 if (self._scale is None) and (trigger is not None): # for dry run
93 return jnp.where(
---> 94 trigger, jax.lax.stop_gradient(self.mean_cost_matrix), 1.0)
95 else:
96 return self._scale
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in mean_cost_matrix(self)
116 @property
117 def mean_cost_matrix(self):
--> 118 if isinstance(self.shape[0], int) and (self.shape[0] > 0):
119 return jnp.sum(self.apply_cost(jnp.ones((self.shape[0],)))) / (
120 self.shape[0] * self.shape[1])
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in shape(self)
134 @property
135 def shape(self):
--> 136 mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
137 if mat is not None:
138 return mat.shape
... last 6 frames repeated, from the frame below ...
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
107 def cost_matrix(self):
108 if self._cost_matrix is None:
--> 109 return -self.epsilon * jnp.log(self._kernel_matrix)
110 return self._cost_matrix
111
RecursionError: maximum recursion depth exceeded while calling a Python object
Version: 0.1.17
Metadata
Metadata
Assignees
Labels
No labels