Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion iwf/command_results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,61 @@
import typing
from dataclasses import dataclass
from typing import Any, Union

from iwf_api.models import (
ChannelRequestStatus,
CommandResults as IdlCommandResults,
TimerStatus,
)
from iwf_api.types import Unset

from iwf.object_encoder import ObjectEncoder


@dataclass
class TimerCommandResult:
status: TimerStatus
command_id: str


@dataclass
class InternalChannelCommandResult:
channel_name: str
value: Any
status: ChannelRequestStatus
command_id: str


@dataclass
class CommandResults:
pass
timer_commands: list[TimerCommandResult]
internal_channel_commands: list[InternalChannelCommandResult]


def from_idl_command_results(
idl_results: Union[Unset, IdlCommandResults],
internal_channel_types: dict[str, typing.Optional[type]],
object_encoder: ObjectEncoder,
) -> CommandResults:
results = CommandResults(list(), list())
if isinstance(idl_results, Unset):
return results
if not isinstance(idl_results.timer_results, Unset):
for timer in idl_results.timer_results:
results.timer_commands.append(
TimerCommandResult(timer.timer_status, timer.command_id)
)

if not isinstance(idl_results.inter_state_channel_results, Unset):
for inter in idl_results.inter_state_channel_results:
results.internal_channel_commands.append(
InternalChannelCommandResult(
inter.channel_name,
object_encoder.decode(
inter.value, internal_channel_types.get(inter.channel_name)
),
inter.request_status,
inter.command_id,
)
)
return results
42 changes: 41 additions & 1 deletion iwf/communication.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,42 @@
from typing import Any, Optional

from iwf_api.models import EncodedObject, InterStateChannelPublishing

from iwf.errors import WorkflowDefinitionError
from iwf.object_encoder import ObjectEncoder


class Communication:
pass
_type_store: dict[str, Optional[type]]
_object_encoder: ObjectEncoder
_to_publish_internal_channel: dict[str, list[EncodedObject]]

def __init__(
self, type_store: dict[str, Optional[type]], object_encoder: ObjectEncoder
):
self._object_encoder = object_encoder
self._type_store = type_store
self._to_publish_internal_channel = {}

def publish_to_internal_channel(self, channel_name: str, value: Any):
registered_type = self._type_store.get(channel_name)
if (
value is not None
and registered_type is not None
and not isinstance(value, registered_type)
):
raise WorkflowDefinitionError(
f"InternalChannel value is not of type {registered_type}"
)
vals = self._to_publish_internal_channel.get(channel_name)
if vals is None:
vals = []
vals.append(self._object_encoder.encode(value))
self._to_publish_internal_channel[channel_name] = vals

def get_to_publishing_internal_channel(self) -> list[InterStateChannelPublishing]:
pubs = []
for name, vals in self._to_publish_internal_channel.items():
for val in vals:
pubs.append(InterStateChannelPublishing(name, val))
return pubs
5 changes: 3 additions & 2 deletions iwf/object_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)

from iwf_api.models import EncodedObject
from iwf_api.types import Unset
from typing_extensions import Literal

