Skip to content

Commit 75b65eb

Browse files
jmchiltonclaude
andcommitted
Add convert_tool_state and native_state_encoder callback protocols.
Allow schema-aware consumers to inject tool-definition-aware state conversion on both export (native→format2) and import (format2→native) paths without gxformat2 needing tool definitions. Rename encode_tool_state → encode_tool_state_json for clarity. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4d64b8f commit 75b65eb

File tree

5 files changed

+404
-21
lines changed

5 files changed

+404
-21
lines changed

gxformat2/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
PROJECT_NAME = "gxformat2"
66
PROJECT_OWNER = "galaxyproject"
77

8-
from .converter import ImportOptions, python_to_workflow # NOQA
9-
from .export import from_galaxy_native # NOQA
8+
from .converter import ImportOptions, NativeStateEncoderFn, python_to_workflow # NOQA
9+
from .export import ConvertToolStateFn, from_galaxy_native # NOQA
1010
from .interface import ImporterGalaxyInterface # NOQA
1111
from .main import convert_and_import_workflow # NOQA
1212

1313
__all__ = (
1414
"convert_and_import_workflow",
15+
"ConvertToolStateFn",
1516
"from_galaxy_native",
1617
"ImporterGalaxyInterface",
1718
"ImportOptions",
19+
"NativeStateEncoderFn",
1820
"python_to_workflow",
1921
)

gxformat2/converter.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,22 @@
33
import argparse
44
import copy
55
import json
6+
import logging
67
import os
78
import sys
89
import uuid
9-
from typing import Any, Optional
10+
from typing import Any, Callable, Dict, Optional
11+
12+
log = logging.getLogger(__name__)
13+
14+
NativeStateEncoderFn = Optional[Callable[[dict, Dict[str, Any]], Optional[Dict[str, Any]]]]
15+
"""Callback to encode format2 state back to native tool_state.
16+
17+
Accepts (step, state) where step is the partially-built native step dict
18+
and state is the format2 state dict after setup_connected_values processing.
19+
Returns {param_name: encoded_value} for native tool_state, or None to fall
20+
back to default json.dumps encoding.
21+
"""
1022

