4545import urllib .parse
4646from collections .abc import AsyncGenerator , Callable
4747from http .server import BaseHTTPRequestHandler
48- from typing import Any
48+ from typing import Any , cast
4949
5050import structlog
5151from starlette .applications import Starlette
5858from genkit .aio .loop import run_async
5959from genkit .codec import dump_dict , dump_json
6060from genkit .core .action import Action
61- from genkit .core .action .types import ActionKind
61+ from genkit .core .action .types import ActionKind , ActionResponse
6262from genkit .core .constants import DEFAULT_GENKIT_VERSION
6363from genkit .core .error import get_reflection_json
6464from genkit .core .registry import Registry
7777def _list_registered_actions (registry : Registry ) -> dict [str , Action ]:
7878 """Return all locally registered actions keyed as `/<kind>/<name>`."""
7979 registered : dict [str , Action ] = {}
80- for kind in ActionKind :
80+ for kind in ActionKind . __members__ . values () :
8181 for name , action in registry .get_actions_by_kind (kind ).items ():
8282 registered [f'/{ kind .value } /{ name } ' ] = action
8383 return registered
@@ -179,7 +179,7 @@ def log_message(self, format, *args):
179179 message = format % args
180180 address = self .address_string ()
181181 timestamp = self .log_date_time_string ()
182- control_chars = self . _control_char_table
182+ control_chars = getattr ( self , ' _control_char_table' , {})
183183 logger .debug (f'{ address } - - [{ timestamp } ] { message .translate (control_chars )} ' )
184184
185185 def do_GET (self ) -> None : # noqa: N802
@@ -246,6 +246,10 @@ async def get_action():
246246 return await registry .resolve_action_by_key (payload ['key' ])
247247
248248 action = run_async (loop , get_action )
249+ if not action :
250+ self .send_response (404 )
251+ self .end_headers ()
252+ return
249253 payload .get ('input' )
250254 context = payload ['context' ] if 'context' in payload else {}
251255
@@ -281,7 +285,7 @@ async def run_fn():
281285 context = context ,
282286 )
283287
284- output = run_async (loop , run_fn )
288+ output = cast ( ActionResponse , run_async (loop , run_fn ) )
285289
286290 self .wfile .write (
287291 bytes (
@@ -309,7 +313,7 @@ async def run_fn():
309313 async def run_fn ():
310314 return await action .arun_raw (raw_input = payload .get ('input' ), context = context )
311315
312- output = run_async (loop , run_fn )
316+ output = cast ( ActionResponse , run_async (loop , run_fn ) )
313317
314318 self .send_response (200 )
315319 self .send_header ('x-genkit-version' , DEFAULT_GENKIT_VERSION )
@@ -719,7 +723,7 @@ def wrapped_on_trace_start(tid):
719723 ],
720724 middleware = [
721725 Middleware (
722- CORSMiddleware ,
726+ CORSMiddleware , # type: ignore[arg-type]
723727 allow_origins = ['*' ],
724728 allow_methods = ['*' ],
725729 allow_headers = ['*' ],
@@ -728,5 +732,5 @@ def wrapped_on_trace_start(tid):
728732 on_startup = [on_app_startup ] if on_app_startup else [],
729733 on_shutdown = [on_app_shutdown ] if on_app_shutdown else [],
730734 )
731- app . active_actions = active_actions
735+ setattr ( app , ' active_actions' , active_actions )
732736 return app
0 commit comments