Skip to content

Commit 1a2d9ed

Browse files
jmchiltonclaude
andcommitted
Split WorkflowInputParameter into type-specific discriminated union
Add per-type parameter models (WorkflowDataParameter, WorkflowCollectionParameter, WorkflowIntegerParameter, WorkflowFloatParameter, WorkflowTextParameter, WorkflowBooleanParameter) discriminated on the `type` field via pydantic annotations. The catch-all WorkflowInputParameter is retained for Schema Salad codegen compatibility. Schema changes: - Add BaseInputParameter, BaseDataParameter, MinMax abstract records - Add split parameter records with pydantic:type Literal overrides - Add discriminator annotations on Process.inputs field - Explicit JSON-LD predicates for shared fields (min, max, collection_type) Code changes: - Use BaseInputParameter for isinstance checks and type annotations - Handle both GalaxyType enum and string Literal type values - Use SerializeAsAny for proper subclass serialization - Preserve input instances through expansion to avoid dict round-trip loss - Bump schema-salad-plus-pydantic >= 0.1.8 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8092536 commit 1a2d9ed

File tree

13 files changed

+5362
-142
lines changed

13 files changed

+5362
-142
lines changed

gxformat2/abstract.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any
88

99
from gxformat2.normalized import ensure_format2, NormalizedFormat2, NormalizedWorkflowStep
10-
from gxformat2.schema.gxformat2 import GalaxyType, WorkflowInputParameter, WorkflowOutputParameter, WorkflowStepOutput
10+
from gxformat2.schema.gxformat2 import BaseInputParameter, GalaxyType, WorkflowOutputParameter, WorkflowStepOutput
1111
from gxformat2.yaml import ordered_dump_to_path, ordered_load
1212

1313
CWL_VERSION = "v1.2"
@@ -105,7 +105,7 @@ def _step_outputs_to_abstract(step: NormalizedWorkflowStep):
105105
return [out.id for out in step.out if out.id is not None]
106106

107107

108-
def _inputs_to_abstract(inputs: list[WorkflowInputParameter]):
108+
def _inputs_to_abstract(inputs: list[BaseInputParameter]):
109109
"""Convert Format2 inputs to abstract CWL inputs."""
110110
abstract_inputs: dict[str, Any] = {}
111111
for inp in inputs:
@@ -114,8 +114,8 @@ def _inputs_to_abstract(inputs: list[WorkflowInputParameter]):
114114
continue
115115
input_def: dict[str, Any] = {}
116116

117-
# Convert type
118-
cwl_type = _galaxy_type_to_cwl(inp.type_)
117+
# Convert type (type_ lives on concrete subclasses, not BaseInputParameter)
118+
cwl_type = _galaxy_type_to_cwl(getattr(inp, "type_", None))
119119
if inp.optional:
120120
cwl_type += "?"
121121
input_def["type"] = cwl_type
@@ -134,7 +134,7 @@ def _inputs_to_abstract(inputs: list[WorkflowInputParameter]):
134134
return abstract_inputs
135135

136136

137-
def _galaxy_type_to_cwl(galaxy_type: GalaxyType | list[GalaxyType] | None) -> str:
137+
def _galaxy_type_to_cwl(galaxy_type: GalaxyType | str | list[GalaxyType | str] | None) -> str:
138138
"""Map a Galaxy/Format2 type to a CWL type string."""
139139
if galaxy_type is None:
140140
return "File"
@@ -144,12 +144,13 @@ def _galaxy_type_to_cwl(galaxy_type: GalaxyType | list[GalaxyType] | None) -> st
144144
if t != GalaxyType.null:
145145
return _galaxy_type_to_cwl(t) + "[]"
146146
return "File"
147-
if galaxy_type == GalaxyType.data:
147+
type_str = galaxy_type.value if isinstance(galaxy_type, GalaxyType) else str(galaxy_type)
148+
if type_str in ("data", "File"):
148149
return "File"
149-
if galaxy_type == GalaxyType.collection:
150+
if type_str == "collection":
150151
# TODO: handle nested collections, pairs, etc...
151152
return "File[]"
152-
return galaxy_type.value
153+
return type_str
153154

154155

155156
def _outputs_to_abstract(outputs: list[WorkflowOutputParameter]):

gxformat2/cytoscape/_builder.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any
77

88
from gxformat2.normalized import ensure_format2, NormalizedFormat2, NormalizedWorkflowStep
9-
from gxformat2.schema.gxformat2 import GalaxyWorkflow, WorkflowInputParameter
9+
from gxformat2.schema.gxformat2 import BaseInputParameter, GalaxyType, GalaxyWorkflow
1010

