Skip to content
This repository was archived by the owner on Aug 31, 2022. It is now read-only.

Commit d340fe4

Browse files
committed
- a few more tests to check differentiability of optimal potentials.
- kernel_ridge is now only applied if problem is balanced (this can perturb results in the unbalanced case). - clean up in transport interface to add parameters from epsilon_scheduler. PiperOrigin-RevId: 386455511
1 parent 01438ee commit d340fe4

File tree

7 files changed

+159
-50
lines changed

7 files changed

+159
-50
lines changed

ott/core/sinkhorn.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -241,41 +241,8 @@ def sinkhorn(
241241
Raises:
242242
ValueError: If momentum parameter is not set correctly, or to a wrong value.
243243
"""
244-
if jit:
245-
call_to_sinkhorn = functools.partial(
246-
jax.jit, static_argnums=(3, 4, 6, 7, 8, 9) + tuple(range(11, 17)))(
247-
_sinkhorn)
248-
else:
249-
call_to_sinkhorn = _sinkhorn
250-
return call_to_sinkhorn(geom, a, b, tau_a, tau_b, threshold, norm_error,
251-
inner_iterations, min_iterations, max_iterations,
252-
momentum, chg_momentum_from, lse_mode,
253-
implicit_differentiation,
254-
linear_solve_kwargs, parallel_dual_updates,
255-
use_danskin, init_dual_a, init_dual_b)
256-
257244

258-
def _sinkhorn(
259-
geom: geometry.Geometry,
260-
a: Optional[jnp.ndarray] = None,
261-
b: Optional[jnp.ndarray] = None,
262-
tau_a: float = 1.0,
263-
tau_b: float = 1.0,
264-
threshold: float = 1e-3,
265-
norm_error: int = 1,
266-
inner_iterations: int = 10,
267-
min_iterations: int = 0,
268-
max_iterations: int = 2000,
269-
momentum: float = 1.0,
270-
chg_momentum_from: int = 0,
271-
lse_mode: bool = True,
272-
implicit_differentiation: bool = True,
273-
linear_solve_kwargs: Optional[Mapping[str, Union[Callable, float]]] = None,
274-
parallel_dual_updates: bool = False,
275-
use_danskin: bool = None,
276-
init_dual_a: Optional[jnp.ndarray] = None,
277-
init_dual_b: Optional[jnp.ndarray] = None) -> SinkhornOutput:
278-
"""Checks inputs and forks between implicit/backprop exec of Sinkhorn."""
245+
# Start by checking inputs.
279246
num_a, num_b = geom.shape
280247
a = jnp.ones((num_a,)) / num_a if a is None else a
281248
b = jnp.ones((num_b,)) / num_b if b is None else b
@@ -298,11 +265,49 @@ def _sinkhorn(
298265
# if that was not the error requested by the user.
299266
norm_error = (norm_error,) if norm_error == 1 else (norm_error, 1)
300267

268+
if jit:
269+
call_to_sinkhorn = functools.partial(
270+
jax.jit, static_argnums=(3, 4, 6, 7, 8, 9) + tuple(range(11, 17)))(
271+
_sinkhorn)
272+
else:
273+
call_to_sinkhorn = _sinkhorn
274+
return call_to_sinkhorn(geom, a, b, tau_a, tau_b, threshold, norm_error,
275+
inner_iterations, min_iterations, max_iterations,
276+
momentum, chg_momentum_from, lse_mode,
277+
implicit_differentiation,
278+
linear_solve_kwargs, parallel_dual_updates,
279+
use_danskin, init_dual_a, init_dual_b)
280+
281+
282+
def _sinkhorn(
283+
geom: geometry.Geometry,
284+
a: jnp.ndarray,
285+
b: jnp.ndarray,
286+
tau_a: float,
287+
tau_b: float,
288+
threshold: float,
289+
norm_error: int,
290+
inner_iterations: int,
291+
min_iterations: int,
292+
max_iterations: int,
293+
momentum: float,
294+
chg_momentum_from: int,
295+
lse_mode: bool,
296+
implicit_differentiation: bool,
297+
linear_solve_kwargs: Mapping[str, Union[Callable, float]],
298+
parallel_dual_updates: bool,
299+
use_danskin: bool,
300+
init_dual_a: jnp.ndarray,
301+
init_dual_b: jnp.ndarray) -> SinkhornOutput:
302+
"""Forks between implicit/backprop exec of Sinkhorn."""
303+
301304
if implicit_differentiation:
302305
iteration_fun = _sinkhorn_iterations_implicit
303306
else:
304307
iteration_fun = _sinkhorn_iterations
305308

309+
# By default, use Danskin theorem to differentiate
310+
# the objective when using implicit_differentiation.
306311
use_danskin = implicit_differentiation if use_danskin is None else use_danskin
307312

308313
f, g, errors = iteration_fun(tau_a, tau_b, inner_iterations, min_iterations,
@@ -337,6 +342,7 @@ def _sinkhorn(
337342
converged = jnp.logical_and(
338343
jnp.sum(errors == -1) > 0,
339344
jnp.sum(jnp.isnan(errors)) == 0)
345+
340346
return SinkhornOutput(f, g, reg_ot_cost, errors, converged)
341347

342348

@@ -845,7 +851,7 @@ def apply_inv_hessian(gr: Tuple[np.ndarray],
845851
tau_b: float, ratio lam/(lam+eps), ratio of regularizers, second marginal.
846852
lse_mode: bool, log-sum-exp mode if True, kernel else.
847853
linear_solver_fun: Callable, should return (solution, ...)
848-
ridge_kernel: promotes zero-sum solutions.
854+
ridge_kernel: promotes zero-sum solutions. only used if tau_a = tau_b = 1.0
849855
ridge_identity: handles rank deficient transport matrices (this happens
850856
typically when rows/cols in cost/kernel matrices are colinear, or,
851857
equivalently when two points from either measure are close).
@@ -866,8 +872,12 @@ def apply_inv_hessian(gr: Tuple[np.ndarray],
866872

867873
solve_fun = lambda lin_op, b: linear_solver_fun(lin_op, b)[0]
868874

869-
# Forks on using Schur complement of either A or D, depending on size.
870875
n, m = geom.shape
876+
# Remove ridge on kernel space if problem is balanced.
877+
ridge_kernel = jnp.where(tau_a == 1.0 and tau_b == 1.0,
878+
ridge_kernel,
879+
0.0)
880+
# Forks on using Schur complement of either A or D, depending on size.
871881
if n > m: # if n is bigger, run m x m linear system.
872882
inv_vjp_ff = lambda z: z / diag_hess_a
873883
vjp_gg = lambda z: z * diag_hess_b

ott/tools/soft_sort.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def transport_for_sort(
2929
weights: jnp.ndarray,
3030
target_weights: jnp.ndarray,
3131
squashing_fun: Callable[[jnp.ndarray], jnp.ndarray] = None,
32+
epsilon: float = 1e-2,
3233
**kwargs) -> jnp.ndarray:
3334
r"""Solves reg. OT, from inputs to a weighted family of increasing values.
3435
@@ -53,21 +54,15 @@ def transport_for_sort(
5354

5455
x = jnp.expand_dims(jnp.squeeze(inputs), axis=1)
5556
if squashing_fun is None:
56-
squashing_fun = lambda z : jax.nn.sigmoid(
57-
(z - jnp.mean(z)) / (jnp.std(z) + 1e-10))
57+
squashing_fun = lambda z: jax.nn.sigmoid((z - jnp.mean(z)) /
58+
(jnp.std(z) + 1e-10))
5859
x = squashing_fun(x)
5960
a = jnp.squeeze(weights)
6061
b = jnp.squeeze(target_weights)
6162
num_targets = b.shape[0]
6263
y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis]
6364

64-
# When runnning soft-sort, the entries are remapped into the segment [0,1].
65-
# For that reason, it makes sense to have a default epsilon value adapted
66-
# to that scale. If none is passed, we provide a default of 1e-2.
67-
epsilon = kwargs.pop('epsilon', None)
68-
scale = kwargs.pop('scale', None)
69-
kwargs.update(epsilon=(1e-2 if epsilon is None else epsilon))
70-
return transport.Transport(x, y, a=a, b=b, **kwargs)
65+
return transport.Transport(x, y, a=a, b=b, epsilon=epsilon, **kwargs)
7166

7267

7368
def apply_on_axis(op, inputs, axis, *args, **kwargs):

ott/tools/transport.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def __init__(self, *args, a=None, b=None, **kwargs):
5858
self.geom = args[0]
5959
else:
6060
pc_kw = {}
61-
for key in ['epsilon', 'cost_fn', 'power', 'online']:
61+
for key in ['epsilon', 'cost_fn', 'power', 'online', 'relative_epsilon',
62+
'target', 'scale', 'init', 'decay']:
6263
value = kwargs.pop(key, None)
6364
if value is not None:
6465
pc_kw[key] = value

ott/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515

1616
"""Current ott version."""
1717

18-
__version__ = "0.1.13"
18+
__version__ = "0.1.14"

tests/core/sinkhorn_hessian_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def test_hessian_sinkhorn(self, lse_mode, tau_a, tau_b, shape, arg):
4141
eps = 1e-3
4242
n, m = shape
4343
# use slightly different parameter to test linear_solve_kwargs
44-
linear_solve_kwargs = {'ridge_kernel' : 1.2e-4}
44+
linear_solve_kwargs = {'ridge_kernel' : 1.2e-4, 'ridge_identity': .9e-4}
45+
4546
dim = 3
4647
rngs = jax.random.split(self.rng, 6)
4748
x = jax.random.uniform(rngs[0], (n, dim))
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# coding=utf-8
2+
# Copyright 2021 Google LLC.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Lint as: python3
17+
"""Tests for the Policy."""
18+
19+
from absl.testing import absltest
20+
from absl.testing import parameterized
21+
import jax
22+
import jax.numpy as jnp
23+
import jax.test_util
24+
from ott.tools import transport
25+
26+
27+
class SinkhornHessianTest(jax.test_util.JaxTestCase):
28+
29+
def setUp(self):
30+
super().setUp()
31+
self.rng = jax.random.PRNGKey(0)
32+
33+
@parameterized.product(
34+
lse_mode=[True, False],
35+
tau_a=[1.0, .93],
36+
tau_b=[1.0, .91],
37+
shape=[(12, 15), (27, 18), (345, 434)],
38+
arg=[0, 1])
39+
def test_potential_jacobian_sinkhorn(self, lse_mode,
40+
tau_a, tau_b, shape, arg):
41+
"""Test Jacobian of optimal potential w.r.t. weights and locations."""
42+
n, m = shape
43+
dim = 3
44+
rngs = jax.random.split(self.rng, 6)
45+
x = jax.random.uniform(rngs[0], (n, dim))
46+
y = jax.random.uniform(rngs[1], (m, dim))
47+
a = jax.random.uniform(rngs[2], (n,)) +.2
48+
b = jax.random.uniform(rngs[3], (m,)) +.2
49+
a = a / jnp.sum(a)
50+
b = b / jnp.sum(b)
51+
52+
# As expected, lse_mode False has a harder time with small epsilon.
53+
epsilon = 0.01 if lse_mode else 0.1
54+
55+
random_dir = jax.random.uniform(rngs[2], (n,)) / n
56+
random_dir = random_dir - jnp.mean(random_dir)
57+
58+
def loss_from_potential(a, x, implicit):
59+
out = transport.Transport(
60+
x, y, epsilon=epsilon, a=a, b=b, tau_a=tau_a, tau_b=tau_b,
61+
lse_mode=lse_mode, implicit_differentiation=implicit
62+
)
63+
return jnp.sum(random_dir * out._f)
64+
65+
delta_a = jax.random.uniform(rngs[4], (n,))
66+
delta_a = delta_a - jnp.mean(delta_a)
67+
delta_x = jax.random.uniform(rngs[5], (n, dim))
68+
69+
# Compute implicit gradient
70+
loss_imp = jax.jit(jax.value_and_grad(
71+
lambda a, x: loss_from_potential(a, x, True), argnums=arg))
72+
_, g_imp = loss_imp(a, x)
73+
imp_dif = jnp.sum(g_imp * (delta_a if arg == 0 else delta_x))
74+
# Compute backprop (unrolling) gradient
75+
76+
loss_back = jax.jit(jax.grad(
77+
lambda a, x: loss_from_potential(a, x, False), argnums=arg))
78+
g_back = loss_back(a, x)
79+
back_dif = jnp.sum(g_back * (delta_a if arg == 0 else delta_x))
80+
81+
# Compute finite difference
82+
perturb_scale = 1e-4
83+
a_p = a + perturb_scale * delta_a if arg == 0 else a
84+
x_p = x if arg == 0 else x + perturb_scale * delta_x
85+
a_m = a - perturb_scale * delta_a if arg == 0 else a
86+
x_m = x if arg == 0 else x - perturb_scale * delta_x
87+
88+
val_p, _ = loss_imp(a_p, x_p)
89+
val_m, _ = loss_imp(a_m, x_m)
90+
fin_dif = (val_p - val_m) / (2 * perturb_scale)
91+
92+
self.assertAllClose(fin_dif, back_dif, atol=1e-2, rtol=1e-2)
93+
self.assertAllClose(fin_dif, imp_dif, atol=1e-2, rtol=1e-2)
94+
95+
# center g_imp, g_back if balanced problem testing gradient w.r.t weights
96+
if tau_a == 1.0 and tau_b == 1.0 and arg == 0:
97+
g_imp = g_imp - jnp.mean(g_imp)
98+
g_back = g_back - jnp.mean(g_back)
99+
self.assertAllClose(g_imp, g_back, atol=5e-2, rtol=1e-2)
100+
101+
if __name__ == '__main__':
102+
absltest.main()

tests/tools/soft_sort_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def loss_fn(logits, implicit=False):
168168
val_peps = loss_fn(z + eps * delta)
169169
val_meps = loss_fn(z - eps * delta)
170170
self.assertAllClose((val_peps - val_meps)/(2 * eps),
171-
jnp.sum(grad_b * delta), atol=0.1, rtol=0.01)
171+
jnp.sum(grad_b * delta), atol=0.001, rtol=0.001)
172172
self.assertAllClose((val_peps - val_meps)/(2 * eps),
173-
jnp.sum(grad_i * delta), atol=0.1, rtol=0.01)
173+
jnp.sum(grad_i * delta), atol=0.001, rtol=0.001)
174174

175175

176176
if __name__ == '__main__':

0 commit comments

Comments
 (0)