Skip to content

Commit 0bb69ac

Browse files
Optimize mega flow by removing microservice wrapper (#582)
* refactor orchestrator * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove no_wrapper * fix * fix * add align_gen * add retriever and rerank params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add fake test for customize params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix dep --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3367b76 commit 0bb69ac

File tree

4 files changed

+168
-15
lines changed

4 files changed

+168
-15
lines changed

comps/cores/mega/gateway.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
EmbeddingRequest,
2121
UsageInfo,
2222
)
23-
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc
23+
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
2424
from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType
2525
from .micro_service import MicroService
2626

@@ -167,8 +167,22 @@ async def handle_request(self, request: Request):
167167
streaming=stream_opt,
168168
chat_template=chat_request.chat_template if chat_request.chat_template else None,
169169
)
170+
retriever_parameters = RetrieverParms(
171+
search_type=chat_request.search_type if chat_request.search_type else "similarity",
172+
k=chat_request.k if chat_request.k else 4,
173+
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
174+
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
175+
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
176+
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
177+
)
178+
reranker_parameters = RerankerParms(
179+
top_n=chat_request.top_n if chat_request.top_n else 1,
180+
)
170181
result_dict, runtime_graph = await self.megaservice.schedule(
171-
initial_inputs={"text": prompt}, llm_parameters=parameters
182+
initial_inputs={"text": prompt},
183+
llm_parameters=parameters,
184+
retriever_parameters=retriever_parameters,
185+
reranker_parameters=reranker_parameters,
172186
)
173187
for node, response in result_dict.items():
174188
if isinstance(response, StreamingResponse):

comps/cores/mega/orchestrator.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import copy
66
import json
7+
import os
78
import re
89
from typing import Dict, List
910

@@ -14,6 +15,10 @@
1415
from ..proto.docarray import LLMParams
1516
from .constants import ServiceType
1617
from .dag import DAG
18+
from .logger import CustomLogger
19+
20+
logger = CustomLogger("comps-core-orchestrator")
21+
LOGFLAG = os.getenv("LOGFLAG", False)
1722

1823

1924
class ServiceOrchestrator(DAG):
@@ -36,18 +41,22 @@ def flow_to(self, from_service, to_service):
3641
self.add_edge(from_service.name, to_service.name)
3742
return True
3843
except Exception as e:
39-
print(e)
44+
logger.error(e)
4045
return False
4146

42-
async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams()):
47+
async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams(), **kwargs):
4348
result_dict = {}
4449
runtime_graph = DAG()
4550
runtime_graph.graph = copy.deepcopy(self.graph)
51+
if LOGFLAG:
52+
logger.info(initial_inputs)
4653

4754
timeout = aiohttp.ClientTimeout(total=1000)
4855
async with aiohttp.ClientSession(trust_env=True, timeout=timeout) as session:
4956
pending = {
50-
asyncio.create_task(self.execute(session, node, initial_inputs, runtime_graph, llm_parameters))
57+
asyncio.create_task(
58+
self.execute(session, node, initial_inputs, runtime_graph, llm_parameters, **kwargs)
59+
)
5160
for node in self.ind_nodes()
5261
}
5362
ind_nodes = self.ind_nodes()
@@ -67,11 +76,12 @@ async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMPa
6776
for downstream in reversed(downstreams):
6877
try:
6978
if re.findall(black_node, downstream):
70-
print(f"skip forwardding to {downstream}...")
79+
if LOGFLAG:
80+
logger.info(f"skip forwardding to {downstream}...")
7181
runtime_graph.delete_edge(node, downstream)
7282
downstreams.remove(downstream)
7383
except re.error as e:
74-
print("Pattern invalid! Operation cancelled.")
84+
logger.error("Pattern invalid! Operation cancelled.")
7585
if len(downstreams) == 0 and llm_parameters.streaming:
7686
# turn the response to a StreamingResponse
7787
# to make the response uniform to UI
@@ -90,7 +100,7 @@ def fake_stream(text):
90100
inputs = self.process_outputs(runtime_graph.predecessors(d_node), result_dict)
91101
pending.add(
92102
asyncio.create_task(
93-
self.execute(session, d_node, inputs, runtime_graph, llm_parameters)
103+
self.execute(session, d_node, inputs, runtime_graph, llm_parameters, **kwargs)
94104
)
95105
)
96106
nodes_to_keep = []
@@ -121,21 +131,33 @@ async def execute(
121131
inputs: Dict,
122132
runtime_graph: DAG,
123133
llm_parameters: LLMParams = LLMParams(),
134+
**kwargs,
124135
):
125136
# send the cur_node request/reply
126137
endpoint = self.services[cur_node].endpoint_path
127138
llm_parameters_dict = llm_parameters.dict()
128-
for field, value in llm_parameters_dict.items():
129-
if inputs.get(field) != value:
130-
inputs[field] = value
139+
if self.services[cur_node].service_type == ServiceType.LLM:
140+
for field, value in llm_parameters_dict.items():
141+
if inputs.get(field) != value:
142+
inputs[field] = value
143+
144+
# pre-process
145+
inputs = self.align_inputs(inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs)
131146

