Source code for gxformat2.mermaid._builder

"""Build Mermaid flowchart diagrams from Galaxy workflows."""

from __future__ import annotations

from pathlib import Path
from typing import Any

from gxformat2._labels import Labels
from gxformat2.normalized import ensure_format2, NormalizedFormat2
from gxformat2.schema.gxformat2 import BaseInputParameter, FrameComment, GalaxyWorkflow

# Standard Mermaid shape wrappers: (open, close) bracket pairs.
#   >label]   = asymmetric / flag (inputs)
#   [[label]] = subroutine (subworkflows)
#   [label]   = rectangle (tool steps, default)
SHAPE_INPUT = (">", "]")
SHAPE_PARAM = ("{{", "}}")
SHAPE_TOOL = ("[", "]")
SHAPE_SUBWORKFLOW = ("[[", "]]")

STEP_TYPE_SHAPES = {
    "data": SHAPE_INPUT,
    "collection": SHAPE_INPUT,
    "integer": SHAPE_PARAM,
    "float": SHAPE_PARAM,
    "text": SHAPE_PARAM,
    "boolean": SHAPE_PARAM,
    "color": SHAPE_PARAM,
    "input": SHAPE_INPUT,
    "tool": SHAPE_TOOL,
    "subworkflow": SHAPE_SUBWORKFLOW,
}

MAIN_TS_PREFIX = "toolshed.g2.bx.psu.edu/repos/"


def _sanitize_label(label: str) -> str:
    """Escape characters that have special meaning in Mermaid labels."""
    label = label.replace('"', "#quot;")
    for ch in "()[]{}<>":
        label = label.replace(ch, f"#{ord(ch)};")
    return label


def _input_type_str(inp: BaseInputParameter) -> str:
    type_ = getattr(inp, "type_", None)
    if type_ is None:
        return "input"
    if isinstance(type_, list):
        if not type_:
            return "input"
        type_ = type_[0]
    return getattr(type_, "value", type_)


def _node_line(node_id: str, label: str, shape: tuple[str, str]) -> str:
    open_br, close_br = shape
    return f'{node_id}{open_br}"{label}"{close_br}'


[docs] def workflow_to_mermaid( workflow: dict[str, Any] | str | Path | GalaxyWorkflow | NormalizedFormat2, *, comments: bool = False, ) -> str: """Convert a Galaxy workflow to a Mermaid flowchart string. Accepts anything ``ensure_format2()`` supports, plus an already normalized ``NormalizedFormat2`` instance. When *comments* is True, FrameComment objects are rendered as Mermaid subgraphs that group their contained steps. """ if isinstance(workflow, NormalizedFormat2): nf2 = workflow else: nf2 = ensure_format2(workflow) lines = ["graph LR"] # Build node ID mappings and collect node declaration lines input_ids: dict[str, str] = {} input_lines: dict[str, str] = {} for i, inp in enumerate(nf2.inputs): node_id = f"input_{i}" inp_label = inp.id or str(i) input_ids[inp_label] = node_id label = _sanitize_label(inp_label) type_str = _input_type_str(inp) input_lines[inp_label] = _node_line( node_id, f"{label}<br/><i>{type_str}</i>", STEP_TYPE_SHAPES.get(type_str, SHAPE_INPUT) ) step_ids: dict[str, str] = {} step_lines: dict[str, str] = {} for i, step in enumerate(nf2.steps): node_id = f"step_{i}" step_label = step.label or step.id step_ids[step_label] = node_id tool_id = step.tool_id if tool_id and tool_id.startswith(MAIN_TS_PREFIX): tool_id = tool_id[len(MAIN_TS_PREFIX) :] display_id = step.id if step.id and not Labels.is_unlabeled(step.id) else None label = _sanitize_label(step.label or display_id or (f"tool:{tool_id}" if tool_id else str(i))) step_type = step.type_.value if step.type_ else "tool" step_lines[step_label] = _node_line(node_id, label, STEP_TYPE_SHAPES.get(step_type, SHAPE_TOOL)) # Collect frame comments and which labels they claim framed: set[str] = set() frames: list[FrameComment] = [] if comments: for comment in nf2.comments: if isinstance(comment, FrameComment) and comment.contains_steps: frames.append(comment) for ref in comment.contains_steps: framed.add(str(ref)) # Emit nodes — framed ones go inside subgraph blocks, others at top level for inp_label, line in input_lines.items(): if inp_label not in framed: lines.append(f" {line}") for step_label, line in step_lines.items(): if step_label not in framed: lines.append(f" {line}") for i, frame in enumerate(frames): title = _sanitize_label(frame.title or f"Group {i}") lines.append(f' subgraph sub_{i} ["{title}"]') for ref in frame.contains_steps or []: ref_str = str(ref) if ref_str in input_lines: lines.append(f" {input_lines[ref_str]}") elif ref_str in step_lines: lines.append(f" {step_lines[ref_str]}") lines.append(" end") # Build edges (deduplicate identical connections) seen_edges: set[tuple[str, str]] = set() for i, step in enumerate(nf2.steps): node_id = f"step_{i}" for step_input in step.in_: if step_input.source is None: continue sources = step_input.source if isinstance(step_input.source, list) else [step_input.source] for source in sources: source_ref = nf2.resolve_source(source) source_id = input_ids.get(source_ref.step_label) or step_ids.get(source_ref.step_label) if source_id: edge_key = (source_id, node_id) if edge_key not in seen_edges: seen_edges.add(edge_key) lines.append(f" {source_id} --> {node_id}") return "\n".join(lines)