Skip to content

ENH: Revise the implementation of FuzzyOverlap #2530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 66 additions & 71 deletions nipype/algorithms/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

from .. import config, logging

from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File,
InputMultiPath, BaseInterfaceInputSpec,
isdefined)
from ..interfaces.base import (
SimpleInterface, BaseInterface, traits, TraitedSpec, File,
InputMultiPath, BaseInterfaceInputSpec,
isdefined)
from ..interfaces.nipy.base import NipyBaseInterface
from ..utils import NUMPY_MMAP

iflogger = logging.getLogger('interface')

Expand Down Expand Up @@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
File(exists=True),
mandatory=True,
desc='Test image. Requires the same dimensions as in_ref.')
in_mask = File(exists=True, desc='calculate overlap only within mask')
weighting = traits.Enum(
'none',
'volume',
Expand All @@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
class FuzzyOverlapOutputSpec(TraitedSpec):
jaccard = traits.Float(desc='Fuzzy Jaccard Index (fJI), all the classes')
dice = traits.Float(desc='Fuzzy Dice Index (fDI), all the classes')
diff_file = File(
exists=True,
desc=
'resulting difference-map of all classes, using the chosen weighting')
class_fji = traits.List(
traits.Float(),
desc='Array containing the fJIs of each computed class')
Expand All @@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec):
desc='Array containing the fDIs of each computed class')


