Skip to content

Commit 7ebe0d5

Browse files
committed
visibility, zorder, parent/aliases
1 parent ab2a630 commit 7ebe0d5

File tree

9 files changed

+310
-135
lines changed

9 files changed

+310
-135
lines changed

data_prototype/artist.py

+120-57
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from bisect import insort
12
from typing import Sequence
23

4+
import numpy as np
5+
36
from .containers import DataContainer, ArrayContainer, DataUnion
47
from .description import Desc, desc_like
5-
from .conversion_edge import Edge, TransformEdge
8+
from .conversion_edge import Edge, Graph, TransformEdge
69

710

811
class Artist:
@@ -16,11 +19,28 @@ def __init__(
1619
self._container = DataUnion(container, kwargs_cont)
1720

1821
edges = edges or []
19-
self._edges = list(edges)
20-
21-
def draw(self, renderer, edges: Sequence[Edge]) -> None:
22+
self._visible = True
23+
self._graph = Graph(edges)
24+
self._clip_box: DataContainer = ArrayContainer(
25+
{"x": "parent", "y": "parent"},
26+
**{"x": np.asarray([0, 1]), "y": np.asarray([0, 1])}
27+
)
28+
29+
def draw(self, renderer, graph: Graph) -> None:
2230
return
2331

32+
def set_clip_box(self, container: DataContainer) -> None:
33+
self._clip_box = container
34+
35+
def get_clip_box(self, container: DataContainer) -> DataContainer:
36+
return self._clip_box
37+
38+
def get_visible(self):
39+
return self._visible
40+
41+
def set_visible(self, visible):
42+
self._visible = visible
43+
2444

2545
class CompatibilityArtist:
2646
"""A compatibility shim to ducktype as a classic Matplotlib Artist.
@@ -42,10 +62,44 @@ class CompatibilityArtist:
4262
def __init__(self, artist: Artist):
4363
self._artist = artist
4464

45-
self.axes = None
65+
self._axes = None
4666
self.figure = None
4767
self._clippath = None
68+
self._visible = True
4869
self.zorder = 2
70+
self._graph = Graph([])
71+
72+
@property
73+
def axes(self):
74+
return self._axes
75+
76+
@axes.setter
77+
def axes(self, ax):
78+
self._axes = ax
79+
80+
if self._axes is None:
81+
self._graph = Graph([])
82+
return
83+
84+
desc: Desc = Desc(("N",), coordinates="data")
85+
xy: dict[str, Desc] = {"x": desc, "y": desc}
86+
self._graph = Graph(
87+
[
88+
TransformEdge(
89+
"data",
90+
xy,
91+
desc_like(xy, coordinates="axes"),
92+
transform=self._axes.transData - self._axes.transAxes,
93+
),
94+
TransformEdge(
95+
"axes",
96+
desc_like(xy, coordinates="axes"),
97+
desc_like(xy, coordinates="display"),
98+
transform=self._axes.transAxes,
99+
),
100+
],
101+
aliases=(("parent", "axes"),),
102+
)
49103

50104
def set_figure(self, fig):
51105
self.figure = fig
@@ -65,32 +119,19 @@ def set_clip_path(self, path):
65119
def get_animated(self):
66120
return False
67121

68-
def draw(self, renderer, edges=None):
122+
def get_visible(self):
123+
return self._visible
69124

70-
if edges is None:
71-
edges = []
125+
def set_visible(self, visible):
126+
self._visible = visible
72127

73-
if self.axes is not None:
74-
desc: Desc = Desc(("N",), coordinates="data")
75-
xy: dict[str, Desc] = {"x": desc, "y": desc}
76-
edges.append(
77-
TransformEdge(
78-
"data",
79-
xy,
80-
desc_like(xy, coordinates="axes"),
81-
transform=self.axes.transData - self.axes.transAxes,
82-
)
83-
)
84-
edges.append(
85-
TransformEdge(
86-
"axes",
87-
desc_like(xy, coordinates="axes"),
88-
desc_like(xy, coordinates="display"),
89-
transform=self.axes.transAxes,
90-
)
91-
)
128+
def draw(self, renderer, graph=None):
129+
if not self.get_visible():
130+
return
92131

93-
self._artist.draw(renderer, edges)
132+
if graph is None:
133+
graph = Graph([])
134+
self._artist.draw(renderer, graph + self._graph)
94135

95136

96137
class CompatibilityAxes:
@@ -111,11 +152,44 @@ class CompatibilityAxes:
111152
"""
112153

113154
def __init__(self, axes):
114-
self.axes = axes
155+
self._axes = axes
115156
self.figure = None
116157
self._clippath = None
158+
self._visible = True
117159
self.zorder = 2
118-
self._children = []
160+
self._children: list[tuple[float, Artist]] = []
161+
162+
@property
163+
def axes(self):
164+
return self._axes
165+
166+
@axes.setter
167+
def axes(self, ax):
168+
self._axes = ax
169+
170+
if self._axes is None:
171+
self._graph = Graph([])
172+
return
173+
174+
desc: Desc = Desc(("N",), coordinates="data")
175+
xy: dict[str, Desc] = {"x": desc, "y": desc}
176+
self._graph = Graph(
177+
[
178+
TransformEdge(
179+
"data",
180+
xy,
181+
desc_like(xy, coordinates="axes"),
182+
transform=self._axes.transData - self._axes.transAxes,
183+
),
184+
TransformEdge(
185+
"axes",
186+
desc_like(xy, coordinates="axes"),
187+
desc_like(xy, coordinates="display"),
188+
transform=self._axes.transAxes,
189+
),
190+
],
191+
aliases=(("parent", "axes"),),
192+
)
119193

120194
def set_figure(self, fig):
121195
self.figure = fig
@@ -135,39 +209,28 @@ def set_clip_path(self, path):
135209
def get_animated(self):
136210
return False
137211

138-
def draw(self, renderer, edges=None):
139-
if edges is None:
140-
edges = []
212+
def draw(self, renderer, graph=None):
213+
if not self.visible:
214+
return
215+
if graph is None:
216+
graph = Graph([])
141217

142-
if self.axes is not None:
143-
desc: Desc = Desc(("N",), coordinates="data")
144-
xy: dict[str, Desc] = {"x": desc, "y": desc}
145-
edges.append(
146-
TransformEdge(
147-
"data",
148-
xy,
149-
desc_like(xy, coordinates="axes"),
150-
transform=self.axes.transData - self.axes.transAxes,
151-
)
152-
)
153-
edges.append(
154-
TransformEdge(
155-
"axes",
156-
desc_like(xy, coordinates="axes"),
157-
desc_like(xy, coordinates="display"),
158-
transform=self.axes.transAxes,
159-
)
160-
)
218+
graph = graph + self._graph
161219

162-
# TODO independent zorder
163-
for c in self._children:
164-
c.draw(renderer, edges)
220+
for _, c in self._children:
221+
c.draw(renderer, graph)
165222

166-
def add_artist(self, artist):
167-
self._children.append(artist)
223+
def add_artist(self, artist, zorder=1):
224+
insort(self._children, (zorder, artist), key=lambda x: x[0])
168225

169226
def set_xlim(self, min_=None, max_=None):
170227
self.axes.set_xlim(min_, max_)
171228

172229
def set_ylim(self, min_=None, max_=None):
173230
self.axes.set_ylim(min_, max_)
231+
232+
def get_visible(self):
233+
return self._visible
234+
235+
def set_visible(self, visible):
236+
self._visible = visible

data_prototype/containers.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,16 @@ class NoNewKeys(ValueError): ...
8282

8383

8484
class ArrayContainer:
85-
def __init__(self, **data):
85+
def __init__(self, coordinates: dict[str, str] | None = None, /, **data):
86+
coordinates = coordinates or {}
8687
self._data = data
8788
self._cache_key = str(uuid.uuid4())
8889
self._desc = {
89-
k: (Desc(v.shape) if isinstance(v, np.ndarray) else Desc(()))
90+
k: (
91+
Desc(v.shape, coordinates.get(k, "auto"))
92+
if isinstance(v, np.ndarray)
93+
else Desc(())
94+
)
9095
for k, v in data.items()
9196
}
9297

data_prototype/conversion_edge.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,11 @@ def inverse(self) -> "TransformEdge":
225225

226226

227227
class Graph:
228-
def __init__(self, edges: Sequence[Edge]):
229-
self._edges = edges
228+
def __init__(
229+
self, edges: Sequence[Edge], aliases: tuple[tuple[str, str], ...] = ()
230+
):
231+
self._edges = tuple(edges)
232+
self._aliases = aliases
230233

231234
self._subgraphs: list[tuple[set[str], list[Edge]]] = []
232235
for edge in self._edges:
@@ -245,29 +248,53 @@ def __init__(self, edges: Sequence[Edge]):
245248
s |= keys
246249
self._subgraphs[overlapping[0]][1].append(edge)
247250
else:
248-
edges_combined = []
251+
edges_combined = [edge]
249252
for n in overlapping:
250253
keys |= self._subgraphs[n][0]
251254
edges_combined.extend(self._subgraphs[n][1])
252255
for n in overlapping[::-1]:
253256
self._subgraphs.pop(n)
254257
self._subgraphs.append((keys, edges_combined))
255258

259+
def _resolve_alias(self, coord: str) -> str:
260+
while True:
261+
for coa, cob in self._aliases:
262+
if coord == coa:
263+
coord = cob
264+
break
265+
else:
266+
break
267+
return coord
268+
256269
def evaluator(self, input: dict[str, Desc], output: dict[str, Desc]) -> Edge:
257270
out_edges = []
271+
258272
for sub_keys, sub_edges in self._subgraphs:
259273
if not (sub_keys & set(output) or sub_keys & set(input)):
260274
continue
275+
261276
output_subset = {k: v for k, v in output.items() if k in sub_keys}
262277
sub_edges = sorted(sub_edges, key=lambda x: x.weight)
263278

264-
@dataclass(order=True)
279+
@dataclass
265280
class Node:
266281
weight: float
267282
desc: dict[str, Desc]
268283
prev_node: Node | None = None
269284
edge: Edge | None = None
270285

286+
def __le__(self, other):
287+
return self.weight <= other.weight
288+
289+
def __lt__(self, other):
290+
return self.weight < other.weight
291+
292+
def __ge__(self, other):
293+
return self.weight >= other.weight
294+
295+
def __gt__(self, other):
296+
return self.weight > other.weight
297+
271298
q: PriorityQueue[Node] = PriorityQueue()
272299
q.put(Node(0, input))
273300

@@ -276,12 +303,12 @@ class Node:
276303
n = q.get()
277304
if n.weight > best.weight:
278305
continue
279-
if Desc.compatible(n.desc, output_subset):
306+
if Desc.compatible(n.desc, output_subset, aliases=self._aliases):
280307
if n.weight < best.weight:
281308
best = n
282309
continue
283310
for e in sub_edges:
284-
if Desc.compatible(n.desc, e.input):
311+
if Desc.compatible(n.desc, e.input, aliases=self._aliases):
285312
d = n.desc | e.output
286313
w = n.weight + e.weight
287314

@@ -328,6 +355,7 @@ class Node:
328355
def visualize(self, input: dict[str, Desc] | None = None):
329356
import networkx as nx
330357
import matplotlib.pyplot as plt
358+
331359
from pprint import pformat
332360

333361
def node_format(x):
@@ -376,3 +404,9 @@ def node_format(x):
376404
nx.draw(G, pos=pos, with_labels=True)
377405
nx.draw_networkx_edge_labels(G, pos=pos)
378406
# plt.show()
407+
408+
def __add__(self, other: Graph) -> Graph:
409+
aself = {k: v for k, v in self._aliases}
410+
aother = {k: v for k, v in other._aliases}
411+
aliases = tuple((aself | aother).items())
412+
return Graph(self._edges + other._edges, aliases)

data_prototype/description.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,32 @@ def validate_shapes(
103103
return None
104104

105105
@staticmethod
106-
def compatible(a: dict[str, "Desc"], b: dict[str, "Desc"]) -> bool:
106+
def compatible(
107+
a: dict[str, "Desc"],
108+
b: dict[str, "Desc"],
109+
aliases: tuple[tuple[str, str], ...] = (),
110+
) -> bool:
107111
"""Determine if ``a`` is a valid input for ``b``.
108112
109113
Note: ``a`` _may_ have additional keys.
110114
"""
115+
116+
def resolve_aliases(coord):
117+
while True:
118+
for coa, cob in aliases:
119+
if coord == coa:
120+
coord = cob
121+
break
122+
else:
123+
break
124+
return coord
125+
111126
try:
112127
Desc.validate_shapes(b, a)
113128
except (KeyError, ValueError):
114129
return False
115130
for k, v in b.items():
116-
if a[k].coordinates != v.coordinates:
131+
if resolve_aliases(a[k].coordinates) != resolve_aliases(v.coordinates):
117132
return False
118133
return True
119134

0 commit comments

Comments
 (0)