Skip to content

Commit 9fc331e

Browse files
committed
Merge pull request #360 from chrisfilo/enh/tbss_3d_files
Added support for a list of 3d files for TSNR calculation.
2 parents 75d2e14 + 7616655 commit 9fc331e

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

nipype/algorithms/misc.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,8 @@ def _list_outputs(self):
447447

448448

449449
class TSNRInputSpec(BaseInterfaceInputSpec):
450-
in_file = File(exists=True, mandatory=True,
451-
desc='realigned 4D file')
450+
in_file = InputMultiPath(File(exists=True), mandatory=True,
451+
desc='realigned 4D file or a list of 3D files')
452452
regress_poly = traits.Int(min=1, desc='Remove polynomials')
453453

454454

@@ -475,39 +475,40 @@ class TSNR(BaseInterface):
475475
input_spec = TSNRInputSpec
476476
output_spec = TSNROutputSpec
477477

478-
def _gen_output_file_name(self, out_ext=None):
479-
_, base, _ = split_filename(self.inputs.in_file)
480-
if out_ext in ['mean', 'stddev']:
481-
return os.path.abspath(base + "_tsnr_" + out_ext + ".nii.gz")
482-
elif out_ext in ['detrended']:
483-
return os.path.abspath(base + "_" + out_ext + ".nii.gz")
478+
def _gen_output_file_name(self, suffix=None):
479+
_, base, ext = split_filename(self.inputs.in_file[0])
480+
if suffix in ['mean', 'stddev']:
481+
return os.path.abspath(base + "_tsnr_" + suffix + ext)
482+
elif suffix in ['detrended']:
483+
return os.path.abspath(base + "_" + suffix + ext)
484484
else:
485-
return os.path.abspath(base + "_tsnr.nii.gz")
485+
return os.path.abspath(base + "_tsnr" + ext)
486486

487487
def _run_interface(self, runtime):
488-
img = nb.load(self.inputs.in_file)
489-
data = img.get_data()
488+
img = nb.load(self.inputs.in_file[0])
489+
vollist = [nb.load(filename) for filename in self.inputs.in_file]
490+
data = np.concatenate([vol.get_data().reshape(vol.get_shape()[:3] + (-1,)) for vol in vollist], axis=3)
490491
if isdefined(self.inputs.regress_poly):
491492
timepoints = img.get_shape()[-1]
492-
X = np.ones((timepoints,1))
493+
X = np.ones((timepoints, 1))
493494
for i in range(self.inputs.regress_poly):
494-
X = np.hstack((X,legendre(i+1)(np.linspace(-1, 1, timepoints))[:, None]))
495+
X = np.hstack((X, legendre(i + 1)(np.linspace(-1, 1, timepoints))[:, None]))
495496
betas = np.dot(np.linalg.pinv(X), np.rollaxis(data, 3, 2))
496-
datahat = np.rollaxis(np.dot(X[:,1:],
497+
datahat = np.rollaxis(np.dot(X[:, 1:],
497498
np.rollaxis(betas[1:, :, :, :], 0, 3)),
498499
0, 4)
499500
data = data - datahat
500501
img = nb.Nifti1Image(data, img.get_affine(), img.get_header())
501-
nb.save(img, self._gen_output_file_name('detrended'))
502+
nb.save(img, self._gen_output_file_name('detrended'))
502503
meanimg = np.mean(data, axis=3)
503504
stddevimg = np.std(data, axis=3)
504-
tsnr = meanimg/stddevimg
505+
tsnr = meanimg / stddevimg
505506
img = nb.Nifti1Image(tsnr, img.get_affine(), img.get_header())
506-
nb.save(img, self._gen_output_file_name())
507+
nb.save(img, self._gen_output_file_name())
507508
img = nb.Nifti1Image(meanimg, img.get_affine(), img.get_header())
508-
nb.save(img, self._gen_output_file_name('mean'))
509+
nb.save(img, self._gen_output_file_name('mean'))
509510
img = nb.Nifti1Image(stddevimg, img.get_affine(), img.get_header())
510-
nb.save(img, self._gen_output_file_name('stddev'))
511+
nb.save(img, self._gen_output_file_name('stddev'))
511512
return runtime
512513

513514
def _list_outputs(self):

0 commit comments

Comments
 (0)