4343import uuid
4444from collections .abc import AsyncIterator , Callable
4545from functools import wraps
46- from typing import TYPE_CHECKING , Any
46+ from typing import TYPE_CHECKING , Any , cast
4747
4848if TYPE_CHECKING :
4949 from genkit .blocks .resource import ResourceFn , ResourceOptions
@@ -122,7 +122,7 @@ def __init__(self):
122122 """Initialize the Genkit registry."""
123123 self .registry : Registry = Registry ()
124124
125- def flow (self , name : str | None = None , description : str | None = None ) -> Callable [[ Callable ], Callable ] :
125+ def flow (self , name : str | None = None , description : str | None = None ) -> Action :
126126 """Decorator to register a function as a flow.
127127
128128 Args:
@@ -235,7 +235,7 @@ def define_schema(self, name: str, schema: type) -> type:
235235 define_schema (self .registry , name , schema )
236236 return schema
237237
238- def tool (self , name : str | None = None , description : str | None = None ) -> Callable [[ Callable ], Callable ] :
238+ def tool (self , name : str | None = None , description : str | None = None ) -> Action :
239239 """Decorator to register a function as a tool.
240240
241241 Args:
@@ -315,10 +315,10 @@ def define_retriever(
315315 self ,
316316 name : str ,
317317 fn : RetrieverFn ,
318- config_schema : BaseModel | dict [str , Any ] | None = None ,
318+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
319319 metadata : dict [str , Any ] | None = None ,
320320 description : str | None = None ,
321- ) -> Callable [[ Callable ], Callable ] :
321+ ) -> Action :
322322 """Define a retriever action.
323323
324324 Args:
@@ -339,7 +339,7 @@ def define_retriever(
339339 retriever_description = get_func_description (fn , description )
340340 return self .registry .register_action (
341341 name = name ,
342- kind = ActionKind .RETRIEVER ,
342+ kind = cast ( ActionKind , ActionKind .RETRIEVER ) ,
343343 fn = fn ,
344344 metadata = retriever_meta ,
345345 description = retriever_description ,
@@ -349,10 +349,10 @@ def define_indexer(
349349 self ,
350350 name : str ,
351351 fn : IndexerFn ,
352- config_schema : BaseModel | dict [str , Any ] | None = None ,
352+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
353353 metadata : dict [str , Any ] | None = None ,
354354 description : str | None = None ,
355- ) -> Callable [[ Callable ], Callable ] :
355+ ) -> Action :
356356 """Define an indexer action.
357357
358358 Args:
@@ -374,7 +374,7 @@ def define_indexer(
374374 indexer_description = get_func_description (fn , description )
375375 return self .registry .register_action (
376376 name = name ,
377- kind = ActionKind .INDEXER ,
377+ kind = cast ( ActionKind , ActionKind .INDEXER ) ,
378378 fn = fn ,
379379 metadata = indexer_meta ,
380380 description = indexer_description ,
@@ -384,7 +384,7 @@ def define_reranker(
384384 self ,
385385 name : str ,
386386 fn : RerankerFn ,
387- config_schema : BaseModel | dict [str , Any ] | None = None ,
387+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
388388 metadata : dict [str , Any ] | None = None ,
389389 description : str | None = None ,
390390 ) -> Action :
@@ -482,7 +482,7 @@ def define_evaluator(
482482 definition : str ,
483483 fn : EvaluatorFn ,
484484 is_billed : bool = False ,
485- config_schema : BaseModel | dict [str , Any ] | None = None ,
485+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
486486 metadata : dict [str , Any ] | None = None ,
487487 description : str | None = None ,
488488 ) -> Action :
@@ -581,7 +581,7 @@ async def eval_stepper_fn(req: EvalRequest) -> EvalResponse:
581581
582582 return self .registry .register_action (
583583 name = name ,
584- kind = ActionKind .EVALUATOR ,
584+ kind = cast ( ActionKind , ActionKind .EVALUATOR ) ,
585585 fn = eval_stepper_fn ,
586586 metadata = evaluator_meta ,
587587 description = evaluator_description ,
@@ -594,10 +594,10 @@ def define_batch_evaluator(
594594 definition : str ,
595595 fn : BatchEvaluatorFn ,
596596 is_billed : bool = False ,
597- config_schema : BaseModel | dict [str , Any ] | None = None ,
597+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
598598 metadata : dict [str , Any ] | None = None ,
599599 description : str | None = None ,
600- ) -> Callable [[ Callable ], Callable ] :
600+ ) -> Action :
601601 """Define a batch evaluator action.
602602
603603 This action runs the callback function on the entire dataset.
@@ -627,7 +627,7 @@ def define_batch_evaluator(
627627 evaluator_description = get_func_description (fn , description )
628628 return self .registry .register_action (
629629 name = name ,
630- kind = ActionKind .EVALUATOR ,
630+ kind = cast ( ActionKind , ActionKind .EVALUATOR ) ,
631631 fn = fn ,
632632 metadata = evaluator_meta ,
633633 description = evaluator_description ,
@@ -666,7 +666,7 @@ def define_model(
666666 model_description = get_func_description (fn , description )
667667 return self .registry .register_action (
668668 name = name ,
669- kind = ActionKind .MODEL ,
669+ kind = cast ( ActionKind , ActionKind .MODEL ) ,
670670 fn = fn ,
671671 metadata = model_meta ,
672672 description = model_description ,
@@ -706,7 +706,7 @@ def define_embedder(
706706 embedder_description = get_func_description (fn , description )
707707 return self .registry .register_action (
708708 name = name ,
709- kind = ActionKind .EMBEDDER ,
709+ kind = cast ( ActionKind , ActionKind .EMBEDDER ) ,
710710 fn = fn ,
711711 metadata = embedder_meta ,
712712 description = embedder_description ,
0 commit comments