132147
if (
133148
self.services[cur_node].service_type == ServiceType.LLM
134149
or self.services[cur_node].service_type == ServiceType.LVM
135150
) and llm_parameters.streaming:
136151
# Still leave to sync requests.post for StreamingResponse
152+
if LOGFLAG:
153+
logger.info(inputs)
137154
response = requests.post(
138-
url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000
155+
url=endpoint,
156+
data=json.dumps(inputs),
157+
headers={"Content-type": "application/json"},
158+
proxies={"http": None},
159+
stream=True,
160+
timeout=1000,
139161
)
140162
downstream = runtime_graph.downstream(cur_node)
141163
if downstream:
@@ -169,11 +191,32 @@ def generate():
169191
else:
170192
yield chunk
171193

172-
return StreamingResponse(generate(), media_type="text/event-stream"), cur_node
194+
return (
195+
StreamingResponse(self.align_generator(generate(), **kwargs), media_type="text/event-stream"),
196+
cur_node,
197+
)
173198
else:
199+
if LOGFLAG:
200+
logger.info(inputs)
174201
async with session.post(endpoint, json=inputs) as response:
175-
print(f"{cur_node}: {response.status}")
176-
return await response.json(), cur_node
202+
# Parse as JSON
203+
data = await response.json()
204+
# post process
205+
data = self.align_outputs(data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs)
206+
207+
return data, cur_node
208+
209+
def align_inputs(self, inputs, *args, **kwargs):
210+
"""Override this method in megaservice definition."""
211+
return inputs
212+
213+
def align_outputs(self, data, *args, **kwargs):
214+
"""Override this method in megaservice definition."""
215+
return data
216+
217+
def align_generator(self, gen, *args, **kwargs):
218+
"""Override this method in megaservice definition."""
219+
return gen
177220

178221
def dump_outputs(self, node, response, result_dict):
179222
result_dict[node] = response

comps/cores/proto/docarray.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,19 @@ class LLMParams(BaseDoc):
173173
)
174174

175175

176+
class RetrieverParms(BaseDoc):
177+
search_type: str = "similarity"
178+
k: int = 4
179+
distance_threshold: Optional[float] = None
180+
fetch_k: int = 20
181+
lambda_mult: float = 0.5
182+
score_threshold: float = 0.2
183+
184+
185+
class RerankerParms(BaseDoc):
186+
top_n: int = 1
187+
188+
176189
class RAGASParams(BaseDoc):
177190
questions: DocList[TextDoc]
178191
answers: DocList[TextDoc]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
import unittest
6+
7+
from comps import (
8+
EmbedDoc,
9+
Gateway,
10+
RerankedDoc,
11+
ServiceOrchestrator,
12+
TextDoc,
13+
opea_microservices,
14+
register_microservice,
15+
)
16+
from comps.cores.mega.constants import ServiceType
17+
from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms
18+
19+
20+
@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add", service_type=ServiceType.RETRIEVER)
21+
async def s1_add(request: EmbedDoc) -> TextDoc:
22+
req = request.model_dump_json()
23+
req_dict = json.loads(req)
24+
text = req_dict["text"]
25+
text += f"opea top_k {req_dict['k']}"
26+
return {"text": text}
27+
28+
29+
@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.RERANK)
30+
async def s2_add(request: TextDoc) -> TextDoc:
31+
req = request.model_dump_json()
32+
req_dict = json.loads(req)
33+
text = req_dict["text"]
34+
text += "project!"
35+
return {"text": text}
36+
37+
38+
def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
39+
if self.services[cur_node].service_type == ServiceType.RETRIEVER:
40+
inputs["k"] = kwargs["retriever_parameters"].k
41+
42+
return inputs
43+
44+
45+
def align_outputs(self, outputs, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
46+
if self.services[cur_node].service_type == ServiceType.RERANK:
47+
top_n = kwargs["reranker_parameters"].top_n
48+
outputs["text"] = outputs["text"][:top_n]
49+
return outputs
50+
51+
52+
class TestServiceOrchestratorParams(unittest.IsolatedAsyncioTestCase):
53+
def setUp(self):
54+
self.s1 = opea_microservices["s1"]
55+
self.s2 = opea_microservices["s2"]
56+
self.s1.start()
57+
self.s2.start()
58+
59+
ServiceOrchestrator.align_inputs = align_inputs
60+
ServiceOrchestrator.align_outputs = align_outputs
61+
self.service_builder = ServiceOrchestrator()
62+
63+
self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
64+
self.service_builder.flow_to(self.s1, self.s2)
65+
self.gateway = Gateway(self.service_builder, port=9898)
66+
67+
def tearDown(self):
68+
self.s1.stop()
69+
self.s2.stop()
70+
self.gateway.stop()
71+
72+
async def test_retriever_schedule(self):
73+
result_dict, _ = await self.service_builder.schedule(
74+
initial_inputs={"text": "hello, ", "embedding": [1.0, 2.0, 3.0]},
75+
retriever_parameters=RetrieverParms(k=8),
76+
reranker_parameters=RerankerParms(top_n=20),
77+
)
78+
self.assertEqual(len(result_dict[self.s2.name]["text"]), 20) # Check reranker top_n is accessed
79+
self.assertTrue("8" in result_dict[self.s2.name]["text"]) # Check retriever k is accessed
80+
81+
82+
if __name__ == "__main__":
83+
unittest.main()

0 commit comments

Comments
 (0)