Skip to content

Commit ec4b158

Browse files
authored
add enable jax profiler to run_server (#140)
1 parent 175d956 commit ec4b158

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

run_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
"available servers",
3333
)
3434
flags.DEFINE_integer("prometheus_port", 0, "")
35+
flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
36+
flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")
3537

3638

3739
# pylint: disable-next=all
@@ -62,6 +64,8 @@ def main(argv: Sequence[str]):
6264
config=server_config,
6365
devices=devices,
6466
metrics_server_config=metrics_server_config,
67+
enable_jax_profiler=FLAGS.enable_jax_profiler,
68+
jax_profiler_port=FLAGS.jax_profiler_port,
6569
)
6670
print("Started jetstream_server....")
6771
jetstream_server.wait_for_termination()

0 commit comments

Comments
 (0)