|
16 | 16 | import numpy as np |
17 | 17 | import torch |
18 | 18 | from parameterized import parameterized |
| 19 | +from scipy.ndimage import distance_transform_edt |
19 | 20 |
|
20 | 21 | from monai.metrics import SurfaceDistanceMetric |
| 22 | +from monai.metrics.utils import get_mask_edges, get_surface_distance |
21 | 23 |
|
22 | 24 | _device = "cuda:0" if torch.cuda.is_available() else "cpu" |
23 | 25 |
|
@@ -182,5 +184,43 @@ def test_nans(self, input_data): |
182 | 184 | np.testing.assert_allclose(0, not_nans, rtol=1e-5) |
183 | 185 |
|
184 | 186 |
|
| 187 | +KDTREE_SPACINGS = [["isotropic_default", None], ["isotropic", (1.0, 1.0, 1.0)], ["anisotropic", (1.0, 2.5, 0.5)]] |
| 188 | + |
| 189 | + |
| 190 | +def _edge_masks(seed=0): |
| 191 | + # two offset spheres plus a few scattered false positives in the prediction, so the |
| 192 | + # surfaces are non-trivially apart and an outlier expands the cropped bounding box. |
| 193 | + gt = create_spherical_seg_3d(radius=20, centre=(30, 30, 30)) |
| 194 | + pred = create_spherical_seg_3d(radius=20, centre=(32, 31, 30)) |
| 195 | + rng = np.random.RandomState(seed) |
| 196 | + for _ in range(5): |
| 197 | + pred[tuple(rng.randint(0, s) for s in pred.shape)] = 1 |
| 198 | + edges_pred, edges_gt = get_mask_edges(pred, gt) |
| 199 | + return np.asarray(edges_pred, dtype=bool), np.asarray(edges_gt, dtype=bool) |
| 200 | + |
| 201 | + |
| 202 | +class TestSurfaceDistanceKDTreeMatchesEDT(unittest.TestCase): |
| 203 | + @parameterized.expand(KDTREE_SPACINGS) |
| 204 | + def test_cpu_kdtree_euclidean_distances_match_dense_edt(self, _name, spacing): |
| 205 | + edges_pred, edges_gt = _edge_masks() |
| 206 | + result = np.asarray(get_surface_distance(edges_pred, edges_gt, distance_metric="euclidean", spacing=spacing)) |
| 207 | + reference = distance_transform_edt(~edges_gt, sampling=spacing)[edges_pred] |
| 208 | + # same multiset of distances (downstream metrics only use max/percentile/mean) |
| 209 | + np.testing.assert_allclose(np.sort(result), np.sort(reference), rtol=1e-5, atol=1e-5) |
| 210 | + self.assertEqual(result.dtype, np.float32) |
| 211 | + self.assertEqual(result.shape, reference.shape) |
| 212 | + |
| 213 | + def test_torch_input_preserves_type_device_and_matches_dense_edt(self): |
| 214 | + edges_pred, edges_gt = _edge_masks() |
| 215 | + spacing = (1.0, 2.5, 0.5) |
| 216 | + seg_pred, seg_gt = torch.as_tensor(edges_pred), torch.as_tensor(edges_gt) |
| 217 | + result = get_surface_distance(seg_pred, seg_gt, distance_metric="euclidean", spacing=spacing) |
| 218 | + self.assertIsInstance(result, torch.Tensor) |
| 219 | + self.assertEqual(result.dtype, torch.float32) |
| 220 | + self.assertEqual(result.device, seg_pred.device) |
| 221 | + reference = distance_transform_edt(~edges_gt, sampling=spacing)[edges_pred] |
| 222 | + np.testing.assert_allclose(np.sort(result.cpu().numpy()), np.sort(reference), rtol=1e-5, atol=1e-5) |
| 223 | + |
| 224 | + |
185 | 225 | if __name__ == "__main__": |
186 | 226 | unittest.main() |
0 commit comments