Skip to content

Commit 20d1939

Browse files
committed
Merge pull request pydata#192 from shoyer/modify-in-place
Enhanced support for modifying Dataset & DataArray properties in place
2 parents 9597d9e + a7f5351 commit 20d1939

File tree

8 files changed

+216
-19
lines changed

8 files changed

+216
-19
lines changed

test/test_data_array.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def test_repr(self):
3737

3838
def test_properties(self):
3939
self.assertDatasetIdentical(self.dv.dataset, self.ds)
40-
self.assertEqual(self.dv.name, 'foo')
4140
self.assertVariableEqual(self.dv.variable, self.v)
4241
self.assertArrayEqual(self.dv.values, self.v.values)
4342
for attr in ['dimensions', 'dtype', 'shape', 'size', 'ndim', 'attrs']:
@@ -47,13 +46,39 @@ def test_properties(self):
4746
self.assertEqual(list(self.dv.coordinates), list(self.ds.coordinates))
4847
for k, v in iteritems(self.dv.coordinates):
4948
self.assertArrayEqual(v, self.ds.coordinates[k])
50-
with self.assertRaises(AttributeError):
51-
self.dv.name = 'bar'
5249
with self.assertRaises(AttributeError):
5350
self.dv.dataset = self.ds
5451
self.assertIsInstance(self.ds['x'].as_index, pd.Index)
5552
with self.assertRaisesRegexp(ValueError, 'must be 1-dimensional'):
5653
self.ds['foo'].as_index
54+
with self.assertRaises(AttributeError):
55+
self.dv.variable = self.v
56+
57+
def test_name(self):
58+
arr = self.dv
59+
self.assertEqual(arr.name, 'foo')
60+
61+
copied = arr.copy()
62+
arr.name = 'bar'
63+
self.assertEqual(arr.name, 'bar')
64+
self.assertDataArrayEqual(copied, arr)
65+
66+
actual = DataArray(Coordinate('x', [3]))
67+
actual.name = 'y'
68+
expected = DataArray(Coordinate('y', [3]))
69+
self.assertDataArrayIdentical(actual, expected)
70+
71+
def test_dimensions(self):
72+
arr = self.dv
73+
self.assertEqual(arr.dimensions, ('x', 'y'))
74+
75+
arr.dimensions = ('w', 'z')
76+
self.assertEqual(arr.dimensions, ('w', 'z'))
77+
78+
x = Dataset({'x': ('x', np.arange(5))})['x']
79+
x.dimensions = ('y',)
80+
self.assertEqual(x.dimensions, ('y',))
81+
self.assertEqual(x.name, 'y')
5782

5883
def test_encoding(self):
5984
expected = {'foo': 'bar'}
@@ -166,10 +191,13 @@ def test_constructor_from_self_described(self):
166191
expected = DataArray([data], expected.coordinates, ['dim_0', 'x', 'y'])
167192
self.assertDataArrayIdentical(expected, actual)
168193

169-
expected = DataArray(['a', 'b'], name='foo')
194+
expected = Dataset({'foo': ('foo', ['a', 'b'])})['foo']
170195
actual = DataArray(pd.Index(['a', 'b'], name='foo'))
171196
self.assertDataArrayIdentical(expected, actual)
172197

198+
actual = DataArray(Coordinate('foo', ['a', 'b']))
199+
self.assertDataArrayIdentical(expected, actual)
200+
173201
def test_equals_and_identical(self):
174202
da2 = self.dv.copy()
175203
self.assertTrue(self.dv.equals(da2))
@@ -275,6 +303,52 @@ def test_coordinates(self):
275303
actual = repr(da.coordinates)
276304
self.assertEquals(expected, actual)
277305