1111
from .models import (
1212
CytoscapeEdge,
@@ -57,17 +57,20 @@ def _to_position(step_position, order_index: int) -> CytoscapePosition:
5757
return CytoscapePosition(x=int(step_position.left), y=int(step_position.top))
5858

5959

60-
def _input_type_str(inp: WorkflowInputParameter) -> str:
61-
if inp.type_ is None:
60+
def _input_type_str(inp: BaseInputParameter) -> str:
61+
# type_ lives on concrete subclasses, not BaseInputParameter
62+
type_ = getattr(inp, "type_", None)
63+
if type_ is None:
6264
return "input"
63-
if isinstance(inp.type_, list):
64-
if inp.type_:
65-
return inp.type_[0].value + "[]"
65+
if isinstance(type_, list):
66+
if type_:
67+
t = type_[0]
68+
return (t.value if isinstance(t, GalaxyType) else str(t)) + "[]"
6669
return "input"
67-
return inp.type_.value
70+
return type_.value if isinstance(type_, GalaxyType) else str(type_)
6871

6972

70-
def _input_node(inp: WorkflowInputParameter, order_index: int) -> CytoscapeNode:
73+
def _input_node(inp: BaseInputParameter, order_index: int) -> CytoscapeNode:
7174
input_id = inp.id or str(order_index)
7275
type_str = _input_type_str(inp)
7376
return CytoscapeNode(

gxformat2/lint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def _validate_input_types(lint_context: LintContext, nf2: NormalizedFormat2):
159159
for inp in nf2.inputs:
160160
if inp.default is None:
161161
continue
162-
input_type = inp.type_
162+
# type_ lives on concrete subclasses, not BaseInputParameter
163+
input_type = getattr(inp, "type_", None)
163164
if isinstance(input_type, list):
164165
# Array type like [string] — skip default validation for now
165166
continue

gxformat2/normalize.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
from gxformat2.normalized import ensure_format2, NormalizedFormat2, NormalizedNativeWorkflow, NormalizedWorkflowStep
1616
from gxformat2.options import ConversionOptions
17-
from gxformat2.schema.gxformat2 import GalaxyWorkflow, WorkflowInputParameter, WorkflowOutputParameter
17+
from gxformat2.schema.gxformat2 import (
18+
BaseInputParameter,
19+
GalaxyWorkflow,
20+
WorkflowOutputParameter,
21+
)
1822
from gxformat2.schema.native import NativeGalaxyWorkflow
1923

2024
# Any input ensure_format2 accepts
@@ -37,7 +41,7 @@ def steps(
3741
workflow_path: str | PathLike | None = None,
3842
options: ConversionOptions | None = None,
3943
expand: bool = False,
40-
) -> list[WorkflowInputParameter | NormalizedWorkflowStep]:
44+
) -> list[BaseInputParameter | NormalizedWorkflowStep]:
4145
"""Return input parameters followed by steps as typed models."""
4246
nf2 = _ensure_format2(workflow_dict, workflow_path, options, expand)
4347
return list(nf2.inputs) + list(nf2.steps)
@@ -48,7 +52,7 @@ def inputs(
4852
workflow_path: str | PathLike | None = None,
4953
options: ConversionOptions | None = None,
5054
expand: bool = False,
51-
) -> list[WorkflowInputParameter]:
55+
) -> list[BaseInputParameter]:
5256
"""Return normalized inputs as typed models."""
5357
nf2 = _ensure_format2(workflow_dict, workflow_path, options, expand)
5458
return list(nf2.inputs)

