Skip to content

Commit 910dd40

Browse files
mvanwykclaude
andcommitted
feat: add comprehensive validation and error handling to TreeGrid
Add robust input validation and improve error messages for TreeGrid: - Validate invalid child references with descriptive error messages - Add input validation in __init__ for: - Grid dimensions (must be positive integers) - node_class (must be TreeNode subclass) - tree_structure (cannot be empty) - Node positions (must be within grid bounds) - Required 'position' key in node data - Extract magic numbers to class constants: - CONNECTION_CURVE_RADIUS = 0.15 - CONNECTION_LINE_WIDTH = 2 - CONNECTION_LINE_COLOR = "black" - Add 6 comprehensive validation tests with realistic retail data: - test_invalid_child_reference - test_invalid_grid_dimensions - test_invalid_node_class - test_empty_tree_structure - test_missing_position_key - test_out_of_bounds_position All 33 tests passing with 82% coverage (improved from 79%). Addresses code review issues #1, #2, #4, and #5 from PR #357. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 150912c commit 910dd40

2 files changed

Lines changed: 197 additions & 4 deletions

File tree

pyretailscience/plots/tree_diagram.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,11 @@ def render(self, ax: Axes) -> None:
494494
class TreeGrid:
495495
"""Grid-based tree diagram renderer with configurable node types."""
496496

