diff --git a/nipype/algorithms/misc.py b/nipype/algorithms/misc.py index 48b45e6e01..3df52cf3f8 100644 --- a/nipype/algorithms/misc.py +++ b/nipype/algorithms/misc.py @@ -447,8 +447,8 @@ def _list_outputs(self): class TSNRInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, - desc='realigned 4D file') + in_file = InputMultiPath(File(exists=True), mandatory=True, + desc='realigned 4D file or a list of 3D files') regress_poly = traits.Int(min=1, desc='Remove polynomials') @@ -475,39 +475,40 @@ class TSNR(BaseInterface): input_spec = TSNRInputSpec output_spec = TSNROutputSpec - def _gen_output_file_name(self, out_ext=None): - _, base, _ = split_filename(self.inputs.in_file) - if out_ext in ['mean', 'stddev']: - return os.path.abspath(base + "_tsnr_" + out_ext + ".nii.gz") - elif out_ext in ['detrended']: - return os.path.abspath(base + "_" + out_ext + ".nii.gz") + def _gen_output_file_name(self, suffix=None): + _, base, ext = split_filename(self.inputs.in_file[0]) + if suffix in ['mean', 'stddev']: + return os.path.abspath(base + "_tsnr_" + suffix + ext) + elif suffix in ['detrended']: + return os.path.abspath(base + "_" + suffix + ext) else: - return os.path.abspath(base + "_tsnr.nii.gz") + return os.path.abspath(base + "_tsnr" + ext) def _run_interface(self, runtime): - img = nb.load(self.inputs.in_file) - data = img.get_data() + img = nb.load(self.inputs.in_file[0]) + vollist = [nb.load(filename) for filename in self.inputs.in_file] + data = np.concatenate([vol.get_data().reshape(vol.get_shape()[:3] + (-1,)) for vol in vollist], axis=3) if isdefined(self.inputs.regress_poly): timepoints = img.get_shape()[-1] - X = np.ones((timepoints,1)) + X = np.ones((timepoints, 1)) for i in range(self.inputs.regress_poly): - X = np.hstack((X,legendre(i+1)(np.linspace(-1, 1, timepoints))[:, None])) + X = np.hstack((X, legendre(i + 1)(np.linspace(-1, 1, timepoints))[:, None])) betas = np.dot(np.linalg.pinv(X), np.rollaxis(data, 3, 2)) - datahat = np.rollaxis(np.dot(X[:,1:], + datahat = np.rollaxis(np.dot(X[:, 1:], np.rollaxis(betas[1:, :, :, :], 0, 3)), 0, 4) data = data - datahat img = nb.Nifti1Image(data, img.get_affine(), img.get_header()) - nb.save(img, self._gen_output_file_name('detrended')) + nb.save(img, self._gen_output_file_name('detrended')) meanimg = np.mean(data, axis=3) stddevimg = np.std(data, axis=3) - tsnr = meanimg/stddevimg + tsnr = meanimg / stddevimg img = nb.Nifti1Image(tsnr, img.get_affine(), img.get_header()) - nb.save(img, self._gen_output_file_name()) + nb.save(img, self._gen_output_file_name()) img = nb.Nifti1Image(meanimg, img.get_affine(), img.get_header()) - nb.save(img, self._gen_output_file_name('mean')) + nb.save(img, self._gen_output_file_name('mean')) img = nb.Nifti1Image(stddevimg, img.get_affine(), img.get_header()) - nb.save(img, self._gen_output_file_name('stddev')) + nb.save(img, self._gen_output_file_name('stddev')) return runtime def _list_outputs(self):