class FuzzyOverlap(BaseInterface):
class FuzzyOverlap(SimpleInterface):
"""Calculates various overlap measures between two maps, using the fuzzy
definition proposed in: Crum et al., Generalized Overlap Measures for
Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med.
Expand All @@ -439,77 +436,75 @@ class FuzzyOverlap(BaseInterface):
output_spec = FuzzyOverlapOutputSpec

def _run_interface(self, runtime):
ncomp = len(self.inputs.in_ref)
assert (ncomp == len(self.inputs.in_tst))
weights = np.ones(shape=ncomp)

img_ref = np.array([
nb.load(fname, mmap=NUMPY_MMAP).get_data()
for fname in self.inputs.in_ref
])
img_tst = np.array([
nb.load(fname, mmap=NUMPY_MMAP).get_data()
for fname in self.inputs.in_tst
])

msk = np.sum(img_ref, axis=0)
msk[msk > 0] = 1.0
tst_msk = np.sum(img_tst, axis=0)
tst_msk[tst_msk > 0] = 1.0

# check that volumes are normalized
# img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0]
# img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0]

self._jaccards = []
volumes = []

diff_im = np.zeros(img_ref.shape)

for ref_comp, tst_comp, diff_comp in zip(img_ref, img_tst, diff_im):
num = np.minimum(ref_comp, tst_comp)
ddr = np.maximum(ref_comp, tst_comp)
diff_comp[ddr > 0] += 1.0 - (num[ddr > 0] / ddr[ddr > 0])
self._jaccards.append(np.sum(num) / np.sum(ddr))
volumes.append(np.sum(ref_comp))

self._dices = 2.0 * (np.array(self._jaccards) /
(np.array(self._jaccards) + 1.0))
# Load data
refdata = nb.concat_images(self.inputs.in_ref).get_data()
tstdata = nb.concat_images(self.inputs.in_tst).get_data()

# Data must have same shape
if not refdata.shape == tstdata.shape:
raise RuntimeError(
'Size of "in_tst" %s must match that of "in_ref" %s.' %
(tstdata.shape, refdata.shape))

ncomp = refdata.shape[-1]

# Load mask
mask = np.ones_like(refdata, dtype=bool)
if isdefined(self.inputs.in_mask):
mask = nb.load(self.inputs.in_mask).get_data()
mask = mask > 0
mask = np.repeat(mask[..., np.newaxis], ncomp, -1)
assert mask.shape == refdata.shape

# Drop data outside mask
refdata = refdata[mask]
tstdata = tstdata[mask]

if np.any(refdata < 0.0):
iflogger.warning('Negative values encountered in "in_ref" input, '
'taking absolute values.')
refdata = np.abs(refdata)

if np.any(tstdata < 0.0):
iflogger.warning('Negative values encountered in "in_tst" input, '
'taking absolute values.')
tstdata = np.abs(tstdata)

if np.any(refdata > 1.0):
iflogger.warning('Values greater than 1.0 found in "in_ref" input, '
'scaling values.')
refdata /= refdata.max()

if np.any(tstdata > 1.0):
iflogger.warning('Values greater than 1.0 found in "in_tst" input, '
'scaling values.')
tstdata /= tstdata.max()

numerators = np.atleast_2d(
np.minimum(refdata, tstdata).reshape((-1, ncomp)))
denominators = np.atleast_2d(
np.maximum(refdata, tstdata).reshape((-1, ncomp)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shapes look guaranteed to be 2D, already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only one item in the input lists are passed, numpy squeezes the redundant dimension. This was necessary for this interface to work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well that's surprising and annoying. Okay.


jaccards = numerators.sum(axis=0) / denominators.sum(axis=0)

# Calculate weights
weights = np.ones_like(jaccards, dtype=float)
if self.inputs.weighting != "none":
weights = 1.0 / np.array(volumes)
volumes = np.sum((refdata + tstdata) > 0, axis=1).reshape((-1, ncomp))
weights = 1.0 / volumes
if self.inputs.weighting == "squared_vol":
weights = weights**2

weights = weights / np.sum(weights)
dices = 2.0 * jaccards / (jaccards + 1.0)

setattr(self, '_jaccard', np.sum(weights * self._jaccards))
setattr(self, '_dice', np.sum(weights * self._dices))

diff = np.zeros(diff_im[0].shape)

for w, ch in zip(weights, diff_im):
ch[msk == 0] = 0
diff += w * ch

nb.save(
nb.Nifti1Image(diff,
nb.load(self.inputs.in_ref[0]).affine,
nb.load(self.inputs.in_ref[0]).header),
self.inputs.out_file)

# Fill-in the results object
self._results['jaccard'] = float(weights.dot(jaccards))
self._results['dice'] = float(weights.dot(dices))
self._results['class_fji'] = [float(v) for v in jaccards]
self._results['class_fdi'] = [float(v) for v in dices]
return runtime

def _list_outputs(self):
outputs = self._outputs().get()
for method in ("dice", "jaccard"):
outputs[method] = getattr(self, '_' + method)
# outputs['volume_difference'] = self._volume
outputs['diff_file'] = os.path.abspath(self.inputs.out_file)
outputs['class_fji'] = np.array(self._jaccards).astype(float).tolist()
outputs['class_fdi'] = self._dices.astype(float).tolist()
return outputs


class ErrorMapInputSpec(BaseInterfaceInputSpec):
in_ref = File(
Expand Down
2 changes: 1 addition & 1 deletion nipype/algorithms/tests/test_auto_FuzzyOverlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def test_FuzzyOverlap_inputs():
nohash=True,
usedefault=True,
),
in_mask=dict(),
in_ref=dict(mandatory=True, ),
in_tst=dict(mandatory=True, ),
out_file=dict(usedefault=True, ),
Expand All @@ -25,7 +26,6 @@ def test_FuzzyOverlap_outputs():
class_fdi=dict(),
class_fji=dict(),
dice=dict(),
diff_file=dict(),
jaccard=dict(),
)
outputs = FuzzyOverlap.output_spec()
Expand Down
58 changes: 58 additions & 0 deletions nipype/algorithms/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:

import numpy as np
import nibabel as nb
from nipype.testing import example_data
from ..metrics import FuzzyOverlap


def test_fuzzy_overlap(tmpdir):
tmpdir.chdir()

# Tests with tissue probability maps
in_mask = example_data('tpms_msk.nii.gz')
tpms = [example_data('tpm_%02d.nii.gz' % i) for i in range(3)]
out = FuzzyOverlap(in_ref=tpms[0], in_tst=tpms[0]).run().outputs
assert out.dice == 1

out = FuzzyOverlap(
in_mask=in_mask, in_ref=tpms[0], in_tst=tpms[0]).run().outputs
assert out.dice == 1

out = FuzzyOverlap(
in_mask=in_mask, in_ref=tpms[0], in_tst=tpms[1]).run().outputs
assert 0 < out.dice < 1

out = FuzzyOverlap(in_ref=tpms, in_tst=tpms).run().outputs
assert out.dice == 1.0

out = FuzzyOverlap(
in_mask=in_mask, in_ref=tpms, in_tst=tpms).run().outputs
assert out.dice == 1.0

# Tests with synthetic 3x3x3 images
data = np.zeros((3, 3, 3), dtype=float)
data[0, 0, 0] = 0.5
data[2, 2, 2] = 0.25
data[1, 1, 1] = 0.3
nb.Nifti1Image(data, np.eye(4)).to_filename('test1.nii.gz')

data = np.zeros((3, 3, 3), dtype=float)
data[0, 0, 0] = 0.6
data[1, 1, 1] = 0.3
nb.Nifti1Image(data, np.eye(4)).to_filename('test2.nii.gz')

out = FuzzyOverlap(in_ref='test1.nii.gz', in_tst='test2.nii.gz').run().outputs
assert np.allclose(out.dice, 0.82051)

# Just considering the mask, the central pixel
# that raised the index now is left aside.
data = np.zeros((3, 3, 3), dtype=int)
data[0, 0, 0] = 1
data[2, 2, 2] = 1
nb.Nifti1Image(data, np.eye(4)).to_filename('mask.nii.gz')

out = FuzzyOverlap(in_ref='test1.nii.gz', in_tst='test2.nii.gz',
in_mask='mask.nii.gz').run().outputs
assert np.allclose(out.dice, 0.74074)