Skip to content

Commit 53dbaa9

Browse files
Jack-Sandbergmeta-codesync[bot]
authored andcommitted
Clamp to 0 to avoid sqrt of negative numbers in mvn_hellinger_distance (#3109)
Summary: ## Motivation When trying out ScoreBO, I noticed that it tended to explore excessively. I found that the acquisition function occasionally contained NaNs that originated from the mvn_hellinger_distance function. When computing 1 - x.exp() for small x, the result is sometimes negative due to floating point inaccuriacies which causes the sqrt to output NaN. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/meta-pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #3109 Test Plan: The change is rather minor but I've added a test to `test_community/utils/test_stat_dist.py`. Reviewed By: saitcakmak Differential Revision: D88863196 Pulled By: hvarfner fbshipit-source-id: a5f72db77143c0e1c43a784f6e0b24b33a9b2e36
1 parent 645d9e5 commit 53dbaa9

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

botorch_community/utils/stat_dist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,4 @@ def mvn_hellinger_distance(
8585
L_mean_diff = torch.matmul(L_inv, mean_diff)
8686
exp_logterm = -0.125 * torch.matmul(L_mean_diff.transpose(-2, -1), L_mean_diff)
8787
sq_hdist = 1 - (base_logterm + exp_logterm.squeeze(-1)).exp()
88-
return sq_hdist.sqrt()
88+
return sq_hdist.clamp_min(0.0).sqrt()

test_community/utils/test_stat_dist.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,14 @@ def test_mvn_hellinger_distance(self):
7676
)
7777
self.assertTrue(torch.all(permean_res > 0))
7878
self.assertTrue(torch.all(percov_res > 0))
79+
80+
def test_mvn_hellinger_distance_numerical_stability(self):
81+
# Two almost equal distributions. Distance is approximatly 0.
82+
dist1_mean = torch.tensor([[-0.4615126826879162]], dtype=torch.float64)
83+
dist2_mean = torch.tensor([[-0.46151268268791173]], dtype=torch.float64)
84+
dist1_cov = torch.tensor([[0.12132352941175625]], dtype=torch.float64)
85+
dist2_cov = torch.tensor([[0.12132352941176472]], dtype=torch.float64)
86+
87+
res = mvn_hellinger_distance(dist1_mean, dist2_mean, dist1_cov, dist2_cov)
88+
self.assertFalse(res.isnan().any())
89+
self.assertAllClose(res, torch.tensor([0.0], dtype=torch.float64))

0 commit comments

Comments
 (0)