|
| 1 | +"""Build Mermaid flowchart diagrams from Galaxy workflows.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from pathlib import Path |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +from gxformat2.normalized import ensure_format2, NormalizedFormat2 |
| 9 | +from gxformat2.schema.gxformat2 import FrameComment, GalaxyWorkflow, WorkflowInputParameter |
| 10 | + |
| 11 | +# Standard Mermaid shape wrappers: (open, close) bracket pairs. |
| 12 | +# >label] = asymmetric / flag (inputs) |
| 13 | +# [[label]] = subroutine (subworkflows) |
| 14 | +# [label] = rectangle (tool steps, default) |
| 15 | +SHAPE_INPUT = (">", "]") |
| 16 | +SHAPE_PARAM = ("{{", "}}") |
| 17 | +SHAPE_TOOL = ("[", "]") |
| 18 | +SHAPE_SUBWORKFLOW = ("[[", "]]") |
| 19 | + |
| 20 | +STEP_TYPE_SHAPES = { |
| 21 | + "data": SHAPE_INPUT, |
| 22 | + "collection": SHAPE_INPUT, |
| 23 | + "integer": SHAPE_PARAM, |
| 24 | + "float": SHAPE_PARAM, |
| 25 | + "text": SHAPE_PARAM, |
| 26 | + "boolean": SHAPE_PARAM, |
| 27 | + "color": SHAPE_PARAM, |
| 28 | + "input": SHAPE_INPUT, |
| 29 | + "tool": SHAPE_TOOL, |
| 30 | + "subworkflow": SHAPE_SUBWORKFLOW, |
| 31 | +} |
| 32 | + |
| 33 | +MAIN_TS_PREFIX = "toolshed.g2.bx.psu.edu/repos/" |
| 34 | + |
| 35 | + |
| 36 | +def _sanitize_label(label: str) -> str: |
| 37 | + """Escape characters that have special meaning in Mermaid labels.""" |
| 38 | + label = label.replace('"', "#quot;") |
| 39 | + for ch in "()[]{}<>": |
| 40 | + label = label.replace(ch, f"#{ord(ch)};") |
| 41 | + return label |
| 42 | + |
| 43 | + |
| 44 | +def _input_type_str(inp: WorkflowInputParameter) -> str: |
| 45 | + if inp.type_ is None: |
| 46 | + return "input" |
| 47 | + if isinstance(inp.type_, list): |
| 48 | + if inp.type_: |
| 49 | + return inp.type_[0].value |
| 50 | + return "input" |
| 51 | + return inp.type_.value |
| 52 | + |
| 53 | + |
| 54 | +def _node_line(node_id: str, label: str, shape: tuple[str, str]) -> str: |
| 55 | + open_br, close_br = shape |
| 56 | + return f'{node_id}{open_br}"{label}"{close_br}' |
| 57 | + |
| 58 | + |
| 59 | +def workflow_to_mermaid( |
| 60 | + workflow: dict[str, Any] | str | Path | GalaxyWorkflow | NormalizedFormat2, |
| 61 | + *, |
| 62 | + comments: bool = False, |
| 63 | +) -> str: |
| 64 | + """Convert a Galaxy workflow to a Mermaid flowchart string. |
| 65 | +
|
| 66 | + Accepts anything ``ensure_format2()`` supports, plus an already |
| 67 | + normalized ``NormalizedFormat2`` instance. |
| 68 | +
|
| 69 | + When *comments* is True, FrameComment objects are rendered as |
| 70 | + Mermaid subgraphs that group their contained steps. |
| 71 | + """ |
| 72 | + if isinstance(workflow, NormalizedFormat2): |
| 73 | + nf2 = workflow |
| 74 | + else: |
| 75 | + nf2 = ensure_format2(workflow) |
| 76 | + |
| 77 | + lines = ["graph LR"] |
| 78 | + |
| 79 | + # Build node ID mappings and collect node declaration lines |
| 80 | + input_ids: dict[str, str] = {} |
| 81 | + input_lines: dict[str, str] = {} |
| 82 | + for i, inp in enumerate(nf2.inputs): |
| 83 | + node_id = f"input_{i}" |
| 84 | + inp_label = inp.id or str(i) |
| 85 | + input_ids[inp_label] = node_id |
| 86 | + label = _sanitize_label(inp_label) |
| 87 | + type_str = _input_type_str(inp) |
| 88 | + input_lines[inp_label] = _node_line( |
| 89 | + node_id, f"{label}<br/><i>{type_str}</i>", STEP_TYPE_SHAPES.get(type_str, SHAPE_INPUT) |
| 90 | + ) |
| 91 | + |
| 92 | + step_ids: dict[str, str] = {} |
| 93 | + step_lines: dict[str, str] = {} |
| 94 | + for i, step in enumerate(nf2.steps): |
| 95 | + node_id = f"step_{i}" |
| 96 | + step_label = step.label or step.id |
| 97 | + step_ids[step_label] = node_id |
| 98 | + |
| 99 | + tool_id = step.tool_id |
| 100 | + if tool_id and tool_id.startswith(MAIN_TS_PREFIX): |
| 101 | + tool_id = tool_id[len(MAIN_TS_PREFIX) :] |
| 102 | + |
| 103 | + label = _sanitize_label(step.label or step.id or (f"tool:{tool_id}" if tool_id else str(i))) |
| 104 | + step_type = step.type_.value if step.type_ else "tool" |
| 105 | + step_lines[step_label] = _node_line(node_id, label, STEP_TYPE_SHAPES.get(step_type, SHAPE_TOOL)) |
| 106 | + |
| 107 | + # Collect frame comments and which labels they claim |
| 108 | + framed: set[str] = set() |
| 109 | + frames: list[FrameComment] = [] |
| 110 | + if comments: |
| 111 | + for comment in nf2.comments: |
| 112 | + if isinstance(comment, FrameComment) and comment.contains_steps: |
| 113 | + frames.append(comment) |
| 114 | + for ref in comment.contains_steps: |
| 115 | + framed.add(str(ref)) |
| 116 | + |
| 117 | + # Emit nodes — framed ones go inside subgraph blocks, others at top level |
| 118 | + for inp_label, line in input_lines.items(): |
| 119 | + if inp_label not in framed: |
| 120 | + lines.append(f" {line}") |
| 121 | + |
| 122 | + for step_label, line in step_lines.items(): |
| 123 | + if step_label not in framed: |
| 124 | + lines.append(f" {line}") |
| 125 | + |
| 126 | + for i, frame in enumerate(frames): |
| 127 | + title = _sanitize_label(frame.title or f"Group {i}") |
| 128 | + lines.append(f' subgraph sub_{i} ["{title}"]') |
| 129 | + for ref in frame.contains_steps or []: |
| 130 | + ref_str = str(ref) |
| 131 | + if ref_str in input_lines: |
| 132 | + lines.append(f" {input_lines[ref_str]}") |
| 133 | + elif ref_str in step_lines: |
| 134 | + lines.append(f" {step_lines[ref_str]}") |
| 135 | + lines.append(" end") |
| 136 | + |
| 137 | + # Build edges (deduplicate identical connections) |
| 138 | + seen_edges: set[tuple[str, str]] = set() |
| 139 | + for i, step in enumerate(nf2.steps): |
| 140 | + node_id = f"step_{i}" |
| 141 | + for step_input in step.in_: |
| 142 | + if step_input.source is None: |
| 143 | + continue |
| 144 | + sources = step_input.source if isinstance(step_input.source, list) else [step_input.source] |
| 145 | + for source in sources: |
| 146 | + source_ref = nf2.resolve_source(source) |
| 147 | + source_id = input_ids.get(source_ref.step_label) or step_ids.get(source_ref.step_label) |
| 148 | + if source_id: |
| 149 | + edge_key = (source_id, node_id) |
| 150 | + if edge_key not in seen_edges: |
| 151 | + seen_edges.add(edge_key) |
| 152 | + lines.append(f" {source_id} --> {node_id}") |
| 153 | + |
| 154 | + return "\n".join(lines) |
0 commit comments