Skip to content

Commit 74c85b5

Browse files
committed
fix(py/genkit): ty check fixes for genkit.ai
1 parent 87bc74d commit 74c85b5

File tree

3 files changed

+27
-27
lines changed

3 files changed

+27
-27
lines changed

py/packages/genkit/src/genkit/ai/_aio.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
class while customizing it with any plugins.
2121
"""
2222

23+
import asyncio
2324
import uuid
24-
from asyncio import Future
2525
from collections.abc import AsyncIterator
2626
from pathlib import Path
27-
from typing import Any
27+
from typing import Any, cast
2828

2929
from genkit.aio import Channel
3030
from genkit.blocks.document import Document
@@ -253,7 +253,7 @@ def generate_stream(
253253
messages: list[Message] | None = None,
254254
tools: list[str] | None = None,
255255
return_tool_requests: bool | None = None,
256-
tool_choice: ToolChoice = None,
256+
tool_choice: ToolChoice | None = None,
257257
config: GenerationCommonConfig | dict[str, Any] | None = None,
258258
max_turns: int | None = None,
259259
context: dict[str, Any] | None = None,
@@ -268,7 +268,7 @@ def generate_stream(
268268
timeout: float | None = None,
269269
) -> tuple[
270270
AsyncIterator[GenerateResponseChunkWrapper],
271-
Future[GenerateResponseWrapper],
271+
asyncio.Future[GenerateResponseWrapper],
272272
]:
273273
"""Streams generated text or structured data using a language model.
274274
@@ -351,7 +351,7 @@ def generate_stream(
351351
use=use,
352352
on_chunk=lambda c: stream.send(c),
353353
)
354-
stream.set_close_future(resp)
354+
stream.set_close_future(asyncio.create_task(resp))
355355

356356
return stream, stream.closed
357357

