File tree 3 files changed +15
-0
lines changed
3 files changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -178,6 +178,8 @@ def create_pytorch_ray_engine(
178
178
is_disaggregated : bool = False ,
179
179
num_hosts : int = 0 ,
180
180
decode_pod_slice_name : str = None ,
181
+ enable_jax_profiler : bool = False ,
182
+ jax_profiler_port : int = 9999 ,
181
183
) -> Any :
182
184
183
185
# Return tuple as reponse: issues/107
@@ -218,6 +220,8 @@ def create_pytorch_ray_engine(
218
220
quantize_kv = quantize_kv ,
219
221
max_cache_length = max_cache_length ,
220
222
sharding_config = sharding_config ,
223
+ enable_jax_profiler = enable_jax_profiler ,
224
+ jax_profiler_port = jax_profiler_port ,
221
225
)
222
226
engine_workers .append (engine_worker )
223
227
Original file line number Diff line number Diff line change @@ -114,6 +114,8 @@ def __init__(
114
114
quantize_kv = False ,
115
115
max_cache_length = 1024 ,
116
116
sharding_config = None ,
117
+ enable_jax_profiler : bool = False ,
118
+ jax_profiler_port : int = 9999 ,
117
119
):
118
120
119
121
jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
@@ -130,6 +132,10 @@ def __init__(
130
132
f"---Jax device_count:{ device_count } , local_device_count{ local_device_count } "
131
133
)
132
134
135
+ if enable_jax_profiler :
136
+ jax .profiler .start_server (jax_profiler_port )
137
+ print (f"Started JAX profiler server on port { jax_profiler_port } " )
138
+
133
139
checkpoint_format = ""
134
140
checkpoint_path = ""
135
141
Original file line number Diff line number Diff line change 34
34
flags .DEFINE_integer ("prometheus_port" , 0 , "" )
35
35
flags .DEFINE_integer ("tpu_chips" , 16 , "device tpu_chips" )
36
36
37
+ flags .DEFINE_bool ("enable_jax_profiler" , False , "enable jax profiler" )
38
+ flags .DEFINE_integer ("jax_profiler_port" , 9999 , "port of JAX profiler server" )
39
+
37
40
38
41
def create_engine ():
39
42
"""create a pytorch engine"""
@@ -53,6 +56,8 @@ def create_engine():
53
56
quantize_kv = FLAGS .quantize_kv_cache ,
54
57
max_cache_length = FLAGS .max_cache_length ,
55
58
sharding_config = FLAGS .sharding_config ,
59
+ enable_jax_profiler = FLAGS .enable_jax_profiler ,
60
+ jax_profiler_port = FLAGS .jax_profiler_port ,
56
61
)
57
62
58
63
print ("Initialize engine" , time .perf_counter () - start )
You can’t perform that action at this time.
0 commit comments