gxformat2/normalized/_conversion.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
MAX_EXPANSION_DEPTH,
3939
)
4040
from ..schema.gxformat2 import (
41+
BaseInputParameter,
4142
CreatorOrganization,
4243
CreatorPerson,
4344
FrameComment,
4445
FreehandComment,
4546
GalaxyWorkflow,
47+
input_parameter_class,
4648
MarkdownComment,
4749
Report,
4850
)
@@ -443,7 +445,7 @@ def _build_format2_workflow(
443445
label_map[str(key)] = f"{UNLABELED_STEP_PREFIX}{step.id}"
444446

445447
# Separate inputs from non-input steps
446-
input_params: list[WorkflowInputParameter] = []
448+
input_params: list[BaseInputParameter] = []
447449
fmt2_steps: list[NormalizedWorkflowStep] = []
448450
labels = Labels()
449451

@@ -491,7 +493,7 @@ def _build_format2_workflow(
491493
)
492494

493495

494-
def _build_input_param(step: NormalizedNativeStep) -> WorkflowInputParameter:
496+
def _build_input_param(step: NormalizedNativeStep) -> BaseInputParameter:
495497
step_id = step.label if step.label is not None else f"{UNLABELED_INPUT_PREFIX}{step.id}"
496498
tool_state = step.tool_state
497499
input_type = native_input_to_format2_type({"type": step.type_}, tool_state)
@@ -521,7 +523,12 @@ def _build_input_param(step: NormalizedNativeStep) -> WorkflowInputParameter:
521523
if step.position:
522524
kwargs["position"] = _convert_position(step.position)
523525

524-
return WorkflowInputParameter(**kwargs)
526+
# Use the specific discriminated type when possible; fall back to
527+
# WorkflowInputParameter for list types (multiple inputs) since the
528+
# specific classes only accept scalar Literal type_ values.
529+
if isinstance(input_type, list):
530+
return WorkflowInputParameter(**kwargs)
531+
return input_parameter_class(input_type)(**kwargs)
525532

526533

527534
def _build_format2_step(
@@ -1094,13 +1101,14 @@ def _build_native_workflow(
10941101

10951102

10961103
def _build_input_step(
1097-
inp: WorkflowInputParameter,
1104+
inp: BaseInputParameter,
10981105
order_index: int,
10991106
ctx: _ConversionContext,
11001107
) -> NormalizedNativeStep:
11011108
raw_label = inp.id or f"Input {order_index}"
11021109
label = None if Labels.is_unlabeled(raw_label) else raw_label
1103-
input_type = inp.type_
1110+
# type_ lives on concrete subclasses, not BaseInputParameter
1111+
input_type = getattr(inp, "type_", None)
11041112
if isinstance(input_type, list):
11051113
if len(input_type) != 1:
11061114
raise Exception("Only simple arrays of workflow inputs are currently supported")
@@ -1135,10 +1143,15 @@ def _build_input_step(
11351143
tool_state["multiple"] = True
11361144
if inp.optional is not None:
11371145
tool_state["optional"] = inp.optional
1138-
if inp.format:
1139-
tool_state["format"] = inp.format
1140-
if inp.collection_type:
1141-
tool_state["collection_type"] = inp.collection_type
1146+
# getattr because inp is typed as BaseInputParameter but may be any subclass:
1147+
# format lives on BaseDataParameter, collection_type on WorkflowCollectionParameter
1148+
# and WorkflowInputParameter (catch-all). Non-data types (integer, text, etc.) lack these.
1149+
fmt = getattr(inp, "format", None)
1150+
if fmt:
1151+
tool_state["format"] = fmt
1152+
collection_type = getattr(inp, "collection_type", None)
1153+
if collection_type:
1154+
tool_state["collection_type"] = collection_type
11421155
if inp.default is not None:
11431156
tool_state["default"] = inp.default
11441157

@@ -1682,8 +1695,8 @@ def _expand_format2(wf: NormalizedFormat2, ctx: _ExpansionContext) -> ExpandedFo
16821695
step_data = step.model_dump(by_alias=True, exclude={"run"})
16831696
expanded_steps.append(ExpandedWorkflowStep(**step_data, run=expanded_run))
16841697

1685-
wf_data = wf.model_dump(by_alias=True, exclude={"steps"})
1686-
return ExpandedFormat2(**wf_data, steps=expanded_steps)
1698+
wf_data = wf.model_dump(by_alias=True, exclude={"steps", "inputs"})
1699+
return ExpandedFormat2(**wf_data, inputs=wf.inputs, steps=expanded_steps)
16871700

16881701

16891702
def _expand_native(wf: NormalizedNativeWorkflow, ctx: _ExpansionContext) -> ExpandedNativeWorkflow:

gxformat2/normalized/_format2.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515
from pathlib import Path
1616
from typing import Any, Literal, NamedTuple, Union
1717

18-
from pydantic import BaseModel, ConfigDict, Field, field_validator
18+
from pydantic import BaseModel, ConfigDict, Field, field_validator, SerializeAsAny
1919
from typing_extensions import TypeAlias
2020

2121
from gxformat2.schema.gxformat2 import (
22+
BaseInputParameter,
2223
CreatorOrganization,
2324
CreatorPerson,
2425
FrameComment,
2526
FreehandComment,
2627
GalaxyWorkflow,
28+
input_parameter_class,
2729
MarkdownComment,
2830
Report,
2931
StepPosition,
@@ -175,7 +177,7 @@ class NormalizedFormat2(_DictMixin, BaseModel):
175177
class_: Literal["GalaxyWorkflow"] = Field(default="GalaxyWorkflow", alias="class")
176178
label: str | None = Field(default=None)
177179
doc: str | None = Field(default=None, description="Annotation, joined if originally a list.")
178-
inputs: list[WorkflowInputParameter] = Field(
180+
inputs: list[SerializeAsAny[BaseInputParameter]] = Field(
179181
default_factory=list, description="Always a list, shorthands expanded."
180182
)
181183
outputs: list[WorkflowOutputParameter] = Field(default_factory=list, description="Always a list.")
@@ -257,6 +259,7 @@ def normalized_format2(
257259
if "steps" not in workflow:
258260
workflow = {**workflow, "steps": {}}
259261
workflow = _pre_clean_steps(workflow)
262+
workflow = _pre_normalize_input_types(workflow)
260263
workflow = GalaxyWorkflow.model_validate(workflow)
261264
assert isinstance(workflow, GalaxyWorkflow)
262265
return _normalize_workflow(workflow)
@@ -310,20 +313,28 @@ def _normalize_input_type(value: Any) -> Any:
310313
return value
311314

312315

316+
def _validate_input_dict(d: dict[str, Any]) -> BaseInputParameter:
317+
"""Validate an input dict using the specific discriminated type."""
318+
type_val = d.get("type")
319+
if isinstance(type_val, list):
320+
return WorkflowInputParameter.model_validate(d)
321+
return input_parameter_class(type_val).model_validate(d)
322+
323+
313324
def _normalize_inputs(
314-
inputs: list[WorkflowInputParameter] | dict[str, WorkflowInputParameter | str] | dict[str, Any],
315-
) -> list[WorkflowInputParameter]:
325+
inputs: list[BaseInputParameter] | dict[str, BaseInputParameter | str] | dict[str, Any] | Any,
326+
) -> list[BaseInputParameter]:
316327
if isinstance(inputs, list):
317-
result = []
328+
result: list[BaseInputParameter] = []
318329
for inp in inputs:
319-
if isinstance(inp, WorkflowInputParameter):
330+
if isinstance(inp, BaseInputParameter):
320331
result.append(inp)
321332
elif isinstance(inp, dict):
322333
if "type" in inp:
323334
inp = {**inp, "type": _normalize_input_type(inp["type"])}
324-
result.append(WorkflowInputParameter.model_validate(inp))
335+
result.append(_validate_input_dict(inp))
325336
else:
326-
result.append(WorkflowInputParameter.model_validate(inp))
337+
result.append(_validate_input_dict(inp))
327338
return result
328339

329340
# Dict form — keys are ids, values are WorkflowInputParameter, type string, or dict
@@ -332,8 +343,8 @@ def _normalize_inputs(
332343
if isinstance(value, str):
333344
# Shorthand: input_name: "data"
334345
normalized_type = _normalize_input_type(value)
335-
result.append(WorkflowInputParameter.model_validate({"id": key, "type": normalized_type}))
336-
elif isinstance(value, WorkflowInputParameter):
346+
result.append(input_parameter_class(normalized_type)(id=key, type_=normalized_type))
347+
elif isinstance(value, BaseInputParameter):
337348
if value.id is None:
338349
value = value.model_copy(update={"id": key})
339350
result.append(value)
@@ -344,9 +355,9 @@ def _normalize_inputs(
344355
value = {**value, "type": _normalize_input_type(value["type"])}
345356
if "format" in value and isinstance(value["format"], str):
346357
value = {**value, "format": [value["format"]]}
347-
result.append(WorkflowInputParameter.model_validate(value))
358+
result.append(_validate_input_dict(value))
348359
else:
349-
result.append(WorkflowInputParameter(id=key))
360+
result.append(input_parameter_class(None)(id=key))
350361
return result
351362

352363

@@ -377,6 +388,33 @@ def _normalize_outputs(
377388
return result
378389

379390

391+
def _pre_normalize_input_types(workflow: dict[str, Any]) -> dict[str, Any]:
392+
"""Normalize input type aliases (File → data, etc.) before discriminator runs.
393+
394+
The discriminated union on ``Process.inputs`` routes based on the raw
395+
``type`` field, so alias normalization must happen before model validation.
396+
"""
397+
inputs = workflow.get("inputs")
398+
if inputs is None:
399+
return workflow
400+
401+
def norm_entry(entry: Any) -> Any:
402+
if isinstance(entry, dict) and "type" in entry:
403+
return {**entry, "type": _normalize_input_type(entry["type"])}
404+
if isinstance(entry, str):
405+
return _normalize_input_type(entry)
406+
return entry
407+
408+
new_inputs: dict[str, Any] | list[Any]
409+
if isinstance(inputs, dict):
410+
new_inputs = {k: norm_entry(v) for k, v in inputs.items()}
411+
elif isinstance(inputs, list):
412+
new_inputs = [norm_entry(v) for v in inputs]
413+
else:
414+
return workflow
415+
return {**workflow, "inputs": new_inputs}
416+
417+
380418
def _pre_clean_steps(workflow: dict[str, Any]) -> dict[str, Any]:
381419
"""Resolve ``$link`` entries in step state dicts before model validation.
382420

0 commit comments

Comments
 (0)