497+
# Connection styling constants
498+
CONNECTION_CURVE_RADIUS = 0.15
499+
CONNECTION_LINE_WIDTH = 2
500+
CONNECTION_LINE_COLOR = "black"
501+
497502
def __init__(
498503
self,
499504
tree_structure: dict[str, dict],
@@ -516,7 +521,26 @@ def __init__(
516521
horizontal_spacing: Horizontal spacing between columns. If None, automatically calculated as
517522
node_width - 1.0 overlap for compact layout.
518523
524+
Raises:
525+
ValueError: If grid dimensions are not positive, if tree_structure is empty,
526+
or if node positions are out of bounds.
527+
TypeError: If node_class is not a TreeNode subclass.
528+
519529
"""
530+
# Validate grid dimensions
531+
if num_rows <= 0 or num_cols <= 0:
532+
error_msg = f"Grid dimensions must be positive: num_rows={num_rows}, num_cols={num_cols}"
533+
raise ValueError(error_msg)
534+
535+
# Validate node_class is a TreeNode subclass
536+
if not issubclass(node_class, TreeNode):
537+
error_msg = f"node_class must be a TreeNode subclass, got {node_class}"
538+
raise TypeError(error_msg)
539+
540+
# Validate tree_structure is not empty
541+
if not tree_structure:
542+
raise ValueError("tree_structure cannot be empty")
543+
520544
self.tree_structure = tree_structure
521545
self.num_rows = num_rows
522546
self.num_cols = num_cols
@@ -530,6 +554,20 @@ def __init__(
530554
self.vertical_spacing = vertical_spacing if vertical_spacing is not None else self.node_height + 0.6
531555
self.horizontal_spacing = horizontal_spacing if horizontal_spacing is not None else self.node_width - 1.0
532556

557+
# Validate positions are within grid bounds
558+
for node_id, node_data in tree_structure.items():
559+
if "position" not in node_data:
560+
error_msg = f"Node '{node_id}' is missing required 'position' key"
561+
raise ValueError(error_msg)
562+
563+
col_idx, row_idx = node_data["position"]
564+
if not (0 <= col_idx < num_cols):
565+
error_msg = f"Node '{node_id}' column index {col_idx} is out of bounds [0, {num_cols})"
566+
raise ValueError(error_msg)
567+
if not (0 <= row_idx < num_rows):
568+
error_msg = f"Node '{node_id}' row index {row_idx} is out of bounds [0, {num_rows})"
569+
raise ValueError(error_msg)
570+
533571
# Generate row and column positions
534572
self.row = {i: i * self.vertical_spacing for i in range(num_rows)}
535573
self.col = {i: i * self.horizontal_spacing for i in range(num_cols)}
@@ -585,6 +623,13 @@ def render(self, ax: Axes | None = None) -> Axes:
585623
if "children" in node_data and len(node_data["children"]) > 0:
586624
parent = node_centers[node_id]
587625
for child_id in node_data["children"]:
626+
if child_id not in node_centers:
627+
available_nodes = list(node_centers.keys())
628+
error_msg = (
629+
f"Child node '{child_id}' referenced by '{node_id}' not found in tree_structure. "
630+
f"Available nodes: {available_nodes}"
631+
)
632+
raise ValueError(error_msg)
588633
child = node_centers[child_id]
589634
self._draw_connection(
590635
ax=ax,
@@ -633,10 +678,10 @@ def _draw_connection(ax: Axes, x1: float, y1: float, x2: float, y2: float) -> No
633678
y2: Y-coordinate of second point.
634679
635680
"""
636-
# Connection styling constants
637-
curve_radius = 0.15
638-
line_width = 2
639-
line_color = "black"
681+
# Use class constants for connection styling
682+
curve_radius = TreeGrid.CONNECTION_CURVE_RADIUS
683+
line_width = TreeGrid.CONNECTION_LINE_WIDTH
684+
line_color = TreeGrid.CONNECTION_LINE_COLOR
640685

641686
mid_y = (y1 + y2) / 2
642687
curve_sign = 1 if x2 > x1 else -1

tests/plots/test_tree_diagram.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,151 @@ def test_wide_tree(self):
466466
num_connections = 5
467467
expected_patches = num_nodes * patches_per_node + num_connections
468468
assert len(ax.patches) == expected_patches
469+
470+
def test_invalid_child_reference(self):
471+
"""Test that referencing a non-existent child raises ValueError."""
472+
tree_structure = {
473+
"total_revenue": {
474+
"header": "Total Revenue",
475+
"percent": 15.3,
476+
"value1": "£1.2M",
477+
"value2": "£1.04M",
478+
"position": (0, 0),
479+
"children": ["nonexistent_child"],
480+
},
481+
}
482+
483+
grid = TreeGrid(
484+
tree_structure=tree_structure,
485+
num_rows=1,
486+
num_cols=1,
487+
node_class=SimpleTreeNode,
488+
)
489+
490+
with pytest.raises(ValueError, match="not found in tree_structure"):
491+
grid.render()
492+
493+
def test_invalid_grid_dimensions(self):
494+
"""Test that invalid grid dimensions raise ValueError."""
495+
tree_structure = {
496+
"total_revenue": {
497+
"header": "Total Revenue",
498+
"percent": 8.5,
499+
"value1": "£850K",
500+
"value2": "£784K",
501+
"position": (0, 0),
502+
"children": [],
503+
},
504+
}
505+
506+
# Test negative rows
507+
with pytest.raises(ValueError, match="Grid dimensions must be positive"):
508+
TreeGrid(
509+
tree_structure=tree_structure,
510+
num_rows=-1,
511+
num_cols=1,
512+
node_class=SimpleTreeNode,
513+
)
514+
515+
# Test zero columns
516+
with pytest.raises(ValueError, match="Grid dimensions must be positive"):
517+
TreeGrid(
518+
tree_structure=tree_structure,
519+
num_rows=1,
520+
num_cols=0,
521+
node_class=SimpleTreeNode,
522+
)
523+
524+
def test_invalid_node_class(self):
525+
"""Test that invalid node_class raises TypeError."""
526+
tree_structure = {
527+
"customer_count": {
528+
"header": "Customer Count",
529+
"percent": 12.4,
530+
"value1": "25,450",
531+
"value2": "22,640",
532+
"position": (0, 0),
533+
"children": [],
534+
},
535+
}
536+
537+
with pytest.raises(TypeError, match="must be a TreeNode subclass"):
538+
TreeGrid(
539+
tree_structure=tree_structure,
540+
num_rows=1,
541+
num_cols=1,
542+
node_class=str, # Not a TreeNode subclass
543+
)
544+
545+
def test_empty_tree_structure(self):
546+
"""Test that empty tree_structure raises ValueError."""
547+
with pytest.raises(ValueError, match="tree_structure cannot be empty"):
548+
TreeGrid(
549+
tree_structure={},
550+
num_rows=1,
551+
num_cols=1,
552+
node_class=SimpleTreeNode,
553+
)
554+
555+
def test_missing_position_key(self):
556+
"""Test that missing position key raises ValueError."""
557+
tree_structure = {
558+
"avg_basket": {
559+
"header": "Average Basket Value",
560+
"percent": 6.2,
561+
"value1": "£45.80",
562+
"value2": "£43.12",
563+
# Missing 'position' key
564+
"children": [],
565+
},
566+
}
567+
568+
with pytest.raises(ValueError, match="missing required 'position' key"):
569+
TreeGrid(
570+
tree_structure=tree_structure,
571+
num_rows=1,
572+
num_cols=1,
573+
node_class=SimpleTreeNode,
574+
)
575+
576+
def test_out_of_bounds_position(self):
577+
"""Test that out of bounds positions raise ValueError."""
578+
# Test column out of bounds (trying to use column 1 when only column 0 exists)
579+
tree_structure = {
580+
"transaction_freq": {
581+
"header": "Transaction Frequency",
582+
"percent": 4.8,
583+
"value1": "3.2",
584+
"value2": "3.05",
585+
"position": (1, 0), # Column 1 is out of bounds for 1 column grid (0-indexed)
586+
"children": [],
587+
},
588+
}
589+
590+
with pytest.raises(ValueError, match="column index .* is out of bounds"):
591+
TreeGrid(
592+
tree_structure=tree_structure,
593+
num_rows=1,
594+
num_cols=1,
595+
node_class=SimpleTreeNode,
596+
)
597+
598+
# Test row out of bounds (trying to use row 1 when only row 0 exists)
599+
tree_structure = {
600+
"items_per_basket": {
601+
"header": "Items per Basket",
602+
"percent": -2.3,
603+
"value1": "4.8",
604+
"value2": "4.9",
605+
"position": (0, 1), # Row 1 is out of bounds for 1 row grid (0-indexed)
606+
"children": [],
607+
},
608+
}
609+
610+
with pytest.raises(ValueError, match="row index .* is out of bounds"):
611+
TreeGrid(
612+
tree_structure=tree_structure,
613+
num_rows=1,
614+
num_cols=1,
615+
node_class=SimpleTreeNode,
616+
)

0 commit comments

Comments
 (0)