Skip to content

Commit 4548d10

Browse files
committed
Implemented of Dataset.apply method
Fixes pydata#140
1 parent 6c394b1 commit 4548d10

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ Computations
7777
.. autosummary::
7878
:toctree: generated/
7979

80+
Dataset.apply
81+
Dataset.reduce
8082
Dataset.all
8183
Dataset.any
8284
Dataset.argmax

test/test_dataset.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,3 +779,32 @@ def test_reduce_keep_attrs(self):
779779
ds = data.mean(keep_attrs=True)
780780
self.assertEqual(len(ds.attrs), len(_attrs))
781781
self.assertTrue(ds.attrs, attrs)
782+
783+
def test_apply(self):
784+
data = create_test_data()
785+
data.attrs['foo'] = 'bar'
786+
787+
self.assertDatasetIdentical(data.apply(np.mean), data.mean())
788+
self.assertDatasetIdentical(data.apply(np.mean, keep_attrs=True),
789+
data.mean(keep_attrs=True))
790+
791+
self.assertDatasetIdentical(data.apply(lambda x: x, keep_attrs=True),
792+
data.drop_vars('time'))
793+
794+
actual = data.apply(np.mean, to=['var1', 'var2', 'var3'])
795+
self.assertDatasetIdentical(actual, data.mean())
796+
797+
actual = data.apply(np.mean, to='var1')
798+
modified = data.select_vars('var1').mean()
799+
unmodified = data.select_vars('var2', 'var3')
800+
expected = modified.merge(unmodified)
801+
self.assertDatasetIdentical(actual, expected)
802+
803+
with self.assertRaisesRegexp(ValueError, 'does not contain'):
804+
data.apply(np.mean, to='foobarbaz')
805+
806+
def scale(x, multiple=1):
807+
return multiple * x
808+
809+
actual = data.apply(scale, multiple=2)
810+
self.assertDataArrayEqual(actual['var1'], 2 * data['var1'])

xray/dataset.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,50 @@ def reduce(self, func, dimension=None, keep_attrs=False, **kwargs):
10911091

10921092
return Dataset(variables=variables, attributes=attrs)
10931093

1094+
def apply(self, func, to=None, keep_attrs=False, **kwargs):
1095+
"""Apply a function over noncoordinates in this dataset.
1096+
1097+
Parameters
1098+
----------
1099+
func : function
1100+
Function which can be called in the form `f(x, **kwargs)` to
1101+
transform each DataArray `x` in this dataset into another
1102+
DataArray.
1103+
to : str or sequence of str, optional
1104+
Explicit list of noncoordinates in this dataset to which to apply
1105+
`func`. Unlisted noncoordinates are passed through unchanged. By
1106+
default, `func` is applied to all noncoordinates in this dataset.
1107+
keep_attrs : bool, optional
1108+
If True, the datasets's attributes (`attrs`) will be copied from
1109+
the original object to the new one. If False, the new object will
1110+
be returned without attributes.
1111+
**kwargs : dict
1112+
Additional keyword arguments passed on to `func`.
1113+
1114+
Returns
1115+
-------
1116+
applied : Dataset
1117+
Resulting dataset from applying over each noncoordinate.
1118+
Coordinates which are no longer used as the dimension of a
1119+
noncoordinate are dropped.
1120+
"""
1121+
if to is not None:
1122+
to = set([to] if isinstance(to, basestring) else to)
1123+
bad_to = to - set(self.noncoordinates)
1124+
if bad_to:
1125+
raise ValueError('Dataset does not contain the '
1126+
'noncoordinates: %r' % list(bad_to))
1127+
else:
1128+
to = set(self.noncoordinates)
1129+
1130+
variables = OrderedDict()
1131+
for name, var in iteritems(self.noncoordinates):
1132+
variables[name] = func(var, **kwargs) if name in to else var
1133+
1134+
attrs = self.attrs if keep_attrs else {}
1135+
1136+
return Dataset(variables, attrs)
1137+
10941138
@classmethod
10951139
def concat(cls, datasets, dimension='concat_dimension', indexers=None,
10961140
mode='different', concat_over=None, compat='equals'):

0 commit comments

Comments
 (0)