@@ -389,7 +389,7 @@ async def retrieve(
389389

390390
request_options = {**(retriever_config or {}), **(options or {})}
391391

392-
retrieve_action = await self.registry.resolve_action(ActionKind.RETRIEVER, retriever_name)
392+
retrieve_action = await self.registry.resolve_action(cast(ActionKind, ActionKind.RETRIEVER), retriever_name)
393393
if retrieve_action is None:
394394
raise ValueError(f'Retriever "{retriever_name}" not found')
395395

@@ -430,7 +430,7 @@ async def index(
430430

431431
req_options = {**(indexer_config or {}), **(options or {})}
432432

433-
index_action = await self.registry.resolve_action(ActionKind.INDEXER, indexer_name)
433+
index_action = await self.registry.resolve_action(cast(ActionKind, ActionKind.INDEXER), indexer_name)
434434
if index_action is None:
435435
raise ValueError(f'Indexer "{indexer_name}" not found')
436436

@@ -464,7 +464,7 @@ async def embed(
464464
# Merge options passed to embed() with config from EmbedderRef
465465
final_options = {**(embedder_config or {}), **(options or {})}
466466

467-
embed_action = await self.registry.resolve_action(ActionKind.EMBEDDER, embedder_name)
467+
embed_action = await self.registry.resolve_action(cast(ActionKind, ActionKind.EMBEDDER), embedder_name)
468468
if embed_action is None:
469469
raise ValueError(f'Embedder "{embedder_name}" not found')
470470

@@ -501,7 +501,7 @@ async def evaluate(
501501

502502
final_options = {**(evaluator_config or {}), **(options or {})}
503503

504-
eval_action = await self.registry.resolve_action(ActionKind.EVALUATOR, evaluator_name)
504+
eval_action = await self.registry.resolve_action(cast(ActionKind, ActionKind.EVALUATOR), evaluator_name)
505505
if eval_action is None:
506506
raise ValueError(f'Evaluator "{evaluator_name}" not found')
507507

py/packages/genkit/src/genkit/ai/_base_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,6 @@ def _make_reflection_server(registry: GenkitRegistry, spec: ServerSpec) -> uvico
246246
Returns:
247247
A uvicorn server instance.
248248
"""
249-
app = create_reflection_asgi_app(registry=registry)
249+
app = create_reflection_asgi_app(registry=registry.registry)
250250
config = uvicorn.Config(app, host=spec.host, port=spec.port, loop='asyncio')
251251
return uvicorn.Server(config)

py/packages/genkit/src/genkit/ai/_registry.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import uuid
4444
from collections.abc import AsyncIterator, Callable
4545
from functools import wraps
46-
from typing import TYPE_CHECKING, Any
46+
from typing import TYPE_CHECKING, Any, cast
4747

4848
if TYPE_CHECKING:
4949
from genkit.blocks.resource import ResourceFn, ResourceOptions
@@ -122,7 +122,7 @@ def __init__(self):
122122
"""Initialize the Genkit registry."""
123123
self.registry: Registry = Registry()
124124

125-
def flow(self, name: str | None = None, description: str | None = None) -> Callable[[Callable], Callable]:
125+
def flow(self, name: str | None = None, description: str | None = None) -> Action:
126126
"""Decorator to register a function as a flow.
127127
128128
Args:
@@ -235,7 +235,7 @@ def define_schema(self, name: str, schema: type) -> type:
235235
define_schema(self.registry, name, schema)
236236
return schema
237237

238-
def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable], Callable]:
238+
def tool(self, name: str | None = None, description: str | None = None) -> Action:
239239
"""Decorator to register a function as a tool.
240240
241241
Args:
@@ -315,10 +315,10 @@ def define_retriever(
315315
self,
316316
name: str,
317317
fn: RetrieverFn,
318-
config_schema: BaseModel | dict[str, Any] | None = None,
318+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
319319
metadata: dict[str, Any] | None = None,
320320
description: str | None = None,
321-
) -> Callable[[Callable], Callable]:
321+
) -> Action:
322322
"""Define a retriever action.
323323
324324
Args:
@@ -339,7 +339,7 @@ def define_retriever(
339339
retriever_description = get_func_description(fn, description)
340340
return self.registry.register_action(
341341
name=name,
342-
kind=ActionKind.RETRIEVER,
342+
kind=cast(ActionKind, ActionKind.RETRIEVER),
343343
fn=fn,
344344
metadata=retriever_meta,
345345
description=retriever_description,
@@ -349,10 +349,10 @@ def define_indexer(
349349
self,
350350
name: str,
351351
fn: IndexerFn,
352-
config_schema: BaseModel | dict[str, Any] | None = None,
352+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
353353
metadata: dict[str, Any] | None = None,
354354
description: str | None = None,
355-
) -> Callable[[Callable], Callable]:
355+
) -> Action:
356356
"""Define an indexer action.
357357
358358
Args:
@@ -374,7 +374,7 @@ def define_indexer(
374374
indexer_description = get_func_description(fn, description)
375375
return self.registry.register_action(
376376
name=name,
377-
kind=ActionKind.INDEXER,
377+
kind=cast(ActionKind, ActionKind.INDEXER),
378378
fn=fn,
379379
metadata=indexer_meta,
380380
description=indexer_description,
@@ -384,7 +384,7 @@ def define_reranker(
384384
self,
385385
name: str,
386386
fn: RerankerFn,
387-
config_schema: BaseModel | dict[str, Any] | None = None,
387+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
388388
metadata: dict[str, Any] | None = None,
389389
description: str | None = None,
390390
) -> Action:
@@ -482,7 +482,7 @@ def define_evaluator(
482482
definition: str,
483483
fn: EvaluatorFn,
484484
is_billed: bool = False,
485-
config_schema: BaseModel | dict[str, Any] | None = None,
485+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
486486
metadata: dict[str, Any] | None = None,
487487
description: str | None = None,
488488
) -> Action:
@@ -581,7 +581,7 @@ async def eval_stepper_fn(req: EvalRequest) -> EvalResponse:
581581

582582
return self.registry.register_action(
583583
name=name,
584-
kind=ActionKind.EVALUATOR,
584+
kind=cast(ActionKind, ActionKind.EVALUATOR),
585585
fn=eval_stepper_fn,
586586
metadata=evaluator_meta,
587587
description=evaluator_description,
@@ -594,10 +594,10 @@ def define_batch_evaluator(
594594
definition: str,
595595
fn: BatchEvaluatorFn,
596596
is_billed: bool = False,
597-
config_schema: BaseModel | dict[str, Any] | None = None,
597+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
598598
metadata: dict[str, Any] | None = None,
599599
description: str | None = None,
600-
) -> Callable[[Callable], Callable]:
600+
) -> Action:
601601
"""Define a batch evaluator action.
602602
603603
This action runs the callback function on the entire dataset.
@@ -627,7 +627,7 @@ def define_batch_evaluator(
627627
evaluator_description = get_func_description(fn, description)
628628
return self.registry.register_action(
629629
name=name,
630-
kind=ActionKind.EVALUATOR,
630+
kind=cast(ActionKind, ActionKind.EVALUATOR),
631631
fn=fn,
632632
metadata=evaluator_meta,
633633
description=evaluator_description,
@@ -666,7 +666,7 @@ def define_model(
666666
model_description = get_func_description(fn, description)
667667
return self.registry.register_action(
668668
name=name,
669-
kind=ActionKind.MODEL,
669+
kind=cast(ActionKind, ActionKind.MODEL),
670670
fn=fn,
671671
metadata=model_meta,
672672
description=model_description,
@@ -706,7 +706,7 @@ def define_embedder(
706706
embedder_description = get_func_description(fn, description)
707707
return self.registry.register_action(
708708
name=name,
709-
kind=ActionKind.EMBEDDER,
709+
kind=cast(ActionKind, ActionKind.EMBEDDER),
710710
fn=fn,
711711
metadata=embedder_meta,
712712
description=embedder_description,

0 commit comments

Comments
 (0)