306+
def test_coordinates_modify(self):
307+
da = DataArray(np.zeros((2, 3)), dimensions=['x', 'y'])
308+
309+
for k, v in [('x', ['a', 'b']), (0, ['c', 'd']), (-2, ['e', 'f'])]:
310+
da.coordinates[k] = v
311+
self.assertArrayEqual(da.coordinates[k], v)
312+
313+
actual = da.copy()
314+
orig_dataset = actual.dataset
315+
actual.coordinates = [[5, 6], [7, 8, 9]]
316+
expected = DataArray(np.zeros((2, 3)), coordinates=[[5, 6], [7, 8, 9]],
317+
dimensions=['x', 'y'])
318+
self.assertDataArrayIdentical(actual, expected)
319+
self.assertIsNot(actual.dataset, orig_dataset)
320+
321+
actual = da.copy()
322+
actual.coordinates = expected.coordinates
323+
self.assertDataArrayIdentical(actual, expected)
324+
325+
actual = da.copy()
326+
expected = DataArray(np.zeros((2, 3)), coordinates=[[5, 6], [7, 8, 9]],
327+
dimensions=['foo', 'bar'])
328+
actual.coordinates = expected.coordinates
329+
self.assertDataArrayIdentical(actual, expected)
330+
331+
with self.assertRaisesRegexp(ValueError, 'coordinate has size'):
332+
da.coordinates['x'] = ['a']
333+
334+
with self.assertRaises(IndexError):
335+
da.coordinates['foobar'] = np.arange(4)
336+
337+
with self.assertRaisesRegexp(ValueError, 'coordinate has size'):
338+
da.coordinates = da.isel(y=slice(2)).coordinates
339+
340+
# modify the coordinates on a coordinate itself
341+
x = DataArray(Coordinate('x', [10.0, 20.0, 30.0]))
342+
343+
actual = x.copy()
344+
actual.coordinates = [[0, 1, 2]]
345+
expected = DataArray(Coordinate('x', range(3)))
346+
self.assertDataArrayIdentical(actual, expected)
347+
348+
actual = DataArray(Coordinate('y', [-10, -20, -30]))
349+
actual.coordinates = expected.coordinates
350+
self.assertDataArrayIdentical(actual, expected)
351+
278352
def test_reindex(self):
279353
foo = self.dv
280354
bar = self.dv[:2, :2]

test/test_dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,22 @@ def test_coordinates_properties(self):
183183
actual = repr(data.coordinates)
184184
self.assertEquals(expected, actual)
185185

186+
def test_coordinates_modify(self):
187+
data = Dataset({'x': ('x', [-1, -2]),
188+
'y': ('y', [0, 1, 2]),
189+
'foo': (['x', 'y'], np.random.randn(2, 3))})
190+
191+
actual = data.copy(deep=True)
192+
actual.coordinates['x'] = ['a', 'b']
193+
self.assertArrayEqual(actual['x'], ['a', 'b'])
194+
195+
actual = data.copy(deep=True)
196+
actual.coordinates['z'] = ['a', 'b']
197+
self.assertArrayEqual(actual['z'], ['a', 'b'])
198+
199+
with self.assertRaisesRegexp(ValueError, 'coordinate has size'):
200+
data.coordinates['x'] = [-1]
201+
186202
def test_equals_and_identical(self):
187203
data = create_test_data(seed=42)
188204
self.assertTrue(data.equals(data))
@@ -429,6 +445,15 @@ def test_rename(self):
429445
with self.assertRaises(UnexpectedDataAccess):
430446
renamed['renamed_var1'].values
431447

448+
def test_rename_inplace(self):
449+
data = Dataset({'z': ('x', [2, 3, 4])})
450+
copied = data.copy()
451+
renamed = data.rename({'x': 'y'})
452+
data.rename({'x': 'y'}, inplace=True)
453+
self.assertDatasetIdentical(data, renamed)
454+
self.assertFalse(data.equals(copied))
455+
self.assertEquals(data.dimensions, {'y': 3})
456+
432457
def test_update(self):
433458
data = create_test_data(seed=0)
434459
var2 = Variable('dim1', np.arange(100))

test/test_variable.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,6 @@ def test_data(self):
585585
self.assertEqual(float, x.dtype)
586586
self.assertArrayEqual(np.arange(3), x)
587587
self.assertEqual(float, x.values.dtype)
588-
self.assertEqual('x', x.name)
589588
# after inspecting x.values, the Coordinate value will be saved as an Index
590589
self.assertIsInstance(x._data, PandasIndexAdapter)
591590
with self.assertRaisesRegexp(TypeError, 'cannot be modified'):
@@ -603,6 +602,13 @@ def test_avoid_index_dtype_inference(self):
603602
self.assertEqual(t.dtype, object)
604603
self.assertEqual(t[:2].dtype, object)
605604

