Skip to content

Commit 92d03f0

Browse files
zac-lijina-bot
andauthored
feat: fix batch handling in sagemaker for clip (#6216)
Co-authored-by: Jina Dev Bot <[email protected]>
1 parent 2f65ce1 commit 92d03f0

File tree

5 files changed

+30
-29
lines changed

5 files changed

+30
-29
lines changed

jina/orchestrate/flow/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,10 +1776,8 @@ def build(self, copy_flow: bool = False, **kwargs) -> 'Flow':
17761776
op_flow._deployment_nodes[GATEWAY_NAME].args.graph_description = json.dumps(
17771777
op_flow._get_graph_representation()
17781778
)
1779-
op_flow._deployment_nodes[
1780-
GATEWAY_NAME
1781-
].args.deployments_addresses = json.dumps(
1782-
op_flow._get_deployments_addresses()
1779+
op_flow._deployment_nodes[GATEWAY_NAME].args.deployments_addresses = (
1780+
json.dumps(op_flow._get_deployments_addresses())
17831781
)
17841782

17851783
op_flow._deployment_nodes[GATEWAY_NAME].update_pod_args()

jina/serve/executors/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -401,21 +401,21 @@ def __init__(
401401
self._init_monitoring()
402402
self._init_workspace = workspace
403403
if __dry_run_endpoint__ not in self.requests:
404-
self.requests[
405-
__dry_run_endpoint__
406-
] = _FunctionWithSchema.get_function_with_schema(
407-
self.__class__._dry_run_func
404+
self.requests[__dry_run_endpoint__] = (
405+
_FunctionWithSchema.get_function_with_schema(
406+
self.__class__._dry_run_func
407+
)
408408
)
409409
else:
410410
self.logger.warning(
411411
f' Endpoint {__dry_run_endpoint__} is defined by the Executor. Be aware that this endpoint is usually reserved to enable health checks from the Client through the gateway.'
412412
f' So it is recommended not to expose this endpoint. '
413413
)
414414
if type(self) == BaseExecutor:
415-
self.requests[
416-
__default_endpoint__
417-
] = _FunctionWithSchema.get_function_with_schema(
418-
self.__class__._dry_run_func
415+
self.requests[__default_endpoint__] = (
416+
_FunctionWithSchema.get_function_with_schema(
417+
self.__class__._dry_run_func
418+
)
419419
)
420420

421421
self._lock = contextlib.AsyncExitStack()
@@ -595,14 +595,14 @@ def _add_requests(self, _requests: Optional[Dict]):
595595
_func = getattr(self.__class__, func)
596596
if callable(_func):
597597
# the target function is not decorated with `@requests` yet
598-
self.requests[
599-
endpoint
600-
] = _FunctionWithSchema.get_function_with_schema(_func)
598+
self.requests[endpoint] = (
599+
_FunctionWithSchema.get_function_with_schema(_func)
600+
)
601601
elif typename(_func) == 'jina.executors.decorators.FunctionMapper':
602602
# the target function is already decorated with `@requests`, need unwrap with `.fn`
603-
self.requests[
604-
endpoint
605-
] = _FunctionWithSchema.get_function_with_schema(_func.fn)
603+
self.requests[endpoint] = (
604+
_FunctionWithSchema.get_function_with_schema(_func.fn)
605+
)
606606
else:
607607
raise TypeError(
608608
f'expect {typename(self)}.{func} to be a function, but receiving {typename(_func)}'

jina/serve/runtimes/worker/http_csp_app.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,14 @@ def construct_model_from_line(
182182
# Handle list of nested models
183183
elif get_origin(field_type) is list:
184184
list_item_type = get_args(field_type)[0]
185-
parsed_list = json.loads(field_str)
186-
if issubclass(list_item_type, BaseModel):
187-
parsed_fields[field_name] = parse_obj_as(
188-
List[list_item_type], parsed_list
189-
)
190-
else:
191-
parsed_fields[field_name] = parsed_list
185+
if field_str:
186+
parsed_list = json.loads(field_str)
187+
if issubclass(list_item_type, BaseModel):
188+
parsed_fields[field_name] = parse_obj_as(
189+
List[list_item_type], parsed_list
190+
)
191+
else:
192+
parsed_fields[field_name] = parsed_list
192193
# General parsing attempt for other types
193194
else:
194195
if field_str:

tests/integration/docarray_v2/csp/SampleClipExecutor/executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, List
22

33
import numpy as np
44
from docarray import BaseDoc, DocList
@@ -13,6 +13,8 @@ class TextAndImageDoc(BaseDoc):
1313
text: Optional[str] = None
1414
url: Optional[AnyUrl] = None
1515
bytes: Optional[ImageBytes] = None
16+
num_tokens: Optional[int] = None
17+
input_ids: Optional[List[int]] = None
1618

1719

1820
class EmbeddingResponseModel(TextAndImageDoc):

tests/integration/docarray_v2/csp/valid_clip_input.csv

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)