We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 175d956 commit ec4b158Copy full SHA for ec4b158
run_server.py
@@ -32,6 +32,8 @@
32
"available servers",
33
)
34
flags.DEFINE_integer("prometheus_port", 0, "")
35
+flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
36
+flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")
37
38
39
# pylint: disable-next=all
@@ -62,6 +64,8 @@ def main(argv: Sequence[str]):
62
64
config=server_config,
63
65
devices=devices,
66
metrics_server_config=metrics_server_config,
67
+ enable_jax_profiler=FLAGS.enable_jax_profiler,
68
+ jax_profiler_port=FLAGS.jax_profiler_port,
69
70
print("Started jetstream_server....")
71
jetstream_server.wait_for_termination()
0 commit comments