From 64790998db6aa3a9b3b8e7f776ed8295b4c07881 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Fri, 14 Jun 2024 14:44:22 -0500 Subject: [PATCH 1/5] Update examples to use CompatibilityAxes --- examples/first.py | 1 + examples/mandelbrot.py | 18 +++++++++++------- examples/units.py | 28 +++++++++++++++++----------- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/examples/first.py b/examples/first.py index 4357c9e..d8baf8d 100644 --- a/examples/first.py +++ b/examples/first.py @@ -38,4 +38,5 @@ ax.add_artist(lw2, 2) ax.set_xlim(0, np.pi * 4) ax.set_ylim(-1.1, 1.1) + plt.show() diff --git a/examples/mandelbrot.py b/examples/mandelbrot.py index 8a4b878..2efed36 100644 --- a/examples/mandelbrot.py +++ b/examples/mandelbrot.py @@ -13,7 +13,8 @@ import matplotlib.pyplot as plt import numpy as np -from data_prototype.wrappers import ImageWrapper +from data_prototype.artist import CompatibilityAxes +from data_prototype.image import Image from data_prototype.containers import FuncContainer from matplotlib.colors import Normalize @@ -36,19 +37,22 @@ def mandelbrot_set(X, Y, maxiter, *, horizon=3, power=2): fc = FuncContainer( {}, xyfuncs={ - "xextent": ((2,), lambda x, y: [x[0], x[-1]]), - "yextent": ((2,), lambda x, y: [y[0], y[-1]]), + "x": ((2,), lambda x, y: [x[0], x[-1]]), + "y": ((2,), lambda x, y: [y[0], y[-1]]), "image": (("N", "M"), lambda x, y: mandelbrot_set(x, y, maxiter)[1]), }, ) cmap = plt.get_cmap() cmap.set_under("w") -im = ImageWrapper(fc, norm=Normalize(0, maxiter), cmap=cmap) +im = Image(fc, norm=Normalize(0, maxiter), cmap=cmap) -fig, ax = plt.subplots() +fig, nax = plt.subplots() +ax = CompatibilityAxes(nax) +nax.add_artist(ax) ax.add_artist(im) ax.set_xlim(-1, 1) ax.set_ylim(-1, 1) -ax.set_aspect("equal") -fig.colorbar(im) + +nax.set_aspect("equal") # No equivalent yet + plt.show() diff --git a/examples/units.py b/examples/units.py index 42c654d..deccbe5 100644 --- a/examples/units.py +++ b/examples/units.py @@ -7,14 +7,16 @@ """ import numpy as np +from collections import defaultdict import matplotlib.pyplot as plt import matplotlib.markers as mmarkers +from data_prototype.artist import CompatibilityAxes from data_prototype.containers import ArrayContainer -from data_prototype.conversion_node import DelayedConversionNode +from data_prototype.conversion_edge import FuncEdge -from data_prototype.wrappers import PathCollectionWrapper +from data_prototype.line import Line import pint @@ -23,7 +25,11 @@ marker_obj = mmarkers.MarkerStyle("o") + +coords = defaultdict(lambda: "auto") +coords["x"] = coords["y"] = "units" cont = ArrayContainer( + coords, x=np.array([0, 1, 2]) * ureg.m, y=np.array([1, 4, 2]) * ureg.m, paths=np.array([marker_obj.get_path()]), @@ -32,17 +38,17 @@ facecolors=np.array(["C3"]), ) -fig, ax = plt.subplots() +fig, nax = plt.subplots() +ax = CompatibilityAxes(nax) +nax.add_artist(ax) ax.set_xlim(-0.5, 7) ax.set_ylim(0, 5) -# DelayedConversionNode is used to identify the keys which undergo unit transformations -# The actual method which does conversions in this example is added by the -# `Axis`/`Axes`, but `PathCollectionWrapper` does not natively interact with the units. -xconv = DelayedConversionNode.from_keys(("x",), converter_key="xunits") -yconv = DelayedConversionNode.from_keys(("y",), converter_key="yunits") -lw = PathCollectionWrapper(cont, [xconv, yconv], offset_transform=ax.transData) +xconv = FuncEdge.from_func("xconv", lambda x, xunits: x.to(xunits), "units", "data") +yconv = FuncEdge.from_func("yconv", lambda y, yunits: y.to(yunits), "units", "data") +lw = Line(cont, [xconv, yconv]) ax.add_artist(lw) -ax.xaxis.set_units(ureg.feet) -ax.yaxis.set_units(ureg.m) +nax.xaxis.set_units(ureg.feet) +nax.yaxis.set_units(ureg.m) + plt.show() From 47cd02201fe991dbbcf3959a1d9da88c626e94ca Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Fri, 14 Jun 2024 14:47:26 -0500 Subject: [PATCH 2/5] Picking and Contains for new Artists --- data_prototype/artist.py | 106 ++++++++++++++++++++++++++---- data_prototype/conversion_edge.py | 10 ++- data_prototype/image.py | 34 ++++++++-- data_prototype/line.py | 64 ++++++++++++++++++ 4 files changed, 197 insertions(+), 17 deletions(-) diff --git a/data_prototype/artist.py b/data_prototype/artist.py index d2f2db1..4ff7a05 100644 --- a/data_prototype/artist.py +++ b/data_prototype/artist.py @@ -3,9 +3,12 @@ import numpy as np +from matplotlib.backend_bases import PickEvent +import matplotlib.artist as martist + from .containers import DataContainer, ArrayContainer, DataUnion from .description import Desc, desc_like -from .conversion_edge import Edge, Graph, TransformEdge +from .conversion_edge import Edge, FuncEdge, Graph, TransformEdge class Artist: @@ -18,6 +21,9 @@ def __init__( kwargs_cont = ArrayContainer(**kwargs) self._container = DataUnion(container, kwargs_cont) + self._children: list[tuple[float, Artist]] = [] + self._picker = None + edges = edges or [] self._visible = True self._graph = Graph(edges) @@ -41,6 +47,77 @@ def get_visible(self): def set_visible(self, visible): self._visible = visible + def pickable(self) -> bool: + return self._picker is not None + + def get_picker(self): + return self._picker + + def set_picker(self, picker): + self._picker = picker + + def contains(self, mouseevent, graph=None): + """ + Test whether the artist contains the mouse event. + + Parameters + ---------- + mouseevent : `~matplotlib.backend_bases.MouseEvent` + + Returns + ------- + contains : bool + Whether any values are within the radius. + details : dict + An artist-specific dictionary of details of the event context, + such as which points are contained in the pick radius. See the + individual Artist subclasses for details. + """ + return False, {} + + def get_children(self): + return [a[1] for a in self._children] + + def pick(self, mouseevent, graph: Graph | None = None): + """ + Process a pick event. + + Each child artist will fire a pick event if *mouseevent* is over + the artist and the artist has picker set. + + See Also + -------- + set_picker, get_picker, pickable + """ + if graph is None: + graph = self._graph + else: + graph = graph + self._graph + # Pick self + if self.pickable(): + picker = self.get_picker() + if callable(picker): + inside, prop = picker(self, mouseevent) + else: + inside, prop = self.contains(mouseevent, graph) + if inside: + PickEvent( + "pick_event", mouseevent.canvas, mouseevent, self, **prop + )._process() + + # Pick children + for a in self.get_children(): + # make sure the event happened in the same Axes + ax = getattr(a, "axes", None) + if mouseevent.inaxes is None or ax is None or mouseevent.inaxes == ax: + # we need to check if mouseevent.inaxes is None + # because some objects associated with an Axes (e.g., a + # tick label) can be outside the bounding box of the + # Axes and inaxes will be None + # also check that ax is None so that it traverse objects + # which do not have an axes property but children might + a.pick(mouseevent, graph) + class CompatibilityArtist: """A compatibility shim to ducktype as a classic Matplotlib Artist. @@ -59,7 +136,7 @@ class CompatibilityArtist: useful for avoiding accidental dependency. """ - def __init__(self, artist: Artist): + def __init__(self, artist: martist.Artist): self._artist = artist self._axes = None @@ -134,7 +211,7 @@ def draw(self, renderer, graph=None): self._artist.draw(renderer, graph + self._graph) -class CompatibilityAxes: +class CompatibilityAxes(Artist): """A compatibility shim to add to traditional matplotlib axes. At this time features are implemented on an "as needed" basis, and many @@ -152,12 +229,11 @@ class CompatibilityAxes: """ def __init__(self, axes): + super().__init__(ArrayContainer()) self._axes = axes self.figure = None self._clippath = None - self._visible = True self.zorder = 2 - self._children: list[tuple[float, Artist]] = [] @property def axes(self): @@ -187,6 +263,18 @@ def axes(self, ax): desc_like(xy, coordinates="display"), transform=self._axes.transAxes, ), + FuncEdge.from_func( + "xunits", + lambda: self._axes.xunits, + {}, + {"xunits": Desc((), "units")}, + ), + FuncEdge.from_func( + "yunits", + lambda: self._axes.yunits, + {}, + {"yunits": Desc((), "units")}, + ), ], aliases=(("parent", "axes"),), ) @@ -210,7 +298,7 @@ def get_animated(self): return False def draw(self, renderer, graph=None): - if not self.visible: + if not self.get_visible(): return if graph is None: graph = Graph([]) @@ -228,9 +316,3 @@ def set_xlim(self, min_=None, max_=None): def set_ylim(self, min_=None, max_=None): self.axes.set_ylim(min_, max_) - - def get_visible(self): - return self._visible - - def set_visible(self, visible): - self._visible = visible diff --git a/data_prototype/conversion_edge.py b/data_prototype/conversion_edge.py index 9091783..2a09cc1 100644 --- a/data_prototype/conversion_edge.py +++ b/data_prototype/conversion_edge.py @@ -295,6 +295,12 @@ def __ge__(self, other): def __gt__(self, other): return self.weight > other.weight + @property + def edges(self): + if self.prev_node is None: + return [self.edge] + return self.prev_node.edges + [self.edge] + q: PriorityQueue[Node] = PriorityQueue() q.put(Node(0, input)) @@ -308,6 +314,8 @@ def __gt__(self, other): best = n continue for e in sub_edges: + if e in n.edges: + continue if Desc.compatible(n.desc, e.input, aliases=self._aliases): d = n.desc | e.output w = n.weight + e.weight @@ -397,7 +405,7 @@ def node_format(x): ) try: - pos = nx.planar_layout(G) + pos = nx.shell_layout(G) except Exception: pos = nx.circular_layout(G) plt.figure() diff --git a/data_prototype/image.py b/data_prototype/image.py index c8821fe..746345c 100644 --- a/data_prototype/image.py +++ b/data_prototype/image.py @@ -14,11 +14,11 @@ def _interpolate_nearest(image, x, y): l, r = x width = int(((round(r) + 0.5) - (round(l) - 0.5)) * magnification) - xpix = np.digitize(np.arange(width), np.linspace(0, r - l, image.shape[1] + 1)) + xpix = np.digitize(np.arange(width), np.linspace(0, r - l, image.shape[1])) b, t = y height = int(((round(t) + 0.5) - (round(b) - 0.5)) * magnification) - ypix = np.digitize(np.arange(height), np.linspace(0, t - b, image.shape[0] + 1)) + ypix = np.digitize(np.arange(height), np.linspace(0, t - b, image.shape[0])) out = np.empty((height, width, 4)) @@ -53,7 +53,7 @@ def __init__(self, container, edges=None, norm=None, cmap=None, **kwargs): {"image": Desc(("O", "P", 4), coordinates="rgba_resampled")}, ) - self._edges += [ + edges = [ CoordinateEdge.from_coords("xycoords", {"x": "auto", "y": "auto"}, "data"), CoordinateEdge.from_coords( "image_coords", {"image": Desc(("M", "N"), "auto")}, "data" @@ -79,7 +79,7 @@ def __init__(self, container, edges=None, norm=None, cmap=None, **kwargs): self._interpolation_edge, ] - self._graph = Graph(self._edges, (("data", "data_resampled"),)) + self._graph = self._graph + Graph(edges, (("data", "data_resampled"),)) def draw(self, renderer, graph: Graph) -> None: if not self.get_visible(): @@ -111,3 +111,29 @@ def draw(self, renderer, graph: Graph) -> None: mtransforms.Bbox.from_extents(clipx[0], clipy[0], clipx[1], clipy[1]) ) renderer.draw_image(gc, x[0], y[0], image) # TODO vector backend transforms + + def contains(self, mouseevent, graph=None): + if graph is None: + return False, {} + g = graph + self._graph + conv = g.evaluator( + self._container.describe(), + { + "x": Desc(("X",), "display"), + "y": Desc(("Y",), "display"), + }, + ).inverse + query, _ = self._container.query(g) + xmin, xmax = query["x"] + ymin, ymax = query["y"] + x, y = conv.evaluate({"x": mouseevent.x, "y": mouseevent.y}).values() + + # This checks xmin <= x <= xmax *or* xmax <= x <= xmin. + inside = ( + x is not None + and (x - xmin) * (x - xmax) <= 0 + and y is not None + and (y - ymin) * (y - ymax) <= 0 + ) + + return inside, {} diff --git a/data_prototype/line.py b/data_prototype/line.py index 026cb10..8805e8b 100644 --- a/data_prototype/line.py +++ b/data_prototype/line.py @@ -9,6 +9,8 @@ from .description import Desc from .conversion_edge import Graph, CoordinateEdge, DefaultEdge +segment_hits = mlines.segment_hits + class Line(Artist): def __init__(self, container, edges=None, **kwargs): @@ -57,6 +59,68 @@ def __init__(self, container, edges=None, **kwargs): # - non-str markers # Each individually pretty easy, but relatively rare features, focusing on common cases + def contains(self, mouseevent, graph=None): + """ + Test whether *mouseevent* occurred on the line. + + An event is deemed to have occurred "on" the line if it is less + than ``self.pickradius`` (default: 5 points) away from it. Use + `~.Line2D.get_pickradius` or `~.Line2D.set_pickradius` to get or set + the pick radius. + + Parameters + ---------- + mouseevent : `~matplotlib.backend_bases.MouseEvent` + + Returns + ------- + contains : bool + Whether any values are within the radius. + details : dict + A dictionary ``{'ind': pointlist}``, where *pointlist* is a + list of points of the line that are within the pickradius around + the event position. + + TODO: sort returned indices by distance + """ + if graph is None: + return False, {} + + g = graph + self._graph + desc = Desc(("N",), "display") + scalar = Desc((), "display") # ... this needs thinking... + # Convert points to pixels + require = { + "x": desc, + "y": desc, + "linestyle": scalar, + } + conv = g.evaluator(self._container.describe(), require) + query, _ = self._container.query(g) + xt, yt, linestyle = conv.evaluate(query).values() + + # Convert pick radius from points to pixels + pixels = 5 # self._pickradius # TODO + + # The math involved in checking for containment (here and inside of + # segment_hits) assumes that it is OK to overflow, so temporarily set + # the error flags accordingly. + with np.errstate(all="ignore"): + # Check for collision + if linestyle in ["None", None]: + # If no line, return the nearby point(s) + (ind,) = np.nonzero( + (xt - mouseevent.x) ** 2 + (yt - mouseevent.y) ** 2 <= pixels**2 + ) + else: + # If line, return the nearby segment(s) + ind = segment_hits(mouseevent.x, mouseevent.y, xt, yt, pixels) + # if self._drawstyle.startswith("steps"): + # ind //= 2 + + # Return the point(s) within radius + return len(ind) > 0, dict(ind=ind) + def draw(self, renderer, graph: Graph) -> None: if not self.get_visible(): return From 4520cba569d28055623a634a27875dafc8690ea7 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Wed, 26 Jun 2024 14:02:14 -0500 Subject: [PATCH 3/5] Fix units example --- data_prototype/artist.py | 4 ++-- data_prototype/containers.py | 4 ++-- examples/units.py | 19 +++++++++++++++++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/data_prototype/artist.py b/data_prototype/artist.py index 4ff7a05..7beddc3 100644 --- a/data_prototype/artist.py +++ b/data_prototype/artist.py @@ -265,13 +265,13 @@ def axes(self, ax): ), FuncEdge.from_func( "xunits", - lambda: self._axes.xunits, + lambda: self._axes.xaxis.units, {}, {"xunits": Desc((), "units")}, ), FuncEdge.from_func( "yunits", - lambda: self._axes.yunits, + lambda: self._axes.yaxis.units, {}, {"yunits": Desc((), "units")}, ), diff --git a/data_prototype/containers.py b/data_prototype/containers.py index fd45106..0308dde 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -89,8 +89,8 @@ def __init__(self, coordinates: dict[str, str] | None = None, /, **data): self._desc = { k: ( Desc(v.shape, coordinates.get(k, "auto")) - if isinstance(v, np.ndarray) - else Desc(()) + if hasattr(v, "shape") + else Desc((), coordinates.get(k, "auto")) ) for k, v in data.items() } diff --git a/examples/units.py b/examples/units.py index deccbe5..eca80b0 100644 --- a/examples/units.py +++ b/examples/units.py @@ -15,6 +15,7 @@ from data_prototype.artist import CompatibilityAxes from data_prototype.containers import ArrayContainer from data_prototype.conversion_edge import FuncEdge +from data_prototype.description import Desc from data_prototype.line import Line @@ -44,9 +45,23 @@ ax.set_xlim(-0.5, 7) ax.set_ylim(0, 5) -xconv = FuncEdge.from_func("xconv", lambda x, xunits: x.to(xunits), "units", "data") -yconv = FuncEdge.from_func("yconv", lambda y, yunits: y.to(yunits), "units", "data") +scalar = Desc((), "units") +unit_vector = Desc(("N",), "units") + +xconv = FuncEdge.from_func( + "xconv", + lambda x, xunits: x.to(xunits), + {"x": unit_vector, "xunits": scalar}, + {"x": Desc(("N",), "data")}, +) +yconv = FuncEdge.from_func( + "yconv", + lambda y, yunits: y.to(yunits), + {"y": unit_vector, "yunits": scalar}, + {"y": Desc(("N",), "data")}, +) lw = Line(cont, [xconv, yconv]) + ax.add_artist(lw) nax.xaxis.set_units(ureg.feet) nax.yaxis.set_units(ureg.m) From 3b8662eff68e2b97c99f3e414548fa9e7c0eb79c Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Wed, 26 Jun 2024 14:12:51 -0500 Subject: [PATCH 4/5] remove dtype from tests --- data_prototype/tests/test_containers.py | 1 - examples/units.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/data_prototype/tests/test_containers.py b/data_prototype/tests/test_containers.py index fb2fc7d..3ecffa8 100644 --- a/data_prototype/tests/test_containers.py +++ b/data_prototype/tests/test_containers.py @@ -22,7 +22,6 @@ def _verify_describe(container): assert set(data) == set(desc) for k, v in data.items(): assert v.shape == desc[k].shape - assert v.dtype == desc[k].dtype def test_array_describe(ac): diff --git a/examples/units.py b/examples/units.py index eca80b0..9c167d5 100644 --- a/examples/units.py +++ b/examples/units.py @@ -63,7 +63,7 @@ lw = Line(cont, [xconv, yconv]) ax.add_artist(lw) -nax.xaxis.set_units(ureg.feet) -nax.yaxis.set_units(ureg.m) +nax.xaxis.set_units(ureg.m) +nax.yaxis.set_units(ureg.cm) plt.show() From f7c8a076c07be4fda4214f517d9177f9224ba7a7 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Fri, 5 Jul 2024 13:51:53 -0500 Subject: [PATCH 5/5] Fix units example, drop units using magnitude --- examples/units.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/units.py b/examples/units.py index 9c167d5..74a246f 100644 --- a/examples/units.py +++ b/examples/units.py @@ -50,20 +50,20 @@ xconv = FuncEdge.from_func( "xconv", - lambda x, xunits: x.to(xunits), + lambda x, xunits: x.to(xunits).magnitude, {"x": unit_vector, "xunits": scalar}, {"x": Desc(("N",), "data")}, ) yconv = FuncEdge.from_func( "yconv", - lambda y, yunits: y.to(yunits), + lambda y, yunits: y.to(yunits).magnitude, {"y": unit_vector, "yunits": scalar}, {"y": Desc(("N",), "data")}, ) lw = Line(cont, [xconv, yconv]) ax.add_artist(lw) -nax.xaxis.set_units(ureg.m) -nax.yaxis.set_units(ureg.cm) +nax.xaxis.set_units(ureg.ft) +nax.yaxis.set_units(ureg.m) plt.show()