Skip to content

Commit fe328bb

Browse files
authored
Enable jax profiler server in run with ray (#112)
* add jax profiler server * update jetstream
1 parent a32be5d commit fe328bb

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

jetstream_pt/ray_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def create_pytorch_ray_engine(
178178
is_disaggregated: bool = False,
179179
num_hosts: int = 0,
180180
decode_pod_slice_name: str = None,
181+
enable_jax_profiler: bool = False,
182+
jax_profiler_port: int = 9999,
181183
) -> Any:
182184

183185
# Return tuple as reponse: issues/107
@@ -218,6 +220,8 @@ def create_pytorch_ray_engine(
218220
quantize_kv=quantize_kv,
219221
max_cache_length=max_cache_length,
220222
sharding_config=sharding_config,
223+
enable_jax_profiler=enable_jax_profiler,
224+
jax_profiler_port=jax_profiler_port,
221225
)
222226
engine_workers.append(engine_worker)
223227

jetstream_pt/ray_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def __init__(
114114
quantize_kv=False,
115115
max_cache_length=1024,
116116
sharding_config=None,
117+
enable_jax_profiler: bool = False,
118+
jax_profiler_port: int = 9999,
117119
):
118120

119121
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
@@ -130,6 +132,10 @@ def __init__(
130132
f"---Jax device_count:{device_count}, local_device_count{local_device_count} "
131133
)
132134

135+
if enable_jax_profiler:
136+
jax.profiler.start_server(jax_profiler_port)
137+
print(f"Started JAX profiler server on port {jax_profiler_port}")
138+
133139
checkpoint_format = ""
134140
checkpoint_path = ""
135141

run_server_with_ray.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
flags.DEFINE_integer("prometheus_port", 0, "")
3535
flags.DEFINE_integer("tpu_chips", 16, "device tpu_chips")
3636

37+
flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
38+
flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")
39+
3740

3841
def create_engine():
3942
"""create a pytorch engine"""
@@ -53,6 +56,8 @@ def create_engine():
5356
quantize_kv=FLAGS.quantize_kv_cache,
5457
max_cache_length=FLAGS.max_cache_length,
5558
sharding_config=FLAGS.sharding_config,
59+
enable_jax_profiler=FLAGS.enable_jax_profiler,
60+
jax_profiler_port=FLAGS.jax_profiler_port,
5661
)
5762

5863
print("Initialize engine", time.perf_counter() - start)

0 commit comments

Comments
 (0)