1
1
import json
2
2
import logging
3
- import os
4
- import uuid
5
3
from dataclasses import dataclass
4
+ from typing import Optional
6
5
7
6
logger = logging .getLogger (__name__ )
8
7
9
8
10
- @dataclass
11
- class MooncakeTransferEngineConfig :
12
- local_hostname : str
13
- metadata_server : str
14
- protocol : str
15
- device_name : str
16
-
17
- @staticmethod
18
- def from_file (file_path : str ) -> "MooncakeTransferEngineConfig" :
19
- """Load the config from a JSON file."""
20
- with open (file_path ) as fin :
21
- config = json .load (fin )
22
- return MooncakeTransferEngineConfig (
23
- local_hostname = config .get ("local_hostname" , None ),
24
- metadata_server = config .get ("metadata_server" ),
25
- protocol = config .get ("protocol" , "rdma" ),
26
- device_name = config .get ("device_name" , "" ),
27
- )
28
-
29
- @staticmethod
30
- def load_from_env () -> "MooncakeTransferEngineConfig" :
31
- """Load config from a file specified in the environment variable."""
32
- config_file_path = os .getenv ("MOONCAKE_CONFIG_PATH" )
33
- if config_file_path is None :
34
- raise ValueError (
35
- "The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
36
- )
37
- return MooncakeTransferEngineConfig .from_file (config_file_path )
38
-
39
-
40
9
class MooncakeTransferEngine :
41
10
42
- def __init__ (self ):
11
+ def __init__ (self , hostname : str , gpu_id : int , ib_device : Optional [ str ] = None ):
43
12
try :
44
13
from mooncake .engine import TransferEngine
45
14
except ImportError as e :
@@ -50,43 +19,43 @@ def __init__(self):
50
19
) from e
51
20
52
21
self .engine = TransferEngine ()
22
+ self .hostname = hostname
23
+ self .gpu_id = gpu_id
24
+ self .ib_device = ib_device
53
25
54
- try :
55
- self .config = MooncakeTransferEngineConfig .load_from_env ()
56
- logger .info ("Mooncake Configuration loaded successfully." )
57
- except ValueError as e :
58
- logger .error (e )
59
- raise
60
- except Exception as exc :
61
- logger .error ("An error occurred while loading the configuration: %s" , exc )
62
- raise
63
-
64
- self .config = MooncakeTransferEngineConfig .load_from_env ()
65
-
66
- session_suffix = "_" + str (uuid .uuid4 ())
67
- self .session_id = self .config .local_hostname + session_suffix
68
26
self .initialize (
69
- self .session_id ,
70
- self .config .metadata_server ,
71
- self .config .protocol ,
72
- self .config .device_name ,
27
+ hostname = self .hostname ,
28
+ device_name = self .ib_device ,
73
29
)
30
+ self .session_id = f"{ self .hostname } :{ self .engine .get_rpc_port ()} "
74
31
75
32
def register (self , ptr , length ):
76
- self .engine .register_memory (ptr , length )
33
+ ret_value = self .engine .register_memory (ptr , length )
34
+ if ret_value != 0 :
35
+ logger .error ("Mooncake memory registration failed." )
36
+ raise RuntimeError ("Mooncake memory registration failed." )
77
37
78
38
def deregister (self , ptr ):
79
- self .engine .unregister_memory (ptr )
39
+ ret_value = self .engine .unregister_memory (ptr )
40
+ if ret_value != 0 :
41
+ logger .error ("Mooncake memory deregistration failed." )
42
+ raise RuntimeError ("Mooncake memory deregistration failed." )
80
43
81
44
def initialize (
82
45
self ,
83
- local_hostname : str ,
84
- metadata_server : str ,
85
- protocol : str ,
86
- device_name : str ,
46
+ hostname : str ,
47
+ device_name : Optional [str ],
87
48
) -> None :
88
49
"""Initialize the mooncake instance."""
89
- self .engine .initialize (local_hostname , metadata_server , protocol , device_name )
50
+ ret_value = self .engine .initialize (
51
+ hostname ,
52
+ "P2PHANDSHAKE" ,
53
+ "rdma" ,
54
+ device_name if device_name is not None else "" ,
55
+ )
56
+ if ret_value != 0 :
57
+ logger .error ("Mooncake Transfer Engine initialization failed." )
58
+ raise RuntimeError ("Mooncake Transfer Engine initialization failed." )
90
59
91
60
def transfer_sync (
92
61
self , session_id : str , buffer : int , peer_buffer_address : int , length : int
@@ -97,12 +66,12 @@ def transfer_sync(
97
66
session_id , buffer , peer_buffer_address , length
98
67
)
99
68
if ret < 0 :
100
- logger .error ("Transfer Return Error" )
101
- raise Exception ( " Transfer Return Error" )
69
+ logger .error ("Mooncake Transfer Engine Return Error. " )
70
+ raise RuntimeError ( "Mooncake Transfer Engine Return Error. " )
102
71
return ret
103
72
104
73
def get_localhost (self ):
105
- return self .config . local_hostname
74
+ return self .hostname
106
75
107
76
def get_session_id (self ):
108
77
return self .session_id
0 commit comments