@@ -241,41 +241,8 @@ def sinkhorn(
241
241
Raises:
242
242
ValueError: If momentum parameter is not set correctly, or to a wrong value.
243
243
"""
244
- if jit :
245
- call_to_sinkhorn = functools .partial (
246
- jax .jit , static_argnums = (3 , 4 , 6 , 7 , 8 , 9 ) + tuple (range (11 , 17 )))(
247
- _sinkhorn )
248
- else :
249
- call_to_sinkhorn = _sinkhorn
250
- return call_to_sinkhorn (geom , a , b , tau_a , tau_b , threshold , norm_error ,
251
- inner_iterations , min_iterations , max_iterations ,
252
- momentum , chg_momentum_from , lse_mode ,
253
- implicit_differentiation ,
254
- linear_solve_kwargs , parallel_dual_updates ,
255
- use_danskin , init_dual_a , init_dual_b )
256
-
257
244
258
- def _sinkhorn (
259
- geom : geometry .Geometry ,
260
- a : Optional [jnp .ndarray ] = None ,
261
- b : Optional [jnp .ndarray ] = None ,
262
- tau_a : float = 1.0 ,
263
- tau_b : float = 1.0 ,
264
- threshold : float = 1e-3 ,
265
- norm_error : int = 1 ,
266
- inner_iterations : int = 10 ,
267
- min_iterations : int = 0 ,
268
- max_iterations : int = 2000 ,
269
- momentum : float = 1.0 ,
270
- chg_momentum_from : int = 0 ,
271
- lse_mode : bool = True ,
272
- implicit_differentiation : bool = True ,
273
- linear_solve_kwargs : Optional [Mapping [str , Union [Callable , float ]]] = None ,
274
- parallel_dual_updates : bool = False ,
275
- use_danskin : bool = None ,
276
- init_dual_a : Optional [jnp .ndarray ] = None ,
277
- init_dual_b : Optional [jnp .ndarray ] = None ) -> SinkhornOutput :
278
- """Checks inputs and forks between implicit/backprop exec of Sinkhorn."""
245
+ # Start by checking inputs.
279
246
num_a , num_b = geom .shape
280
247
a = jnp .ones ((num_a ,)) / num_a if a is None else a
281
248
b = jnp .ones ((num_b ,)) / num_b if b is None else b
@@ -298,11 +265,49 @@ def _sinkhorn(
298
265
# if that was not the error requested by the user.
299
266
norm_error = (norm_error ,) if norm_error == 1 else (norm_error , 1 )
300
267
268
+ if jit :
269
+ call_to_sinkhorn = functools .partial (
270
+ jax .jit , static_argnums = (3 , 4 , 6 , 7 , 8 , 9 ) + tuple (range (11 , 17 )))(
271
+ _sinkhorn )
272
+ else :
273
+ call_to_sinkhorn = _sinkhorn
274
+ return call_to_sinkhorn (geom , a , b , tau_a , tau_b , threshold , norm_error ,
275
+ inner_iterations , min_iterations , max_iterations ,
276
+ momentum , chg_momentum_from , lse_mode ,
277
+ implicit_differentiation ,
278
+ linear_solve_kwargs , parallel_dual_updates ,
279
+ use_danskin , init_dual_a , init_dual_b )
280
+
281
+
282
+ def _sinkhorn (
283
+ geom : geometry .Geometry ,
284
+ a : jnp .ndarray ,
285
+ b : jnp .ndarray ,
286
+ tau_a : float ,
287
+ tau_b : float ,
288
+ threshold : float ,
289
+ norm_error : int ,
290
+ inner_iterations : int ,
291
+ min_iterations : int ,
292
+ max_iterations : int ,
293
+ momentum : float ,
294
+ chg_momentum_from : int ,
295
+ lse_mode : bool ,
296
+ implicit_differentiation : bool ,
297
+ linear_solve_kwargs : Mapping [str , Union [Callable , float ]],
298
+ parallel_dual_updates : bool ,
299
+ use_danskin : bool ,
300
+ init_dual_a : jnp .ndarray ,
301
+ init_dual_b : jnp .ndarray ) -> SinkhornOutput :
302
+ """Forks between implicit/backprop exec of Sinkhorn."""
303
+
301
304
if implicit_differentiation :
302
305
iteration_fun = _sinkhorn_iterations_implicit
303
306
else :
304
307
iteration_fun = _sinkhorn_iterations
305
308
309
+ # By default, use Danskin theorem to differentiate
310
+ # the objective when using implicit_differentiation.
306
311
use_danskin = implicit_differentiation if use_danskin is None else use_danskin
307
312
308
313
f , g , errors = iteration_fun (tau_a , tau_b , inner_iterations , min_iterations ,
@@ -337,6 +342,7 @@ def _sinkhorn(
337
342
converged = jnp .logical_and (
338
343
jnp .sum (errors == - 1 ) > 0 ,
339
344
jnp .sum (jnp .isnan (errors )) == 0 )
345
+
340
346
return SinkhornOutput (f , g , reg_ot_cost , errors , converged )
341
347
342
348
@@ -845,7 +851,7 @@ def apply_inv_hessian(gr: Tuple[np.ndarray],
845
851
tau_b: float, ratio lam/(lam+eps), ratio of regularizers, second marginal.
846
852
lse_mode: bool, log-sum-exp mode if True, kernel else.
847
853
linear_solver_fun: Callable, should return (solution, ...)
848
- ridge_kernel: promotes zero-sum solutions.
854
+ ridge_kernel: promotes zero-sum solutions. only used if tau_a = tau_b = 1.0
849
855
ridge_identity: handles rank deficient transport matrices (this happens
850
856
typically when rows/cols in cost/kernel matrices are colinear, or,
851
857
equivalently when two points from either measure are close).
@@ -866,8 +872,12 @@ def apply_inv_hessian(gr: Tuple[np.ndarray],
866
872
867
873
solve_fun = lambda lin_op , b : linear_solver_fun (lin_op , b )[0 ]
868
874
869
- # Forks on using Schur complement of either A or D, depending on size.
870
875
n , m = geom .shape
876
+ # Remove ridge on kernel space if problem is balanced.
877
+ ridge_kernel = jnp .where (tau_a == 1.0 and tau_b == 1.0 ,
878
+ ridge_kernel ,
879
+ 0.0 )
880
+ # Forks on using Schur complement of either A or D, depending on size.
871
881
if n > m : # if n is bigger, run m x m linear system.
872
882
inv_vjp_ff = lambda z : z / diag_hess_a
873
883
vjp_gg = lambda z : z * diag_hess_b
0 commit comments