Skip to content

ENH: Revise the implementation of FuzzyOverlap #2530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 66 additions & 69 deletions nipype/algorithms/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from .. import config, logging
from ..utils.misc import package_check

from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File,
InputMultiPath, BaseInterfaceInputSpec,
isdefined)
from ..utils import NUMPY_MMAP
from ..interfaces.base import (
SimpleInterface, BaseInterface, traits, TraitedSpec, File,
InputMultiPath, BaseInterfaceInputSpec,
isdefined)

iflogger = logging.getLogger('interface')

Expand Down Expand Up @@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
File(exists=True),
mandatory=True,
desc='Test image. Requires the same dimensions as in_ref.')
in_mask = File(exists=True, desc='calculate overlap only within mask')
weighting = traits.Enum(
'none',
'volume',
Expand All @@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
class FuzzyOverlapOutputSpec(TraitedSpec):
jaccard = traits.Float(desc='Fuzzy Jaccard Index (fJI), all the classes')
dice = traits.Float(desc='Fuzzy Dice Index (fDI), all the classes')
diff_file = File(
exists=True,
desc=
'resulting difference-map of all classes, using the chosen weighting')
class_fji = traits.List(
traits.Float(),
desc='Array containing the fJIs of each computed class')
Expand All @@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec):
desc='Array containing the fDIs of each computed class')


