@@ -447,8 +447,8 @@ def _list_outputs(self):
447
447
448
448
449
449
class TSNRInputSpec (BaseInterfaceInputSpec ):
450
- in_file = File (exists = True , mandatory = True ,
451
- desc = 'realigned 4D file' )
450
+ in_file = InputMultiPath ( File (exists = True ) , mandatory = True ,
451
+ desc = 'realigned 4D file or a list of 3D files ' )
452
452
regress_poly = traits .Int (min = 1 , desc = 'Remove polynomials' )
453
453
454
454
@@ -475,39 +475,40 @@ class TSNR(BaseInterface):
475
475
input_spec = TSNRInputSpec
476
476
output_spec = TSNROutputSpec
477
477
478
- def _gen_output_file_name (self , out_ext = None ):
479
- _ , base , _ = split_filename (self .inputs .in_file )
480
- if out_ext in ['mean' , 'stddev' ]:
481
- return os .path .abspath (base + "_tsnr_" + out_ext + ".nii.gz" )
482
- elif out_ext in ['detrended' ]:
483
- return os .path .abspath (base + "_" + out_ext + ".nii.gz" )
478
+ def _gen_output_file_name (self , suffix = None ):
479
+ _ , base , ext = split_filename (self .inputs .in_file [ 0 ] )
480
+ if suffix in ['mean' , 'stddev' ]:
481
+ return os .path .abspath (base + "_tsnr_" + suffix + ext )
482
+ elif suffix in ['detrended' ]:
483
+ return os .path .abspath (base + "_" + suffix + ext )
484
484
else :
485
- return os .path .abspath (base + "_tsnr.nii.gz" )
485
+ return os .path .abspath (base + "_tsnr" + ext )
486
486
487
487
def _run_interface (self , runtime ):
488
- img = nb .load (self .inputs .in_file )
489
- data = img .get_data ()
488
+ img = nb .load (self .inputs .in_file [0 ])
489
+ vollist = [nb .load (filename ) for filename in self .inputs .in_file ]
490
+ data = np .concatenate ([vol .get_data ().reshape (vol .get_shape ()[:3 ] + (- 1 ,)) for vol in vollist ], axis = 3 )
490
491
if isdefined (self .inputs .regress_poly ):
491
492
timepoints = img .get_shape ()[- 1 ]
492
- X = np .ones ((timepoints ,1 ))
493
+ X = np .ones ((timepoints , 1 ))
493
494
for i in range (self .inputs .regress_poly ):
494
- X = np .hstack ((X ,legendre (i + 1 )(np .linspace (- 1 , 1 , timepoints ))[:, None ]))
495
+ X = np .hstack ((X , legendre (i + 1 )(np .linspace (- 1 , 1 , timepoints ))[:, None ]))
495
496
betas = np .dot (np .linalg .pinv (X ), np .rollaxis (data , 3 , 2 ))
496
- datahat = np .rollaxis (np .dot (X [:,1 :],
497
+ datahat = np .rollaxis (np .dot (X [:, 1 :],
497
498
np .rollaxis (betas [1 :, :, :, :], 0 , 3 )),
498
499
0 , 4 )
499
500
data = data - datahat
500
501
img = nb .Nifti1Image (data , img .get_affine (), img .get_header ())
501
- nb .save (img , self ._gen_output_file_name ('detrended' ))
502
+ nb .save (img , self ._gen_output_file_name ('detrended' ))
502
503
meanimg = np .mean (data , axis = 3 )
503
504
stddevimg = np .std (data , axis = 3 )
504
- tsnr = meanimg / stddevimg
505
+ tsnr = meanimg / stddevimg
505
506
img = nb .Nifti1Image (tsnr , img .get_affine (), img .get_header ())
506
- nb .save (img , self ._gen_output_file_name ())
507
+ nb .save (img , self ._gen_output_file_name ())
507
508
img = nb .Nifti1Image (meanimg , img .get_affine (), img .get_header ())
508
- nb .save (img , self ._gen_output_file_name ('mean' ))
509
+ nb .save (img , self ._gen_output_file_name ('mean' ))
509
510
img = nb .Nifti1Image (stddevimg , img .get_affine (), img .get_header ())
510
- nb .save (img , self ._gen_output_file_name ('stddev' ))
511
+ nb .save (img , self ._gen_output_file_name ('stddev' ))
511
512
return runtime
512
513
513
514
def _list_outputs (self ):
0 commit comments