Skip to content

Commit cb0e229

Browse files
murray-dsmvanwykclaude
authored
feat: add TreeGrid layout engine for tree diagrams (PR 3/6) (#357)
* feat: add TreeGrid layout engine for tree diagrams Add TreeGrid class for positioning nodes in a grid-based coordinate system and drawing curved connection lines between parent and child nodes. - TreeGrid manages tree layout with configurable spacing - Auto-calculates spacing if not provided (vertical: node_height + 0.6, horizontal: node_width - 1.0) - Simplified render() API to return only Axes (instead of tuple) - Draws curved connection lines using Bezier curves between parent and child nodes - Comprehensive unit tests: complete tree, axes management, edge cases - 25/25 tests passing, 79% coverage of tree_diagram.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * feat: add comprehensive validation and error handling to TreeGrid Add robust input validation and improve error messages for TreeGrid: - Validate invalid child references with descriptive error messages - Add input validation in __init__ for: - Grid dimensions (must be positive integers) - node_class (must be TreeNode subclass) - tree_structure (cannot be empty) - Node positions (must be within grid bounds) - Required 'position' key in node data - Extract magic numbers to class constants: - CONNECTION_CURVE_RADIUS = 0.15 - CONNECTION_LINE_WIDTH = 2 - CONNECTION_LINE_COLOR = "black" - Add 6 comprehensive validation tests with realistic retail data: - test_invalid_child_reference - test_invalid_grid_dimensions - test_invalid_node_class - test_empty_tree_structure - test_missing_position_key - test_out_of_bounds_position All 33 tests passing with 82% coverage (improved from 79%). Addresses code review issues #1, #2, #4, and #5 from PR #357. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * test: improve BaseRoundedBox test quality and coverage Enhance test assertions to verify actual behavior instead of trivial checks: - Improve test_box_creation_with_dimensions to verify actual width/height via bounding box instead of just checking object existence - Remove test_rendering_to_axes as it only tests basic matplotlib functionality, not package-specific behavior - Improve test_border_radius_top to verify correct vertex count (2 * 10 arc points + 2 straight + 1 close = 23) instead of just checking paths differ - Improve test_border_radius_bottom to verify correct vertex count (same formula as top) instead of just checking paths differ All tests now verify the claimed behavior with specific assertions based on the ARC_POINTS_PER_CORNER constant, following the pattern of the excellent test_zero_radius_creates_square_corners and test_nonzero_radius_creates_rounded_corners tests. 32/32 tests passing (removed 1 trivial test, all others improved). Addresses test quality issues from code review feedback. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Murray Vanwyk <murray.vanwyk@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent f168816 commit cb0e229

3 files changed

Lines changed: 559 additions & 224 deletions

File tree

pyretailscience/plots/tree_diagram.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import graphviz
1414
import matplotlib.patches as mpatches
15+
import matplotlib.pyplot as plt
1516
import numpy as np
1617
from matplotlib.axes import Axes
1718
from matplotlib.path import Path
@@ -488,3 +489,230 @@ def render(self, ax: Axes) -> None:
488489
fontsize=styling_context.fonts.label_size,
489490
color=text_color,
490491
)
492+
493+
494+
class TreeGrid:
495+
"""Grid-based tree diagram renderer with configurable node types."""
496+
497+
# Connection styling constants
498+
CONNECTION_CURVE_RADIUS = 0.15
499+
CONNECTION_LINE_WIDTH = 2
500+
CONNECTION_LINE_COLOR = "black"
501+
502+
def __init__(
503+
self,
504+
tree_structure: dict[str, dict],
505+
num_rows: int,
506+
num_cols: int,
507+
node_class: type[TreeNode],
508+
vertical_spacing: float | None = None,
509+
horizontal_spacing: float | None = None,
510+
) -> None:
511+
"""Initialize the tree grid.
512+
513+
Args:
514+
tree_structure: Dictionary mapping node IDs to node data with required keys
515+
depending on the node_class being used.
516+
num_rows: Number of rows in the grid.
517+
num_cols: Number of columns in the grid.
518+
node_class: The TreeNode subclass to use for rendering nodes.
519+
vertical_spacing: Vertical spacing between rows. If None, automatically calculated as
520+
node_height + 0.6 gap.
521+
horizontal_spacing: Horizontal spacing between columns. If None, automatically calculated as
522+
node_width - 1.0 overlap for compact layout.
523+
524+
Raises:
525+
ValueError: If grid dimensions are not positive, if tree_structure is empty,
526+
or if node positions are out of bounds.
527+
TypeError: If node_class is not a TreeNode subclass.
528+
529+
"""
530+
# Validate grid dimensions
531+
if num_rows <= 0 or num_cols <= 0:
532+
error_msg = f"Grid dimensions must be positive: num_rows={num_rows}, num_cols={num_cols}"
533+
raise ValueError(error_msg)
534+
535+
# Validate node_class is a TreeNode subclass
536+
if not issubclass(node_class, TreeNode):
537+
error_msg = f"node_class must be a TreeNode subclass, got {node_class}"
538+
raise TypeError(error_msg)
539+
540+
# Validate tree_structure is not empty
541+
if not tree_structure:
542+
raise ValueError("tree_structure cannot be empty")
543+
544+
self.tree_structure = tree_structure
545+
self.num_rows = num_rows
546+
self.num_cols = num_cols
547+
self.node_class = node_class
548+
549+
# Get node dimensions from the node class
550+
self.node_width = node_class.NODE_WIDTH
551+
self.node_height = node_class.NODE_HEIGHT
552+
553+
# Auto-calculate spacing if not provided
554+
self.vertical_spacing = vertical_spacing if vertical_spacing is not None else self.node_height + 0.6
555+
self.horizontal_spacing = horizontal_spacing if horizontal_spacing is not None else self.node_width - 1.0
556+
557+
# Validate positions are within grid bounds
558+
for node_id, node_data in tree_structure.items():
559+
if "position" not in node_data:
560+
error_msg = f"Node '{node_id}' is missing required 'position' key"
561+
raise ValueError(error_msg)
562+
563+
col_idx, row_idx = node_data["position"]
564+
if not (0 <= col_idx < num_cols):
565+
error_msg = f"Node '{node_id}' column index {col_idx} is out of bounds [0, {num_cols})"
566+
raise ValueError(error_msg)
567+
if not (0 <= row_idx < num_rows):
568+
error_msg = f"Node '{node_id}' row index {row_idx} is out of bounds [0, {num_rows})"
569+
raise ValueError(error_msg)
570+
571+
# Generate row and column positions
572+
self.row = {i: i * self.vertical_spacing for i in range(num_rows)}
573+
self.col = {i: i * self.horizontal_spacing for i in range(num_cols)}
574+
575+
def render(self, ax: Axes | None = None) -> Axes:
576+
"""Render the tree diagram.
577+
578+
Args:
579+
ax: Optional matplotlib axes object. If None, creates a new figure and axes.
580+
581+
Returns:
582+
The matplotlib axes object.
583+
584+
"""
585+
if ax is None:
586+
# Calculate plot dimensions based on layout
587+
plot_width = self.col[self.num_cols - 1] + self.node_width
588+
plot_height = self.row[self.num_rows - 1] + self.node_height
589+
590+
_, ax = plt.subplots(figsize=(plot_width, plot_height))
591+
ax.set_xlim(0, plot_width)
592+
ax.set_ylim(0, plot_height)
593+
ax.axis("off")
594+
595+
# First pass: create all nodes and store their centers
596+
node_centers = {}
597+
for node_id, node_data in self.tree_structure.items():
598+
# Convert grid coordinates to absolute positions
599+
col_idx, row_idx = node_data["position"]
600+
x = self.col[col_idx]
601+
y = self.row[row_idx]
602+
603+
# Extract data for the node (exclude position and children which are structural)
604+
data_dict = {k: v for k, v in node_data.items() if k not in ("position", "children")}
605+
606+
# Create and render the node
607+
node = self.node_class(
608+
data=data_dict,
609+
x=x,
610+
y=y,
611+
)
612+
node.render(ax)
613+
614+
# Store center positions for connections
615+
node_centers[node_id] = {
616+
"x": x + self.node_width / 2,
617+
"y_bottom": y,
618+
"y_top": y + self.node_height,
619+
}
620+
621+
# Second pass: draw connections
622+
for node_id, node_data in self.tree_structure.items():
623+
if "children" in node_data and len(node_data["children"]) > 0:
624+
parent = node_centers[node_id]
625+
for child_id in node_data["children"]:
626+
if child_id not in node_centers:
627+
available_nodes = list(node_centers.keys())
628+
error_msg = (
629+
f"Child node '{child_id}' referenced by '{node_id}' not found in tree_structure. "
630+
f"Available nodes: {available_nodes}"
631+
)
632+
raise ValueError(error_msg)
633+
child = node_centers[child_id]
634+
self._draw_connection(
635+
ax=ax,
636+
x1=parent["x"],
637+
y1=parent["y_bottom"],
638+
x2=child["x"],
639+
y2=child["y_top"],
640+
)
641+
642+
return ax
643+
644+
@staticmethod
645+
def _add_curve(
646+
verts: list[tuple[float, float]],
647+
codes: list[int],
648+
x: float,
649+
y: float,
650+
x_offset: float,
651+
y_offset: float,
652+
) -> None:
653+
"""Add Bezier curve control points to the path.
654+
655+
Args:
656+
verts: List of vertices to append to.
657+
codes: List of path codes to append to.
658+
x: X-coordinate of the curve start point.
659+
y: Y-coordinate of the curve start point.
660+
x_offset: X offset for the curve end point.
661+
y_offset: Y offset for the curve end point.
662+
663+
"""
664+
verts.append((x, y))
665+
codes.append(Path.CURVE3)
666+
verts.append((x + x_offset, y + y_offset))
667+
codes.append(Path.CURVE3)
668+
669+
@staticmethod
670+
def _draw_connection(ax: Axes, x1: float, y1: float, x2: float, y2: float) -> None:
671+
"""Draw connection line between nodes with curved corners.
672+
673+
Args:
674+
ax: Matplotlib axes object.
675+
x1: X-coordinate of first point.
676+
y1: Y-coordinate of first point.
677+
x2: X-coordinate of second point.
678+
y2: Y-coordinate of second point.
679+
680+
"""
681+
# Use class constants for connection styling
682+
curve_radius = TreeGrid.CONNECTION_CURVE_RADIUS
683+
line_width = TreeGrid.CONNECTION_LINE_WIDTH
684+
line_color = TreeGrid.CONNECTION_LINE_COLOR
685+
686+
mid_y = (y1 + y2) / 2
687+
curve_sign = 1 if x2 > x1 else -1
688+
689+
# Create path with curved corners using Bezier curves
690+
verts = []
691+
codes = []
692+
693+
# Start point (bottom of parent node)
694+
verts.append((x1, y1))
695+
codes.append(Path.MOVETO)
696+
697+
# Vertical line down to curve start
698+
verts.append((x1, mid_y + curve_radius))
699+
codes.append(Path.LINETO)
700+
701+
# Curve from vertical to horizontal (first corner)
702+
TreeGrid._add_curve(verts, codes, x1, mid_y, curve_sign * curve_radius, 0)
703+
704+
# Horizontal line
705+
verts.append((x2 - (curve_sign * curve_radius), mid_y))
706+
codes.append(Path.LINETO)
707+
708+
# Curve from horizontal to vertical (second corner)
709+
TreeGrid._add_curve(verts, codes, x2, mid_y, 0, -curve_radius)
710+
711+
# Vertical line up to child node
712+
verts.append((x2, y2))
713+
codes.append(Path.LINETO)
714+
715+
# Create and draw the path
716+
path = Path(verts, codes)
717+
patch = mpatches.PathPatch(path, facecolor="none", edgecolor=line_color, linewidth=line_width)
718+
ax.add_patch(patch)

0 commit comments

Comments
 (0)