class FuzzyOverlap(BaseInterface):
class FuzzyOverlap(SimpleInterface):
"""Calculates various overlap measures between two maps, using the fuzzy
definition proposed in: Crum et al., Generalized Overlap Measures for
Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med.
Expand All @@ -439,77 +436,77 @@ class FuzzyOverlap(BaseInterface):
output_spec = FuzzyOverlapOutputSpec

def _run_interface(self, runtime):
ncomp = len(self.inputs.in_ref)
assert (ncomp == len(self.inputs.in_tst))
weights = np.ones(shape=ncomp)

img_ref = np.array([
nb.load(fname, mmap=NUMPY_MMAP).get_data()
for fname in self.inputs.in_ref
])
img_tst = np.array([
nb.load(fname, mmap=NUMPY_MMAP).get_data()
for fname in self.inputs.in_tst
])

msk = np.sum(img_ref, axis=0)
msk[msk > 0] = 1.0
tst_msk = np.sum(img_tst, axis=0)
tst_msk[tst_msk > 0] = 1.0

# check that volumes are normalized
# img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0]
# img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0]

self._jaccards = []
volumes = []

diff_im = np.zeros(img_ref.shape)

for ref_comp, tst_comp, diff_comp in zip(img_ref, img_tst, diff_im):
num = np.minimum(ref_comp, tst_comp)
ddr = np.maximum(ref_comp, tst_comp)
diff_comp[ddr > 0] += 1.0 - (num[ddr > 0] / ddr[ddr > 0])
self._jaccards.append(np.sum(num) / np.sum(ddr))
volumes.append(np.sum(ref_comp))

self._dices = 2.0 * (np.array(self._jaccards) /
(np.array(self._jaccards) + 1.0))
# Load data
refdata = nb.concat_images(self.inputs.in_ref).get_data()
tstdata = nb.concat_images(self.inputs.in_tst).get_data()

# Data must have same shape
if not refdata.shape == tstdata.shape:
raise RuntimeError(
'Size of "in_tst" %s must match that of "in_ref" %s.' %
(tstdata.shape, refdata.shape))

# Load mask
mask = np.ones_like(refdata[..., 0], dtype=bool)
if isdefined(self.inputs.in_mask):
mask = nb.load(self.inputs.in_mask).get_data()
mask = mask > 0
assert mask.shape == refdata.shape[:-1]

ncomp = refdata.shape[-1]

# Drop data outside mask
refdata = refdata[mask[..., np.newaxis]]
tstdata = tstdata[mask[..., np.newaxis]]

if np.any(refdata < 0.0):
iflogger.warning('Negative values encountered in "in_ref" input, '
'taking absolute values.')
refdata = np.abs(refdata)

if np.any(tstdata < 0.0):
iflogger.warning('Negative values encountered in "in_tst" input, '
'taking absolute values.')
tstdata = np.abs(tstdata)

if np.any(refdata > 1.0):
iflogger.warning('Values greater than 1.0 found in "in_ref" input, '
'scaling values.')
refdata /= refdata.max()

if np.any(tstdata > 1.0):
iflogger.warning('Values greater than 1.0 found in "in_tst" input, '
'scaling values.')
tstdata /= tstdata.max()

numerators = np.atleast_2d(
np.minimum(refdata, tstdata).reshape((-1, ncomp)))
denominators = np.atleast_2d(
np.maximum(refdata, tstdata).reshape((-1, ncomp)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shapes look guaranteed to be 2D, already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only one item in the input lists are passed, numpy squeezes the redundant dimension. This was necessary for this interface to work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well that's surprising and annoying. Okay.


jaccards = numerators.sum(axis=0) / denominators.sum(axis=0)

# Calculate weights
weights = np.ones_like(jaccards, dtype=float)
if self.inputs.weighting != "none":
volumes = np.sum((refdata + tstdata) > 0, axis=1).reshape((-1, ncomp))
weights = 1.0 / np.array(volumes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

volumes is already an array.

if self.inputs.weighting == "squared_vol":
weights = weights**2

weights = weights / np.sum(weights)
dices = 2.0 * jaccards / (jaccards + 1.0)

setattr(self, '_jaccard', np.sum(weights * self._jaccards))
setattr(self, '_dice', np.sum(weights * self._dices))

diff = np.zeros(diff_im[0].shape)

for w, ch in zip(weights, diff_im):
ch[msk == 0] = 0
diff += w * ch

nb.save(
nb.Nifti1Image(diff,
nb.load(self.inputs.in_ref[0]).affine,
nb.load(self.inputs.in_ref[0]).header),
self.inputs.out_file)
# Fill-in the results object
self._results['jaccard'] = float(np.sum(weights * jaccards))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just use the dot product here?

self._results['jaccard'] = weights.dot(jaccards)

self._results['dice'] = float(np.sum(weights * dices))
self._results['class_fji'] = [
float(v) for v in jaccards.astype(float).tolist()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jaccards.astype(float).tolist() seems like it should be fine? I think this and the next comprehension are redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is me over-worrying about output traits. In the past I've seen that ndarray.astype(float).tolist() returned a list of numpy floats that traits didn't like. I'll check if this is not necessary anymore.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. If that's the case, then [float(v) for v in jaccards] is another option. There shouldn't be any change to values or iteration by dropping astype(float) or tolist().

self._results['class_fdi'] = [
float(v) for v in dices.astype(float).tolist()]

return runtime

def _list_outputs(self):
outputs = self._outputs().get()
for method in ("dice", "jaccard"):
outputs[method] = getattr(self, '_' + method)
# outputs['volume_difference'] = self._volume
outputs['diff_file'] = os.path.abspath(self.inputs.out_file)
outputs['class_fji'] = np.array(self._jaccards).astype(float).tolist()
outputs['class_fdi'] = self._dices.astype(float).tolist()
return outputs


class ErrorMapInputSpec(BaseInterfaceInputSpec):
in_ref = File(
Expand Down
2 changes: 1 addition & 1 deletion nipype/algorithms/tests/test_auto_FuzzyOverlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def test_FuzzyOverlap_inputs():
nohash=True,
usedefault=True,
),
in_mask=dict(),
in_ref=dict(mandatory=True, ),
in_tst=dict(mandatory=True, ),
out_file=dict(usedefault=True, ),
Expand All @@ -25,7 +26,6 @@ def test_FuzzyOverlap_outputs():
class_fdi=dict(),
class_fji=dict(),
dice=dict(),
diff_file=dict(),
jaccard=dict(),
)
outputs = FuzzyOverlap.output_spec()
Expand Down