1123
from ._labels import Labels
1224
from .model import (
@@ -81,7 +93,8 @@ class ImportOptions:
8193

8294
def __init__(self):
8395
self.deduplicate_subworkflows = False
84-
self.encode_tool_state = True
96+
self.encode_tool_state_json = True
97+
self.native_state_encoder: NativeStateEncoderFn = None
8598

8699

87100
def yaml_to_workflow(has_yaml, galaxy_interface, workflow_directory, import_options=None):
@@ -370,7 +383,7 @@ def transform_input(context, step, default_name):
370383
if attrib in step:
371384
tool_state[attrib] = step[attrib]
372385

373-
_populate_tool_state(step, tool_state, encode=context.import_options.encode_tool_state)
386+
_populate_tool_state(step, tool_state, encode=context.import_options.encode_tool_state_json)
374387

375388

376389
def transform_pause(context, step, default_name="Pause for dataset review"):
@@ -398,7 +411,7 @@ def transform_pause(context, step, default_name="Pause for dataset review"):
398411

399412
connect = pop_connect_from_step_dict(step)
400413
_populate_input_connections(context, step, connect)
401-
_populate_tool_state(step, tool_state, encode=context.import_options.encode_tool_state)
414+
_populate_tool_state(step, tool_state, encode=context.import_options.encode_tool_state_json)
402415

403416

404417
def transform_pick_value(context, step, default_name="Pick Value"):
@@ -469,7 +482,7 @@ def transform_subworkflow(context, step):
469482

470483
connect = pop_connect_from_step_dict(step)
471484
_populate_input_connections(context, step, connect)
472-
_populate_tool_state(step, tool_state, encode=context.import_options.encode_tool_state)
485+
_populate_tool_state(step, tool_state, encode=context.import_options.encode_tool_state_json)
473486

474487

475488
def _runtime_value():
@@ -501,12 +514,27 @@ def transform_tool(context, step):
501514
# TODO: handle runtime inputs and state together.
502515
runtime_inputs = step.get("runtime_inputs", [])
503516
if "state" in step or runtime_inputs:
504-
encode = context.import_options.encode_tool_state
517+
encode = context.import_options.encode_tool_state_json
518+
encoder = context.import_options.native_state_encoder
505519
step_state = step.pop("state", {})
506520
step_state = setup_connected_values(step_state, append_to=connect)
507521

508-
for key, value in step_state.items():
509-
tool_state[key] = json.dumps(value) if encode else value
522+
encoded = None
523+
if encoder is not None:
524+
try:
525+
encoded = encoder(step, step_state)
526+
except Exception:
527+
log.warning(
528+
"native_state_encoder callback failed for %s, falling back to default",
529+
step.get("tool_id"),
530+
exc_info=True,
531+
)
532+
533+
if encoded is not None:
534+
tool_state.update(encoded)
535+
else:
536+
for key, value in step_state.items():
537+
tool_state[key] = json.dumps(value) if encode else value
510538
for runtime_input in runtime_inputs:
511539
tool_state[runtime_input] = json.dumps(_runtime_value()) if encode else _runtime_value()
512540
elif "tool_state" in step:
@@ -515,7 +543,7 @@ def transform_tool(context, step):
515543
# Fill in input connections
516544
_populate_input_connections(context, step, connect)
517545

518-
_populate_tool_state(step, tool_state, encode=context.import_options.encode_tool_state)
546+
_populate_tool_state(step, tool_state, encode=context.import_options.encode_tool_state_json)
519547

520548
# Handle outputs.
521549
out = step.pop("out", None)

gxformat2/export.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,19 @@
22

33
import argparse
44
import json
5+
import logging
56
import sys
67
from collections import OrderedDict
8+
from typing import Any, Callable, Dict, Optional
9+
10+
log = logging.getLogger(__name__)
11+
12+
ConvertToolStateFn = Optional[Callable[[dict], Optional[Dict[str, Any]]]]
13+
"""Callback to convert a native tool step's tool_state to format2 state.
14+
15+
Accepts a native step dict (with tool_id, tool_version, tool_state).
16+
Returns a format2 state dict, or None to fall back to default tool_state passthrough.
17+
"""
718

819
from ._labels import Labels, UNLABELED_INPUT_PREFIX, UNLABELED_STEP_PREFIX
920
from .model import (
@@ -29,10 +40,18 @@ def _copy_common_properties(from_native_step, to_format2_step, compact=False):
2940
to_format2_step[prop] = value
3041

3142

32-
def from_galaxy_native(native_workflow_dict, tool_interface=None, json_wrapper=False, compact=False):
43+
def from_galaxy_native(native_workflow_dict, tool_interface=None, json_wrapper=False, compact=False, convert_tool_state: ConvertToolStateFn = None):
3344
"""Convert native .ga workflow definition to a format2 workflow.
3445
3546
This is highly experimental and currently broken.
47+
48+
If ``convert_tool_state`` is provided it should be a callable accepting a
49+
native step dict and returning an optional dict representing the format2
50+
``state`` for that step. When the callable returns a dict, the step will
51+
carry ``state`` instead of ``tool_state``; when it returns ``None`` the
52+
default ``tool_state`` passthrough is used. This allows schema-aware
53+
consumers to inject tool-definition-aware value conversion without
54+
gxformat2 needing to know about tool definitions.
3655
"""
3756
data = OrderedDict()
3857
data["class"] = "GalaxyWorkflow"
@@ -143,7 +162,11 @@ def from_galaxy_native(native_workflow_dict, tool_interface=None, json_wrapper=F
143162
else:
144163
subworkflow_native_dict = step["subworkflow"]
145164
subworkflow = from_galaxy_native(
146-
subworkflow_native_dict, tool_interface=tool_interface, json_wrapper=False, compact=compact
165+
subworkflow_native_dict,
166+
tool_interface=tool_interface,
167+
json_wrapper=False,
168+
compact=compact,
169+
convert_tool_state=convert_tool_state,
147170
)
148171
step_dict["run"] = subworkflow
149172
steps.append(step_dict)
@@ -155,10 +178,24 @@ def from_galaxy_native(native_workflow_dict, tool_interface=None, json_wrapper=F
155178
_copy_properties(step, step_dict, optional_props, required_props)
156179
_copy_common_properties(step, step_dict, compact=compact)
157180

158-
tool_state = _tool_state(step)
159-
tool_state.pop("__page__", None)
160-
tool_state.pop("__rerun_remap_job_id__", None)
161-
step_dict["tool_state"] = tool_state
181+
converted_state = None
182+
if convert_tool_state is not None:
183+
try:
184+
converted_state = convert_tool_state(step)
185+
except Exception:
186+
log.warning(
187+
"convert_tool_state callback failed for %s, falling back to default",
188+
step.get("tool_id"),
189+
exc_info=True,
190+
)
191+
192+
if converted_state is not None:
193+
step_dict["state"] = converted_state
194+
else:
195+
tool_state = _tool_state(step)
196+
tool_state.pop("__page__", None)
197+
tool_state.pop("__rerun_remap_job_id__", None)
198+
step_dict["tool_state"] = tool_state
162199

163200
_convert_input_connections(step, step_dict, label_map)
164201
_convert_post_job_actions(step, step_dict)

tests/test_to_format2.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,196 @@ def test_unlabeled_input_connections_round_trip():
102102
assert input_conn["id"] == 0
103103

104104

105+
def test_convert_tool_state_callback_called():
106+
"""Test that convert_tool_state callback is called for tool steps on export."""
107+
sars_example = os.path.join(TEST_PATH, "sars-cov-2-variant-calling.ga")
108+
with open(sars_example) as f:
109+
native_wf = json.load(f)
110+
111+
called_tool_ids = []
112+
113+
def _callback(native_step):
114+
called_tool_ids.append(native_step.get("tool_id"))
115+
return {"custom_key": "custom_value"}
116+
117+
result = from_galaxy_native(native_wf, convert_tool_state=_callback)
118+
119+
# Callback should have been called for each tool step
120+
assert len(called_tool_ids) > 0
121+
# Tool steps should have "state" not "tool_state"
122+
for step in _tool_steps(result):
123+
assert "state" in step
124+
assert "tool_state" not in step
125+
assert step["state"] == {"custom_key": "custom_value"}
126+
127+
128+
def test_convert_tool_state_callback_none_fallback():
129+
"""Test that returning None from callback falls back to default tool_state."""
130+
sars_example = os.path.join(TEST_PATH, "sars-cov-2-variant-calling.ga")
131+
with open(sars_example) as f:
132+
native_wf = json.load(f)
133+
134+
def _callback(native_step):
135+
return None
136+
137+
result = from_galaxy_native(native_wf, convert_tool_state=_callback)
138+
139+
# All tool steps should have tool_state (default path)
140+
for step in _tool_steps(result):
141+
assert "tool_state" in step
142+
assert "state" not in step
143+
144+
145+
def test_convert_tool_state_callback_exception_fallback():
146+
"""Test that callback exceptions fall back to default tool_state."""
147+
sars_example = os.path.join(TEST_PATH, "sars-cov-2-variant-calling.ga")
148+
with open(sars_example) as f:
149+
native_wf = json.load(f)
150+
151+
def _callback(native_step):
152+
raise ValueError("conversion failed")
153+
154+
result = from_galaxy_native(native_wf, convert_tool_state=_callback)
155+
156+
# All tool steps should have tool_state (fallback on exception)
157+
for step in _tool_steps(result):
158+
assert "tool_state" in step
159+
assert "state" not in step
160+
161+
162+
def test_convert_tool_state_callback_selective():
163+
"""Test that callback can convert some steps and fall back on others."""
164+
sars_example = os.path.join(TEST_PATH, "sars-cov-2-variant-calling.ga")
165+
with open(sars_example) as f:
166+
native_wf = json.load(f)
167+
168+
target_tool_id = "__MERGE_COLLECTION__"
169+
170+
def _callback(native_step):
171+
if native_step.get("tool_id") == target_tool_id:
172+
return {"converted": True}
173+
return None
174+
175+
result = from_galaxy_native(native_wf, convert_tool_state=_callback)
176+
177+
tool_steps = list(_tool_steps(result))
178+
converted_count = 0
179+
fallback_count = 0
180+
for step in tool_steps:
181+
if step.get("state") == {"converted": True}:
182+
converted_count += 1
183+
assert "tool_state" not in step
184+
else:
185+
fallback_count += 1
186+
assert "tool_state" in step
187+
assert "state" not in step
188+
assert converted_count >= 1
189+
assert fallback_count >= 1
190+
191+
192+
def test_convert_tool_state_connections_always_present():
193+
"""Test that _convert_input_connections runs regardless of callback."""
194+
sars_example = os.path.join(TEST_PATH, "sars-cov-2-variant-calling.ga")
195+
with open(sars_example) as f:
196+
native_wf = json.load(f)
197+
198+
def _callback(native_step):
199+
return {"converted": True}
200+
201+
result = from_galaxy_native(native_wf, convert_tool_state=_callback)
202+
203+
# Tool steps with connections should still have "in" populated by _convert_input_connections
204+
has_connections = False
205+
for step in _tool_steps(result):
206+
if step.get("in"):
207+
has_connections = True
208+
break
209+
assert has_connections, "Expected at least one tool step with input connections"
210+
211+
212+
def test_convert_tool_state_no_callback_default_unchanged():
213+
"""Test that omitting convert_tool_state preserves original behavior."""
214+
sars_example = os.path.join(TEST_PATH, "sars-cov-2-variant-calling.ga")
215+
with open(sars_example) as f:
216+
native_wf = json.load(f)
217+
218+
result_default = from_galaxy_native(copy.deepcopy(native_wf))
219+
result_none = from_galaxy_native(copy.deepcopy(native_wf), convert_tool_state=None)
220+
221+
# Should be identical
222+
assert json.dumps(result_default, sort_keys=True) == json.dumps(result_none, sort_keys=True)
223+
224+
225+
def test_convert_tool_state_subworkflow_recursion():
226+
"""Test that convert_tool_state callback is passed through to subworkflows."""
227+
from gxformat2.yaml import ordered_load
228+
from gxformat2.converter import python_to_workflow
229+
230+
nested_f2 = """
231+
class: GalaxyWorkflow
232+
inputs:
233+
outer_input: data
234+
steps:
235+
first_cat:
236+
tool_id: cat1
237+
in:
238+
input1: outer_input
239+
nested_workflow:
240+
run:
241+
class: GalaxyWorkflow
242+
inputs:
243+
inner_input: data
244+
steps:
245+
inner_cat:
246+
tool_id: cat1
247+
in:
248+
input1: inner_input
249+
in:
250+
inner_input: first_cat/out_file1
251+
"""
252+
# Build a native workflow with a subworkflow
253+
f2 = ordered_load(nested_f2)
254+
native_wf = python_to_workflow(f2, MockGalaxyInterface(), None)
255+
256+
called_tool_ids = []
257+
258+
def _callback(native_step):
259+
called_tool_ids.append(native_step.get("tool_id"))
260+
return {"from_callback": True}
261+
262+
result = from_galaxy_native(native_wf, convert_tool_state=_callback)
263+
264+
# Should have been called for outer tool AND inner subworkflow tool
265+
assert len(called_tool_ids) == 2
266+
assert all(tid == "cat1" for tid in called_tool_ids)
267+
268+
# Check inner subworkflow step also got state from callback
269+
subworkflow_step = None
270+
for step in result.get("steps", {}).values() if isinstance(result.get("steps"), dict) else result.get("steps", []):
271+
if isinstance(step.get("run"), dict):
272+
subworkflow_step = step
273+
break
274+
assert subworkflow_step is not None
275+
inner_steps = subworkflow_step["run"].get("steps", {})
276+
if isinstance(inner_steps, dict):
277+
inner_tool = list(inner_steps.values())[0]
278+
else:
279+
inner_tool = inner_steps[0]
280+
assert inner_tool.get("state") == {"from_callback": True}
281+
282+
283+
def _tool_steps(format2_wf):
284+
"""Yield tool steps from a format2 workflow (handles both dict and list steps)."""
285+
steps = format2_wf.get("steps", {})
286+
if isinstance(steps, dict):
287+
step_list = steps.values()
288+
else:
289+
step_list = steps
290+
for step in step_list:
291+
if step.get("tool_id"):
292+
yield step
293+
294+
105295
def _run_example_path(path, compact=False):
106296
out = _examples_path_for(path)
107297
argv = [path, out]

0 commit comments

Comments
 (0)