forked from PaddlePaddle/FastDeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexpert_service.py
More file actions
163 lines (136 loc) · 6.31 KB
/
expert_service.py
File metadata and controls
163 lines (136 loc) · 6.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
import signal
import threading
import time
import traceback
import weakref
import numpy as np
from fastdeploy.engine.common_engine import EngineService
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.utils import console_logger, envs, llm_logger
class ExpertService:
"""
Engine class responsible for managing the Large Language Model (LLM) operations.
Attributes:
cfg (Config): Configuration object containing all the parameters.
local_data_parallel_id (int): Local data parallel ID.
"""
def __init__(self, cfg, local_data_parallel_id, start_queue=True):
"""
Initializes the LLMEngine with the provided configuration.
Args:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
if cfg.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
self.cfg.disaggregate_info = None
if cfg.splitwise_role != "mixed":
if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = (
int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
)
else:
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
self.engine = EngineService(self.cfg, start_queue)
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(self, ipc_signal_suffix, local_data_parallel_id):
"""
Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread
to keep getting request from zmq_server.
"""
# assert not self.is_started, "The engine is already started."
start_time = time.time()
self.engine.start()
if ipc_signal_suffix is not None:
self.api_server_pid = ipc_signal_suffix
self.engine.start_zmq_service(ipc_signal_suffix)
else:
ipc_signal_suffix = self.cfg.engine_worker_queue_port[0]
llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.splitwise_role != "mixed":
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix)
self.engine.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise":
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
self.engine.scheduler.start(role, host_ip, disaggregate)
if self.cfg.splitwise_role != "mixed":
self.splitwise_receive_thread = threading.Thread(
target=self.engine.split_connector.start_receiver, args=()
)
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
self.cfg.print()
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
if not envs.FD_ENABLE_MULTI_API_SERVER:
launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=ipc_signal_suffix,
create=False,
)
self.launched_expert_service_signal.value[local_rank] = 1
console_logger.info(
f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds."
)
return True
def _exit_sub_services(self):
"""
exit sub services
"""
if hasattr(self, "cache_manager_processes"):
self.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.engine.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
try:
os.killpg(p.pid, signal.SIGTERM)
except:
pass
if hasattr(self, "zmq_server") and self.zmq_server is not None:
self.zmq_server.close()
def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=None):
"""
Start expert service
"""
expert_service = ExpertService(cfg, local_data_parallel_id, start_queue=False)
try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
def deamon_thread():
while True:
time.sleep(10)
t_deamon = threading.Thread(target=deamon_thread, daemon=True)
t_deamon.start()
except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")