Skip to content

Commit a4226bc

Browse files
authored
fix(py/genkit): additional ty check fixes in core (#4243)
1 parent 70f9d01 commit a4226bc

File tree

14 files changed

+63
-46
lines changed

14 files changed

+63
-46
lines changed

py/bin/sanitize_schema_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def add_header(content: str) -> str:
340340
import sys # noqa
341341
342342
if sys.version_info < (3, 11): # noqa
343-
from strenum import StrEnum # noqa
343+
from strenum import StrEnum # type: ignore # noqa
344344
else: # noqa
345345
from enum import StrEnum # noqa
346346
"""

py/packages/genkit/src/genkit/core/action/_action.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
from collections.abc import AsyncIterator, Callable
9090
from contextvars import ContextVar
9191
from functools import cached_property
92-
from typing import Any
92+
from typing import Any, Awaitable
9393

9494
from pydantic import BaseModel, TypeAdapter
9595

@@ -228,7 +228,7 @@ def __init__(
228228
input_spec = inspect.getfullargspec(metadata_fn if metadata_fn else fn)
229229
action_args, arg_types = extract_action_args_and_types(input_spec)
230230
n_action_args = len(action_args)
231-
self._fn, self._afn = _make_tracing_wrappers(name, kind, span_metadata, n_action_args, fn)
231+
self._fn, self._afn = _make_tracing_wrappers(name, kind, span_metadata or {}, n_action_args, fn)
232232
self._initialize_io_schemas(action_args, arg_types, input_spec)
233233

234234
@property
@@ -248,7 +248,7 @@ def metadata(self) -> dict[str, Any]:
248248
return self._metadata
249249

250250
@cached_property
251-
def input_type(self) -> type | None:
251+
def input_type(self) -> TypeAdapter[Any] | None:
252252
return self._input_type
253253

254254
@cached_property
@@ -418,7 +418,7 @@ def stream(
418418
telemetry_labels=telemetry_labels,
419419
on_chunk=lambda c: stream.send(c),
420420
)
421-
stream.set_close_future(resp)
421+
stream.set_close_future(asyncio.create_task(resp))
422422

423423
result_future: asyncio.Future[ActionResponse] = asyncio.Future()
424424
stream.closed.add_done_callback(lambda _: result_future.set_result(stream.closed.result().response))
@@ -483,7 +483,7 @@ class ActionMetadata(BaseModel):
483483

484484

485485
_SyncTracingWrapper = Callable[[Any | None, ActionRunContext], ActionResponse]
486-
_AsyncTracingWrapper = Callable[[Any | None, ActionRunContext], ActionResponse]
486+
_AsyncTracingWrapper = Callable[[Any | None, ActionRunContext], Awaitable[ActionResponse]]
487487

488488

489489
def _make_tracing_wrappers(
@@ -539,7 +539,7 @@ async def async_tracing_wrapper(input: Any | None, ctx: ActionRunContext) -> Act
539539
) from e
540540

541541
record_output_metadata(span, output=output)
542-
return ActionResponse(response=output, trace_id=trace_id)
542+
return ActionResponse(response=output, traceId=trace_id)
543543

544544
def sync_tracing_wrapper(input: Any | None, ctx: ActionRunContext) -> ActionResponse:
545545
"""Wrap the function in a sync tracing wrapper.
@@ -580,6 +580,6 @@ def sync_tracing_wrapper(input: Any | None, ctx: ActionRunContext) -> ActionResp
580580
) from e
581581

582582
record_output_metadata(span, output=output)
583-
return ActionResponse(response=output, trace_id=trace_id)
583+
return ActionResponse(response=output, traceId=trace_id)
584584

585585
return sync_tracing_wrapper, async_tracing_wrapper

py/packages/genkit/src/genkit/core/action/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any
2222

2323
if sys.version_info < (3, 11):
24-
from strenum import StrEnum
24+
from strenum import StrEnum # type: ignore
2525
else:
2626
from enum import StrEnum
2727

py/packages/genkit/src/genkit/core/environment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
import os
2020
import sys
21+
from typing import cast
2122

2223
if sys.version_info < (3, 11):
23-
from strenum import StrEnum
24+
from strenum import StrEnum # type: ignore
2425
else:
2526
from enum import StrEnum
2627

@@ -64,8 +65,8 @@ def get_current_environment() -> GenkitEnvironment:
6465
"""
6566
env = os.getenv(EnvVar.GENKIT_ENV)
6667
if env is None:
67-
return GenkitEnvironment.PROD
68+
return cast(GenkitEnvironment, GenkitEnvironment.PROD)
6869
try:
6970
return GenkitEnvironment(env)
7071
except ValueError:
71-
return GenkitEnvironment.PROD
72+
return cast(GenkitEnvironment, GenkitEnvironment.PROD)

py/packages/genkit/src/genkit/core/error.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""Base error classes and utilities for Genkit."""
1818