605+
def test_name(self):
606+
coord = Coordinate('x', [10.0])
607+
self.assertEqual(coord.name, 'x')
608+
609+
with self.assertRaises(AttributeError):
610+
coord.name = 'y'
611+
606612

607613
class TestAsCompatibleData(TestCase):
608614
def test_unchanged_types(self):

xray/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ def __repr__(self):
119119
return '\n'.join(_wrap_indent(repr(v.as_index), '%s: ' % k)
120120
for k, v in self.items())
121121

122+
@staticmethod
123+
def _convert_to_coord(key, value, expected_size=None):
124+
from .variable import Coordinate, as_variable
125+
126+
if not isinstance(value, AbstractArray):
127+
value = Coordinate(key, value)
128+
coord = as_variable(value).to_coord()
129+
130+
if expected_size is not None and coord.size != expected_size:
131+
raise ValueError('new coordinate has size %s but the existing '
132+
'coordinate has size %s'
133+
% (coord.size, expected_size))
134+
return coord
135+
122136

123137
def _summarize_attributes(data):
124138
if data.attrs:

xray/data_array.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import functools
23
import operator
34
import warnings
@@ -80,6 +81,12 @@ def __setitem__(self, key, value):
8081
self.data_array[self._remap_key(key)] = value
8182

8283

