Skip to content

Commit b7cdaa0

Browse files
authored
Allow assigning Coercible values + NamedNode internal class (pydata#115)
* test assigning int * allow assigning coercible values * refactor name-related methods to intermediate class * refactor tests to match * fix now-exposed bug with naming * moved test showing lack of name permanence * whatsnew
1 parent 4cb72fb commit b7cdaa0

File tree

5 files changed

+164
-126
lines changed

5 files changed

+164
-126
lines changed

datatree/datatree.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
MappedDataWithCoords,
4242
)
4343
from .render import RenderTree
44-
from .treenode import NodePath, Tree, TreeNode
44+
from .treenode import NamedNode, NodePath, Tree
4545

4646
if TYPE_CHECKING:
4747
from xarray.core.merge import CoercibleValue
@@ -228,7 +228,7 @@ def _replace(
228228

229229

230230
class DataTree(
231-
TreeNode,
231+
NamedNode,
232232
MappedDatasetMethodsMixin,
233233
MappedDataWithCoords,
234234
DataTreeArithmeticMixin,
@@ -343,20 +343,6 @@ def __init__(
343343
)
344344
self._close = ds._close
345345

346-
@property
347-
def name(self) -> str | None:
348-
"""The name of this node."""
349-
return self._name
350-
351-
@name.setter
352-
def name(self, name: str | None) -> None:
353-
if name is not None:
354-
if not isinstance(name, str):
355-
raise TypeError("node name must be a string or None")
356-
if "/" in name:
357-
raise ValueError("node names cannot contain forward slashes")
358-
self._name = name
359-
360346
@property
361347
def parent(self: DataTree) -> DataTree | None:
362348
"""Parent of this node."""
@@ -699,14 +685,17 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None:
699685
if isinstance(val, DataTree):
700686
val.name = key
701687
val.parent = self
702-
elif isinstance(val, (DataArray, Variable)):
703-
# TODO this should also accomodate other types that can be coerced into Variables
704-
self.update({key: val})
705688
else:
706-
raise TypeError(f"Type {type(val)} cannot be assigned to a DataTree")
689+
if not isinstance(val, (DataArray, Variable)):
690+
# accommodate other types that can be coerced into Variables
691+
val = DataArray(val)
692+
693+
self.update({key: val})
707694

708695
def __setitem__(
709-
self, key: str, value: DataTree | Dataset | DataArray | Variable
696+
self,
697+
key: str,
698+
value: Any,
710699
) -> None:
711700
"""
712701
Add either a child node or an array to the tree, at any position.

datatree/tests/test_datatree.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,6 @@ def test_setitem_unnamed_child_node_becomes_named(self):
249249
john2["sonny"] = DataTree()
250250
assert john2["sonny"].name == "sonny"
251251

252-
@pytest.mark.xfail(reason="bug with name overwriting")
253-
def test_setitem_child_node_keeps_name(self):
254-
john = DataTree(name="john")
255-
r2d2 = DataTree(name="R2D2")
256-
john["Mary"] = r2d2
257-
assert r2d2.name == "R2D2"
258-
259252
def test_setitem_new_grandchild_node(self):
260253
john = DataTree(name="john")
261254
DataTree(name="mary", parent=john)
@@ -314,6 +307,17 @@ def test_setitem_unnamed_dataarray(self):
314307
folder1["results"] = data
315308
xrt.assert_equal(folder1["results"], data)
316309

310+
def test_setitem_variable(self):
311+
var = xr.Variable(data=[0, 50], dims="x")
312+
folder1 = DataTree(name="folder1")
313+
folder1["results"] = var
314+
xrt.assert_equal(folder1["results"], xr.DataArray(var))
315+
316+
def test_setitem_coerce_to_dataarray(self):
317+
folder1 = DataTree(name="folder1")
318+
folder1["results"] = 0
319+
xrt.assert_equal(folder1["results"], xr.DataArray(0))
320+
317321
def test_setitem_add_new_variable_to_empty_node(self):
318322
results = DataTree(name="results")
319323
results["pressure"] = xr.DataArray(data=[2, 3])

datatree/tests/test_treenode.py

Lines changed: 72 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from datatree.iterators import LevelOrderIter, PreOrderIter
4-
from datatree.treenode import TreeError, TreeNode
4+
from datatree.treenode import NamedNode, TreeError, TreeNode
55

66

77
class TestFamilyTree:
@@ -143,43 +143,6 @@ def test_get_from_root(self):
143143
assert sue._get_item("/Mary") is mary
144144

145145

146-
class TestPaths:
147-
def test_path_property(self):
148-
sue = TreeNode()
149-
mary = TreeNode(children={"Sue": sue})
150-
john = TreeNode(children={"Mary": mary}) # noqa
151-
assert sue.path == "/Mary/Sue"
152-
assert john.path == "/"
153-
154-
def test_path_roundtrip(self):
155-
sue = TreeNode()
156-
mary = TreeNode(children={"Sue": sue})
157-
john = TreeNode(children={"Mary": mary}) # noqa
158-
assert john._get_item(sue.path) == sue
159-
160-
def test_same_tree(self):
161-
mary = TreeNode()
162-
kate = TreeNode()
163-
john = TreeNode(children={"Mary": mary, "Kate": kate}) # noqa
164-
assert mary.same_tree(kate)
165-
166-
def test_relative_paths(self):
167-
sue = TreeNode()
168-
mary = TreeNode(children={"Sue": sue})
169-
annie = TreeNode()
170-
john = TreeNode(children={"Mary": mary, "Annie": annie})
171-
172-
assert sue.relative_to(john) == "Mary/Sue"
173-
assert john.relative_to(sue) == "../.."
174-
assert annie.relative_to(sue) == "../../Annie"
175-
assert sue.relative_to(annie) == "../Mary/Sue"
176-
assert sue.relative_to(sue) == "."
177-
178-
evil_kate = TreeNode()
179-
with pytest.raises(ValueError, match="nodes do not lie within the same tree"):
180-
sue.relative_to(evil_kate)
181-
182-
183146
class TestSetNodes:
184147
def test_set_child_node(self):
185148
john = TreeNode()
@@ -261,16 +224,66 @@ def test_del_child(self):
261224
del john["Mary"]
262225

263226

227+
class TestNames:
228+
def test_child_gets_named_on_attach(self):
229+
sue = NamedNode()
230+
mary = NamedNode(children={"Sue": sue}) # noqa
231+
assert sue.name == "Sue"
232+
233+
@pytest.mark.xfail(reason="requires refactoring to retain name")
234+
def test_grafted_subtree_retains_name(self):
235+
subtree = NamedNode("original")
236+
root = NamedNode(children={"new_name": subtree}) # noqa
237+
assert subtree.name == "original"
238+
239+
240+
class TestPaths:
241+
def test_path_property(self):
242+
sue = NamedNode()
243+
mary = NamedNode(children={"Sue": sue})
244+
john = NamedNode(children={"Mary": mary}) # noqa
245+
assert sue.path == "/Mary/Sue"
246+
assert john.path == "/"
247+
248+
def test_path_roundtrip(self):
249+
sue = NamedNode()
250+
mary = NamedNode(children={"Sue": sue})
251+
john = NamedNode(children={"Mary": mary}) # noqa
252+
assert john._get_item(sue.path) == sue
253+
254+
def test_same_tree(self):
255+
mary = NamedNode()
256+
kate = NamedNode()
257+
john = NamedNode(children={"Mary": mary, "Kate": kate}) # noqa
258+
assert mary.same_tree(kate)
259+
260+
def test_relative_paths(self):
261+
sue = NamedNode()
262+
mary = NamedNode(children={"Sue": sue})
263+
annie = NamedNode()
264+
john = NamedNode(children={"Mary": mary, "Annie": annie})
265+
266+
assert sue.relative_to(john) == "Mary/Sue"
267+
assert john.relative_to(sue) == "../.."
268+
assert annie.relative_to(sue) == "../../Annie"
269+
assert sue.relative_to(annie) == "../Mary/Sue"
270+
assert sue.relative_to(sue) == "."
271+
272+
evil_kate = NamedNode()
273+
with pytest.raises(ValueError, match="nodes do not lie within the same tree"):
274+
sue.relative_to(evil_kate)
275+
276+
264277
def create_test_tree():
265-
f = TreeNode()
266-
b = TreeNode()
267-
a = TreeNode()
268-
d = TreeNode()
269-
c = TreeNode()
270-
e = TreeNode()
271-
g = TreeNode()
272-
i = TreeNode()
273-
h = TreeNode()
278+
f = NamedNode()
279+
b = NamedNode()
280+
a = NamedNode()
281+
d = NamedNode()
282+
c = NamedNode()
283+
e = NamedNode()
284+
g = NamedNode()
285+
i = NamedNode()
286+
h = NamedNode()
274287

275288
f.children = {"b": b, "g": g}
276289
b.children = {"a": a, "d": d}
@@ -286,7 +299,7 @@ def test_preorderiter(self):
286299
tree = create_test_tree()
287300
result = [node.name for node in PreOrderIter(tree)]
288301
expected = [
289-
None, # root TreeNode is unnamed
302+
None, # root Node is unnamed
290303
"b",
291304
"a",
292305
"d",
@@ -302,7 +315,7 @@ def test_levelorderiter(self):
302315
tree = create_test_tree()
303316
result = [node.name for node in LevelOrderIter(tree)]
304317
expected = [
305-
None, # root TreeNode is unnamed
318+
None, # root Node is unnamed
306319
"b",
307320
"g",
308321
"a",
@@ -317,19 +330,19 @@ def test_levelorderiter(self):
317330

318331
class TestRenderTree:
319332
def test_render_nodetree(self):
320-
sam = TreeNode()
321-
ben = TreeNode()
322-
mary = TreeNode(children={"Sam": sam, "Ben": ben})
323-
kate = TreeNode()
324-
john = TreeNode(children={"Mary": mary, "Kate": kate})
333+
sam = NamedNode()
334+
ben = NamedNode()
335+
mary = NamedNode(children={"Sam": sam, "Ben": ben})
336+
kate = NamedNode()
337+
john = NamedNode(children={"Mary": mary, "Kate": kate})
325338

326339
printout = john.__str__()
327340
expected_nodes = [
328-
"TreeNode()",
329-
"TreeNode('Mary')",
330-
"TreeNode('Sam')",
331-
"TreeNode('Ben')",
332-
"TreeNode('Kate')",
341+
"NamedNode()",
342+
"NamedNode('Mary')",
343+
"NamedNode('Sam')",
344+
"NamedNode('Ben')",
345+
"NamedNode('Kate')",
333346
]
334347
for expected_node, printed_node in zip(expected_nodes, printout.splitlines()):
335348
assert expected_node in printed_node

0 commit comments

Comments
 (0)