-
Notifications
You must be signed in to change notification settings - Fork 535
ENH: Revise the implementation of FuzzyOverlap #2530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
76b0e34
9d4f39c
36e43d2
63519b8
140b159
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,10 +21,10 @@ | |
from .. import config, logging | ||
from ..utils.misc import package_check | ||
|
||
from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File, | ||
InputMultiPath, BaseInterfaceInputSpec, | ||
isdefined) | ||
from ..utils import NUMPY_MMAP | ||
from ..interfaces.base import ( | ||
SimpleInterface, BaseInterface, traits, TraitedSpec, File, | ||
InputMultiPath, BaseInterfaceInputSpec, | ||
isdefined) | ||
|
||
iflogger = logging.getLogger('interface') | ||
|
||
|
@@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec): | |
File(exists=True), | ||
mandatory=True, | ||
desc='Test image. Requires the same dimensions as in_ref.') | ||
in_mask = File(exists=True, desc='calculate overlap only within mask') | ||
weighting = traits.Enum( | ||
'none', | ||
'volume', | ||
|
@@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec): | |
class FuzzyOverlapOutputSpec(TraitedSpec): | ||
jaccard = traits.Float(desc='Fuzzy Jaccard Index (fJI), all the classes') | ||
dice = traits.Float(desc='Fuzzy Dice Index (fDI), all the classes') | ||
diff_file = File( | ||
exists=True, | ||
desc= | ||
'resulting difference-map of all classes, using the chosen weighting') | ||
class_fji = traits.List( | ||
traits.Float(), | ||
desc='Array containing the fJIs of each computed class') | ||
|
@@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec): | |
desc='Array containing the fDIs of each computed class') | ||
|
||
|
||
class FuzzyOverlap(BaseInterface): | ||
class FuzzyOverlap(SimpleInterface): | ||
"""Calculates various overlap measures between two maps, using the fuzzy | ||
definition proposed in: Crum et al., Generalized Overlap Measures for | ||
Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med. | ||
|
@@ -439,77 +436,77 @@ class FuzzyOverlap(BaseInterface): | |
output_spec = FuzzyOverlapOutputSpec | ||
|
||
def _run_interface(self, runtime): | ||
ncomp = len(self.inputs.in_ref) | ||
assert (ncomp == len(self.inputs.in_tst)) | ||
weights = np.ones(shape=ncomp) | ||
|
||
img_ref = np.array([ | ||
nb.load(fname, mmap=NUMPY_MMAP).get_data() | ||
for fname in self.inputs.in_ref | ||
]) | ||
img_tst = np.array([ | ||
nb.load(fname, mmap=NUMPY_MMAP).get_data() | ||
for fname in self.inputs.in_tst | ||
]) | ||
|
||
msk = np.sum(img_ref, axis=0) | ||
msk[msk > 0] = 1.0 | ||
tst_msk = np.sum(img_tst, axis=0) | ||
tst_msk[tst_msk > 0] = 1.0 | ||
|
||
# check that volumes are normalized | ||
# img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0] | ||
# img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0] | ||
|
||
self._jaccards = [] | ||
volumes = [] | ||
|
||
diff_im = np.zeros(img_ref.shape) | ||
|
||
for ref_comp, tst_comp, diff_comp in zip(img_ref, img_tst, diff_im): | ||
num = np.minimum(ref_comp, tst_comp) | ||
ddr = np.maximum(ref_comp, tst_comp) | ||
diff_comp[ddr > 0] += 1.0 - (num[ddr > 0] / ddr[ddr > 0]) | ||
self._jaccards.append(np.sum(num) / np.sum(ddr)) | ||
volumes.append(np.sum(ref_comp)) | ||
|
||
self._dices = 2.0 * (np.array(self._jaccards) / | ||
(np.array(self._jaccards) + 1.0)) | ||
# Load data | ||
refdata = nb.concat_images(self.inputs.in_ref).get_data() | ||
tstdata = nb.concat_images(self.inputs.in_tst).get_data() | ||
|
||
# Data must have same shape | ||
if not refdata.shape == tstdata.shape: | ||
raise RuntimeError( | ||
'Size of "in_tst" %s must match that of "in_ref" %s.' % | ||
(tstdata.shape, refdata.shape)) | ||
|
||
# Load mask | ||
mask = np.ones_like(refdata[..., 0], dtype=bool) | ||
if isdefined(self.inputs.in_mask): | ||
mask = nb.load(self.inputs.in_mask).get_data() | ||
mask = mask > 0 | ||
assert mask.shape == refdata.shape[:-1] | ||
|
||
ncomp = refdata.shape[-1] | ||
|
||
# Drop data outside mask | ||
refdata = refdata[mask[..., np.newaxis]] | ||
tstdata = tstdata[mask[..., np.newaxis]] | ||
|
||
if np.any(refdata < 0.0): | ||
iflogger.warning('Negative values encountered in "in_ref" input, ' | ||
'taking absolute values.') | ||
refdata = np.abs(refdata) | ||
|
||
if np.any(tstdata < 0.0): | ||
iflogger.warning('Negative values encountered in "in_tst" input, ' | ||
'taking absolute values.') | ||
tstdata = np.abs(tstdata) | ||
|
||
if np.any(refdata > 1.0): | ||
iflogger.warning('Values greater than 1.0 found in "in_ref" input, ' | ||
'scaling values.') | ||
refdata /= refdata.max() | ||
|
||
if np.any(tstdata > 1.0): | ||
iflogger.warning('Values greater than 1.0 found in "in_tst" input, ' | ||
'scaling values.') | ||
tstdata /= tstdata.max() | ||
|
||
numerators = np.atleast_2d( | ||
np.minimum(refdata, tstdata).reshape((-1, ncomp))) | ||
denominators = np.atleast_2d( | ||
np.maximum(refdata, tstdata).reshape((-1, ncomp))) | ||
|
||
jaccards = numerators.sum(axis=0) / denominators.sum(axis=0) | ||
|
||
# Calculate weights | ||
weights = np.ones_like(jaccards, dtype=float) | ||
if self.inputs.weighting != "none": | ||
volumes = np.sum((refdata + tstdata) > 0, axis=1).reshape((-1, ncomp)) | ||
weights = 1.0 / np.array(volumes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if self.inputs.weighting == "squared_vol": | ||
weights = weights**2 | ||
|
||
weights = weights / np.sum(weights) | ||
dices = 2.0 * jaccards / (jaccards + 1.0) | ||
|
||
setattr(self, '_jaccard', np.sum(weights * self._jaccards)) | ||
setattr(self, '_dice', np.sum(weights * self._dices)) | ||
|
||
diff = np.zeros(diff_im[0].shape) | ||
|
||
for w, ch in zip(weights, diff_im): | ||
ch[msk == 0] = 0 | ||
diff += w * ch | ||
|
||
nb.save( | ||
nb.Nifti1Image(diff, | ||
nb.load(self.inputs.in_ref[0]).affine, | ||
nb.load(self.inputs.in_ref[0]).header), | ||
self.inputs.out_file) | ||
# Fill-in the results object | ||
self._results['jaccard'] = float(np.sum(weights * jaccards)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can just use the dot product here? self._results['jaccard'] = weights.dot(jaccards) |
||
self._results['dice'] = float(np.sum(weights * dices)) | ||
self._results['class_fji'] = [ | ||
float(v) for v in jaccards.astype(float).tolist()] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is me over-worrying about output traits. In the past I've seen that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah. If that's the case, then |
||
self._results['class_fdi'] = [ | ||
float(v) for v in dices.astype(float).tolist()] | ||
|
||
return runtime | ||
|
||
def _list_outputs(self): | ||
outputs = self._outputs().get() | ||
for method in ("dice", "jaccard"): | ||
outputs[method] = getattr(self, '_' + method) | ||
# outputs['volume_difference'] = self._volume | ||
outputs['diff_file'] = os.path.abspath(self.inputs.out_file) | ||
outputs['class_fji'] = np.array(self._jaccards).astype(float).tolist() | ||
outputs['class_fdi'] = self._dices.astype(float).tolist() | ||
return outputs | ||
|
||
|
||
class ErrorMapInputSpec(BaseInterfaceInputSpec): | ||
in_ref = File( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These shapes look guaranteed to be 2D, already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If only one item in the input lists are passed, numpy squeezes the redundant dimension. This was necessary for this interface to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well that's surprising and annoying. Okay.