1919
import traceback
20-
from typing import Any
20+
from typing import Any, cast
2121

2222
from pydantic import BaseModel, ConfigDict, Field
2323

@@ -61,6 +61,8 @@ class HttpErrorWireFormat(BaseModel):
6161
class GenkitError(Exception):
6262
"""Base error class for Genkit errors."""
6363

64+
status: StatusName
65+
6466
def __init__(
6567
self,
6668
*,
@@ -81,12 +83,13 @@ def __init__(
8183
trace_id: A unique identifier for tracing the action execution.
8284
source: Optional source of the error.
8385
"""
84-
self.status = status
85-
if not self.status and isinstance(cause, GenkitError):
86-
self.status = cause.status
87-
88-
if not self.status:
89-
self.status = 'INTERNAL'
86+
if status:
87+
temp_status: StatusName = status
88+
elif isinstance(cause, GenkitError):
89+
temp_status = cause.status
90+
else:
91+
temp_status = 'INTERNAL'
92+
self.status = temp_status
9093

9194
source_prefix = f'{source}: ' if source else ''
9295
super().__init__(f'{source_prefix}{self.status}: {message}')
@@ -129,7 +132,7 @@ def to_serializable(self) -> GenkitReflectionApiErrorWireFormat:
129132
# This error type is used by 3P authors with the field "details",
130133
# but the actual Callable protocol value is "details"
131134
return GenkitReflectionApiErrorWireFormat(
132-
details=self.details,
135+
details=GenkitReflectionApiDetailsWireFormat(**self.details) if self.details else None,
133136
code=StatusCodes[self.status].value,
134137
message=repr(self.cause) if self.cause else self.original_message,
135138
)
@@ -201,7 +204,7 @@ def get_reflection_json(error: Any) -> GenkitReflectionApiErrorWireFormat:
201204
return GenkitReflectionApiErrorWireFormat(
202205
message=str(error),
203206
code=StatusCodes.INTERNAL.value,
204-
details={'stack': get_error_stack(error)},
207+
details=GenkitReflectionApiDetailsWireFormat(stack=get_error_stack(error)),
205208
)
206209

207210

py/packages/genkit/src/genkit/core/flows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ async def handle_standard_flow(
330330
routes=routes,
331331
middleware=[
332332
Middleware(
333-
CORSMiddleware,
333+
CORSMiddleware, # type: ignore[arg-type]
334334
allow_origins=['*'],
335335
allow_methods=['*'],
336336
allow_headers=['*'],

py/packages/genkit/src/genkit/core/plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"""
2222

2323
import abc
24+
from typing import cast
2425

2526
from genkit.core.action import Action, ActionMetadata
2627
from genkit.core.action.types import ActionKind
@@ -92,7 +93,7 @@ async def model(self, name: str) -> Action:
9293
ValueError: If the model is not found.
9394
"""
9495
target = name if '/' in name else f'{self.name}/{name}'
95-
action = await self.resolve(ActionKind.MODEL, target)
96+
action = await self.resolve(cast(ActionKind, ActionKind.MODEL), target)
9697
if action is None:
9798
raise ValueError(f'Model not found: {target}')
9899
return action

py/packages/genkit/src/genkit/core/reflection.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
import urllib.parse
4646
from collections.abc import AsyncGenerator, Callable
4747
from http.server import BaseHTTPRequestHandler
48-
from typing import Any
48+
from typing import Any, cast
4949

5050
import structlog
5151
from starlette.applications import Starlette
@@ -58,7 +58,7 @@
5858
from genkit.aio.loop import run_async
5959
from genkit.codec import dump_dict, dump_json
6060
from genkit.core.action import Action
61-
from genkit.core.action.types import ActionKind
61+
from genkit.core.action.types import ActionKind, ActionResponse
6262
from genkit.core.constants import DEFAULT_GENKIT_VERSION
6363
from genkit.core.error import get_reflection_json
6464
from genkit.core.registry import Registry
@@ -77,7 +77,7 @@
7777
def _list_registered_actions(registry: Registry) -> dict[str, Action]:
7878
"""Return all locally registered actions keyed as `/<kind>/<name>`."""
7979
registered: dict[str, Action] = {}
80-
for kind in ActionKind:
80+
for kind in ActionKind.__members__.values():
8181
for name, action in registry.get_actions_by_kind(kind).items():
8282
registered[f'/{kind.value}/{name}'] = action
8383
return registered
@@ -179,7 +179,7 @@ def log_message(self, format, *args):
179179
message = format % args
180180
address = self.address_string()
181181
timestamp = self.log_date_time_string()
182-
control_chars = self._control_char_table
182+
control_chars = getattr(self, '_control_char_table', {})
183183
logger.debug(f'{address} - - [{timestamp}] {message.translate(control_chars)}')
184184

185185
def do_GET(self) -> None: # noqa: N802
@@ -246,6 +246,10 @@ async def get_action():
246246
return await registry.resolve_action_by_key(payload['key'])
247247

248248
action = run_async(loop, get_action)
249+
if not action:
250+
self.send_response(404)
251+
self.end_headers()
252+
return
249253
payload.get('input')
250254
context = payload['context'] if 'context' in payload else {}
251255

@@ -281,7 +285,7 @@ async def run_fn():
281285
context=context,
282286
)
283287

284-
output = run_async(loop, run_fn)
288+
output = cast(ActionResponse, run_async(loop, run_fn))
285289

286290
self.wfile.write(
287291
bytes(
@@ -309,7 +313,7 @@ async def run_fn():
309313
async def run_fn():
310314
return await action.arun_raw(raw_input=payload.get('input'), context=context)
311315

312-
output = run_async(loop, run_fn)
316+
output = cast(ActionResponse, run_async(loop, run_fn))
313317

314318
self.send_response(200)
315319
self.send_header('x-genkit-version', DEFAULT_GENKIT_VERSION)
@@ -719,7 +723,7 @@ def wrapped_on_trace_start(tid):
719723
],
720724
middleware=[
721725
Middleware(
722-
CORSMiddleware,
726+
CORSMiddleware, # type: ignore[arg-type]
723727
allow_origins=['*'],
724728
allow_methods=['*'],
725729
allow_headers=['*'],
@@ -728,5 +732,5 @@ def wrapped_on_trace_start(tid):
728732
on_startup=[on_app_startup] if on_app_startup else [],
729733
on_shutdown=[on_app_shutdown] if on_app_shutdown else [],
730734
)
731-
app.active_actions = active_actions
735+
setattr(app, 'active_actions', active_actions)
732736
return app

py/packages/genkit/src/genkit/core/registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import asyncio
3131
import threading
3232
from collections.abc import Callable
33-
from typing import Any
33+
from typing import Any, cast
3434

3535
import structlog
3636
from dotpromptz.dotprompt import Dotprompt
@@ -383,7 +383,11 @@ async def resolve_action(self, kind: ActionKind, name: str) -> Action | None:
383383
# Skip if we're looking up a dynamic action provider itself to avoid recursion
384384
if kind != ActionKind.DYNAMIC_ACTION_PROVIDER:
385385
with self._lock:
386-
providers = list(self._entries.get(ActionKind.DYNAMIC_ACTION_PROVIDER, {}).values())
386+
if ActionKind.DYNAMIC_ACTION_PROVIDER in self._entries:
387+
providers_dict = self._entries[ActionKind.DYNAMIC_ACTION_PROVIDER]
388+
else:
389+
providers_dict = {}
390+
providers = list(providers_dict.values())
387391
for provider in providers:
388392
try:
389393
response = await provider.arun({'kind': kind, 'name': name})

py/packages/genkit/src/genkit/core/trace/default_exporter.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,16 @@ def extract_span_data(span: ReadableSpan) -> dict[str, Any]:
5151
This function extracts the span data from a ReadableSpan object and returns
5252
a dictionary containing the span data.
5353
"""
54-
span_data = {'traceId': f'{span.context.trace_id}', 'spans': {}}
55-
span_data['spans'][span.context.span_id] = {
56-
'spanId': f'{span.context.span_id}',
54+
span_data: dict[str, Any] = {'traceId': f'{span.context.trace_id}', 'spans': {}}
55+
span_id = span.context.span_id
56+
start_time = (span.start_time / 1000000) if span.start_time is not None else 0
57+
end_time = (span.end_time / 1000000) if span.end_time is not None else 0
58+
59+
span_data['spans'][span_id] = {
60+
'spanId': f'{span_id}',
5761
'traceId': f'{span.context.trace_id}',
58-
'startTime': span.start_time / 1000000,
59-
'endTime': span.end_time / 1000000,
62+
'startTime': start_time,
63+
'endTime': end_time,
6064
'attributes': {**span.attributes},
6165
'displayName': span.name,
6266
# "links": span.links,
@@ -75,8 +79,8 @@ def extract_span_data(span: ReadableSpan) -> dict[str, Any]:
7579
'version': 'v1',
7680
},
7781
}
78-
if not span_data['spans'][span.context.span_id]['parentSpanId']: # type: ignore
79-
del span_data['spans'][span.context.span_id]['parentSpanId'] # type: ignore
82+
if not span_data['spans'][span.context.span_id]['parentSpanId']:
83+
del span_data['spans'][span.context.span_id]['parentSpanId']
8084

8185
if not span.parent:
8286
span_data['displayName'] = span.name
@@ -127,7 +131,7 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
127131
for span in spans:
128132
client.post(
129133
urljoin(self.telemetry_server_url, self.telemetry_server_endpoint),
130-
data=json.dumps(extract_span_data(span)),
134+
json=extract_span_data(span),
131135
headers={
132136
'Content-Type': 'application/json',
133137
'Accept': 'application/json',

0 commit comments

Comments
 (0)