# StrEnum is available in 3.11+
Expand Down Expand Up @@ -502,7 +503,7 @@ def encode(

def decode(
self,
payload: Optional[EncodedObject],
payload: Union[Optional[EncodedObject], Unset],
type_hint: Optional[Type] = None,
) -> Any:
"""Decode payloads into values.
Expand All @@ -516,7 +517,7 @@ def decode(
Returns:
Decoded and converted value.
"""
if payload is None:
if payload is None or isinstance(payload, Unset):
return None
if self.payload_codec:
payload = self.payload_codec.decode(payload)
Expand Down
23 changes: 19 additions & 4 deletions iwf/registry.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
from typing import Optional

from iwf.errors import WorkflowDefinitionError, InvalidArgumentError
from iwf.communication_schema import CommunicationMethodType
from iwf.errors import InvalidArgumentError, WorkflowDefinitionError
from iwf.workflow import ObjectWorkflow, get_workflow_type
from iwf.workflow_state import get_state_id, WorkflowState
from iwf.workflow_state import WorkflowState, get_state_id


class Registry:
_workflow_store: dict[str, ObjectWorkflow]
_starting_state_store: dict[str, WorkflowState]
_state_store: dict[str, dict[str, WorkflowState]]
_internal_channel_type_store: dict[str, dict[str, Optional[type]]]

def __init__(self):
self._workflow_store = dict()
self._starting_state_store = dict()
self._state_store = dict()
self._internal_channel_type_store = dict()

def add_workflow(self, wf: ObjectWorkflow):
self._register_workflow(wf)
self._register_workflow_type(wf)
self._register_workflow_state(wf)
self._register_internal_channels(wf)

def add_workflows(self, *wfs: ObjectWorkflow):
for wf in wfs:
Expand Down Expand Up @@ -45,12 +49,23 @@ def get_workflow_state_with_check(
)
return state

def _register_workflow(self, wf):
def get_internal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
return self._internal_channel_type_store[wf_type]

def _register_workflow_type(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
if wf_type in self._workflow_store:
raise WorkflowDefinitionError("workflow type conflict: ", wf_type)
self._workflow_store[wf_type] = wf

def _register_internal_channels(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
types: dict[str, Optional[type]] = {}
for method in wf.get_communication_schema().communication_methods:
if method.method_type == CommunicationMethodType.InternalChannel:
types[method.name] = method.value_type
self._internal_channel_type_store[wf_type] = types

def _register_workflow_state(self, wf):
wf_type = get_workflow_type(wf)
state_map = {}
Expand Down
10 changes: 8 additions & 2 deletions iwf/state_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@ def single_next_state(
return StateDecision([StateMovement.create(state, state_input)])

@classmethod
def multi_next_states(cls, next_states: List[StateMovement]) -> StateDecision:
return StateDecision(next_states)
def multi_next_states(
cls, *next_states: Union[type[WorkflowState], StateMovement]
) -> StateDecision:
next_list = [
n if isinstance(n, StateMovement) else StateMovement.create(n)
for n in next_states
]
return StateDecision(next_list)


StateDecision.dead_end = StateDecision([StateMovement.dead_end])
Expand Down
132 changes: 132 additions & 0 deletions iwf/tests/test_internal_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import inspect
import time

from iwf_api.models import ChannelRequestStatus

from iwf.client import Client
from iwf.command_request import CommandRequest, InternalChannelCommand
from iwf.command_results import CommandResults, InternalChannelCommandResult
from iwf.communication import Communication
from iwf.communication_schema import CommunicationMethod, CommunicationSchema
from iwf.persistence import Persistence
from iwf.state_decision import StateDecision
from iwf.state_schema import StateSchema
from iwf.tests.worker_server import registry
from iwf.workflow import ObjectWorkflow
from iwf.workflow_context import WorkflowContext
from iwf.workflow_state import T, WorkflowState

test_channel_name1 = "test-internal-channel-1"
test_channel_name2 = "test-internal-channel-2"

test_channel_name3 = "test-internal-channel-3"
test_channel_name4 = "test-internal-channel-4"


class InitState(WorkflowState[None]):
def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
return StateDecision.multi_next_states(
WaitAnyWithPublishState, WaitAllThenPublishState
)


class WaitAnyWithPublishState(WorkflowState[None]):
def wait_until(
self,
ctx: WorkflowContext,
input: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
communication.publish_to_internal_channel(test_channel_name3, 123)
communication.publish_to_internal_channel(test_channel_name4, "str-value")
return CommandRequest.for_any_command_completed(
InternalChannelCommand.by_name(test_channel_name1),
InternalChannelCommand.by_name(test_channel_name2),
)

def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
assert len(command_results.internal_channel_commands) == 2
assert command_results.internal_channel_commands[
0
] == InternalChannelCommandResult(
channel_name=test_channel_name1,
command_id="",
status=ChannelRequestStatus.WAITING,
value=None,
)
assert command_results.internal_channel_commands[
1
] == InternalChannelCommandResult(
channel_name=test_channel_name2,
command_id="",
status=ChannelRequestStatus.RECEIVED,
value=None,
)
return StateDecision.graceful_complete_workflow()


class WaitAllThenPublishState(WorkflowState[None]):
def wait_until(
self,
ctx: WorkflowContext,
input: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
return CommandRequest.for_all_command_completed(
InternalChannelCommand.by_name(test_channel_name3),
InternalChannelCommand.by_name(test_channel_name4),
)

def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
communication.publish_to_internal_channel(test_channel_name2, None)
return StateDecision.dead_end


class InternalChannelWorkflow(ObjectWorkflow):
def get_workflow_states(self) -> StateSchema:
return StateSchema.with_starting_state(
InitState(), WaitAnyWithPublishState(), WaitAllThenPublishState()
)

def get_communication_schema(self) -> CommunicationSchema:
return CommunicationSchema.create(
CommunicationMethod.internal_channel_def(test_channel_name1, int),
CommunicationMethod.internal_channel_def(test_channel_name2, type(None)),
CommunicationMethod.internal_channel_def(test_channel_name3, int),
CommunicationMethod.internal_channel_def(test_channel_name4, str),
)


wf = InternalChannelWorkflow()
registry.add_workflow(wf)
client = Client(registry)


def test_internal_channel_workflow():
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"

client.start_workflow(InternalChannelWorkflow, wf_id, 100, None)
client.get_simple_workflow_result_with_wait(wf_id, None)
7 changes: 3 additions & 4 deletions iwf/tests/worker_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from threading import Thread

from flask import Flask, request
from iwf_api.models import WorkflowStateWaitUntilRequest, WorkflowStateExecuteRequest
from iwf_api.models import WorkflowStateExecuteRequest, WorkflowStateWaitUntilRequest

from iwf.registry import Registry
from iwf.worker_service import (
WorkerService,
)

debug_mode = False
debug_mode: bool = False

registry = Registry()

Expand Down Expand Up @@ -39,8 +39,7 @@ def handle_execute():

@_flask_app.errorhandler(Exception)
def internal_error(exception):
print("500 error caught")
print(traceback.format_exc())
# TODO: how to print to std in a different thread??
response = exception.get_response()
# replace the body with JSON
response.data = traceback.format_exc()
Expand Down
Loading