84+
def _assert_coordinates_same_size(orig, new):
85+
if not new.size == orig.size:
86+
raise ValueError('new coordinate has size %s but the existing '
87+
'coordinate has size %s' % (new.size, orig.size))
88+
89+
8390
class DataArrayCoordinates(AbstractCoordinates):
8491
"""Dictionary like container for DataArray coordinates.
8592
@@ -96,6 +103,17 @@ def __getitem__(self, key):
96103
else:
97104
raise KeyError(repr(key))
98105

106+
def __setitem__(self, key, value):
107+
if isinstance(key, (int, np.integer)):
108+
key = self._data.dimensions[key]
109+
110+
if key not in self:
111+
raise IndexError('%s is not a coordinate')
112+
113+
coord = self._convert_to_coord(key, value, self[key].size)
114+
with self._data._set_new_dataset() as ds:
115+
ds._variables[key] = coord
116+
99117

100118
class DataArray(AbstractArray):
101119
"""N-dimensional array with labeled coordinates and dimensions.
@@ -178,6 +196,8 @@ def __init__(self, data=None, coordinates=None, dimensions=None, name=None,
178196
coordinates = [data.index]
179197
elif isinstance(data, pd.DataFrame):
180198
coordinates = [data.index, data.columns]
199+
elif isinstance(data, (pd.Index, variable.Coordinate)):
200+
coordinates = [data]
181201
elif isinstance(data, pd.Panel):
182202
coordinates = [data.items, data.major_axis, data.minor_axis]
183203
if dimensions is None:
@@ -197,12 +217,10 @@ def __init__(self, data=None, coordinates=None, dimensions=None, name=None,
197217
dimensions, data, attributes, encoding)
198218
dataset = xray.Dataset(variables)
199219
else:
220+
# move this back to an alternate constructor?
200221
if name not in dataset and name not in dataset.virtual_variables:
201222
raise ValueError('name %r must be a variable in dataset %s' %
202223
(name, dataset))
203-
# make a shallow copy of the dataset so we can safely modify the
204-
# array in-place?
205-
# dataset = dataset.copy(deep=False)
206224

207225
self._dataset = dataset
208226
self._name = name
@@ -220,19 +238,25 @@ def name(self):
220238
"""
221239
return self._name
222240

241+
@contextlib.contextmanager
242+
def _set_new_dataset(self):
243+
"""Context manager to use for modifying _dataset, in a manner that
244+
can be safely rolled back if an error is encountered.
245+
"""
246+
ds = self.dataset.copy(deep=False)
247+
yield ds
248+
self._dataset = ds
249+
223250
@name.setter
224251
def name(self, value):
225-
raise AttributeError('cannot modify the name of a %s inplace; use the '
226-
"'rename' method instead" % type(self).__name__)
252+
with self._set_new_dataset() as ds:
253+
ds.rename({self.name: value}, inplace=True)
254+
self._name = value
227255

228256
@property
229257
def variable(self):
230258
return self.dataset.variables[self.name]
231259

232-
@variable.setter
233-
def variable(self, value):
234-
self.dataset[self.name] = value
235-
236260
@property
237261
def dtype(self):
238262
return self.variable.dtype
@@ -274,6 +298,17 @@ def as_index(self):
274298
def dimensions(self):
275299
return self.variable.dimensions
276300

301+
@dimensions.setter
302+
def dimensions(self, value):
303+
with self._set_new_dataset() as ds:
304+
if not len(value) == self.ndim:
305+
raise ValueError('%s dimensions supplied but data has ndim=%s'
306+
% (len(value), self.ndim))
307+
name_map = dict(zip(self.dimensions, value))
308+
ds.rename(name_map, inplace=True)
309+
if self.name in name_map:
310+
self._name = name_map[self.name]
311+
277312
def _key_to_indexers(self, key):
278313
return OrderedDict(
279314
zip(self.dimensions, indexing.expanded_indexer(key, self.ndim)))
@@ -350,6 +385,31 @@ def coordinates(self):
350385
"""
351386
return DataArrayCoordinates(self)
352387

388+
@coordinates.setter
389+
def coordinates(self, value):
390+
if not len(value) == self.ndim:
391+
raise ValueError('%s coordinates supplied but data has ndim=%s'
392+
% (len(value), self.ndim))
393+
with self._set_new_dataset() as ds:
394+
# TODO: allow setting to dict-like objects other than
395+
# DataArrayCoordinates?
396+
if isinstance(value, DataArrayCoordinates):
397+
# yes, this is regretably complex and probably slow
398+
name_map = dict(zip(self.dimensions, value.keys()))
399+
ds.rename(name_map, inplace=True)
400+
name = name_map.get(self.name, self.name)
401+
dimensions = ds[name].dimensions
402+
value = value.values()
403+
else:
404+
name = self.name
405+
dimensions = self.dimensions
406+
407+
for k, v in zip(dimensions, value):
408+
coord = DataArrayCoordinates._convert_to_coord(
409+
k, v, expected_size=ds.coordinates[k].size)
410+
ds[k] = coord
411+
self._name = name
412+
353413
def load_data(self):
354414
"""Manually trigger loading of this array's data from disk or a
355415
remote source into memory and return this array.
@@ -836,7 +896,7 @@ def _inplace_binary_op(f):
836896
def func(self, other):
837897
self._check_coords_compat(other)
838898
other_array = getattr(other, 'variable', other)
839-
self.variable = f(self.variable, other_array)
899+
f(self.variable, other_array)
840900
if hasattr(other, 'coordinates'):
841901
self.dataset.merge(other.coordinates, inplace=True)
842902
return self

xray/dataset.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ def __getitem__(self, key):
256256
else:
257257
raise KeyError(repr(key))
258258

259+
def __setitem__(self, key, value):
260+
expected_size = self[key].size if key in self else None
261+
self._data[key] = self._convert_to_coord(key, value, expected_size)
262+
259263

260264
def as_dataset(obj):
261265
"""Cast the given object to a Dataset.
@@ -852,14 +856,17 @@ def get_fill_value_and_dtype(dtype):
852856
variables[name] = new_var
853857
return type(self)(variables, self.attrs)
854858

855-
def rename(self, name_dict):
859+
def rename(self, name_dict, inplace=False):
856860
"""Returns a new object with renamed variables and dimensions.
857861
858862
Parameters
859863
----------
860864
name_dict : dict-like
861865
Dictionary whose keys are current variable or dimension names and
862866
whose values are new names.
867+
inplace : bool, optional
868+
If True, rename variables and dimensions in-place. Otherwise,
869+
return a new dataset object.
863870
864871
Returns
865872
-------
@@ -877,7 +884,14 @@ def rename(self, name_dict):
877884
var = v.copy(deep=False)
878885
var.dimensions = dims
879886
variables[name] = var
880-
return type(self)(variables, self.attrs)
887+
888+
if inplace:
889+
self._dimensions = _calculate_dimensions(variables)
890+
self._variables = variables
891+
obj = self
892+
else:
893+
obj = type(self)(variables, self.attrs)
894+
return obj
881895

882896
def update(self, other, inplace=True):
883897
"""Update this dataset's variables and attributes with those from

0 commit comments

Comments
 (0)