4
4
import asyncio
5
5
import copy
6
6
import json
7
+ import os
7
8
import re
8
9
from typing import Dict , List
9
10
14
15
from ..proto .docarray import LLMParams
15
16
from .constants import ServiceType
16
17
from .dag import DAG
18
+ from .logger import CustomLogger
19
+
20
+ logger = CustomLogger ("comps-core-orchestrator" )
21
+ LOGFLAG = os .getenv ("LOGFLAG" , False )
17
22
18
23
19
24
class ServiceOrchestrator (DAG ):
@@ -36,18 +41,22 @@ def flow_to(self, from_service, to_service):
36
41
self .add_edge (from_service .name , to_service .name )
37
42
return True
38
43
except Exception as e :
39
- print (e )
44
+ logger . error (e )
40
45
return False
41
46
42
- async def schedule (self , initial_inputs : Dict , llm_parameters : LLMParams = LLMParams ()):
47
+ async def schedule (self , initial_inputs : Dict , llm_parameters : LLMParams = LLMParams (), ** kwargs ):
43
48
result_dict = {}
44
49
runtime_graph = DAG ()
45
50
runtime_graph .graph = copy .deepcopy (self .graph )
51
+ if LOGFLAG :
52
+ logger .info (initial_inputs )
46
53
47
54
timeout = aiohttp .ClientTimeout (total = 1000 )
48
55
async with aiohttp .ClientSession (trust_env = True , timeout = timeout ) as session :
49
56
pending = {
50
- asyncio .create_task (self .execute (session , node , initial_inputs , runtime_graph , llm_parameters ))
57
+ asyncio .create_task (
58
+ self .execute (session , node , initial_inputs , runtime_graph , llm_parameters , ** kwargs )
59
+ )
51
60
for node in self .ind_nodes ()
52
61
}
53
62
ind_nodes = self .ind_nodes ()
@@ -67,11 +76,12 @@ async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMPa
67
76
for downstream in reversed (downstreams ):
68
77
try :
69
78
if re .findall (black_node , downstream ):
70
- print (f"skip forwardding to { downstream } ..." )
79
+ if LOGFLAG :
80
+ logger .info (f"skip forwardding to { downstream } ..." )
71
81
runtime_graph .delete_edge (node , downstream )
72
82
downstreams .remove (downstream )
73
83
except re .error as e :
74
- print ("Pattern invalid! Operation cancelled." )
84
+ logger . error ("Pattern invalid! Operation cancelled." )
75
85
if len (downstreams ) == 0 and llm_parameters .streaming :
76
86
# turn the response to a StreamingResponse
77
87
# to make the response uniform to UI
@@ -90,7 +100,7 @@ def fake_stream(text):
90
100
inputs = self .process_outputs (runtime_graph .predecessors (d_node ), result_dict )
91
101
pending .add (
92
102
asyncio .create_task (
93
- self .execute (session , d_node , inputs , runtime_graph , llm_parameters )
103
+ self .execute (session , d_node , inputs , runtime_graph , llm_parameters , ** kwargs )
94
104
)
95
105
)
96
106
nodes_to_keep = []
@@ -121,21 +131,33 @@ async def execute(
121
131
inputs : Dict ,
122
132
runtime_graph : DAG ,
123
133
llm_parameters : LLMParams = LLMParams (),
134
+ ** kwargs ,
124
135
):
125
136
# send the cur_node request/reply
126
137
endpoint = self .services [cur_node ].endpoint_path
127
138
llm_parameters_dict = llm_parameters .dict ()
128
- for field , value in llm_parameters_dict .items ():
129
- if inputs .get (field ) != value :
130
- inputs [field ] = value
139
+ if self .services [cur_node ].service_type == ServiceType .LLM :
140
+ for field , value in llm_parameters_dict .items ():
141
+ if inputs .get (field ) != value :
142
+ inputs [field ] = value
143
+
144
+ # pre-process
145
+ inputs = self .align_inputs (inputs , cur_node , runtime_graph , llm_parameters_dict , ** kwargs )
131
146
132
147
if (
133
148
self .services [cur_node ].service_type == ServiceType .LLM
134
149
or self .services [cur_node ].service_type == ServiceType .LVM
135
150
) and llm_parameters .streaming :
136
151
# Still leave to sync requests.post for StreamingResponse
152
+ if LOGFLAG :
153
+ logger .info (inputs )
137
154
response = requests .post (
138
- url = endpoint , data = json .dumps (inputs ), proxies = {"http" : None }, stream = True , timeout = 1000
155
+ url = endpoint ,
156
+ data = json .dumps (inputs ),
157
+ headers = {"Content-type" : "application/json" },
158
+ proxies = {"http" : None },
159
+ stream = True ,
160
+ timeout = 1000 ,
139
161
)
140
162
downstream = runtime_graph .downstream (cur_node )
141
163
if downstream :
@@ -169,11 +191,32 @@ def generate():
169
191
else :
170
192
yield chunk
171
193
172
- return StreamingResponse (generate (), media_type = "text/event-stream" ), cur_node
194
+ return (
195
+ StreamingResponse (self .align_generator (generate (), ** kwargs ), media_type = "text/event-stream" ),
196
+ cur_node ,
197
+ )
173
198
else :
199
+ if LOGFLAG :
200
+ logger .info (inputs )
174
201
async with session .post (endpoint , json = inputs ) as response :
175
- print (f"{ cur_node } : { response .status } " )
176
- return await response .json (), cur_node
202
+ # Parse as JSON
203
+ data = await response .json ()
204
+ # post process
205
+ data = self .align_outputs (data , cur_node , inputs , runtime_graph , llm_parameters_dict , ** kwargs )
206
+
207
+ return data , cur_node
208
+
209
+ def align_inputs (self , inputs , * args , ** kwargs ):
210
+ """Override this method in megaservice definition."""
211
+ return inputs
212
+
213
+ def align_outputs (self , data , * args , ** kwargs ):
214
+ """Override this method in megaservice definition."""
215
+ return data
216
+
217
+ def align_generator (self , gen , * args , ** kwargs ):
218
+ """Override this method in megaservice definition."""
219
+ return gen
177
220
178
221
def dump_outputs (self , node , response , result_dict ):
179
222
result_dict [node ] = response
0 commit comments