Skip to content

Commit 508d71c

Browse files
authored
Set JAX_PLATFORMS to "tpu, cpu" for ray worker (#145)
set JAX_PLATFORMS for ray worker
1 parent 98a8e28 commit 508d71c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jetstream_pt/ray_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import ray
7+
from ray.runtime_env import RuntimeEnv
78
from ray.util.accelerators import tpu
89

910
from jetstream.engine import engine_api, tokenizer_pb2
@@ -241,7 +242,8 @@ def create_pytorch_ray_engine(
241242
), f"num_hosts (current value {num_hosts}) should be a positive number"
242243
# pylint: disable-next=all
243244
engine_worker_with_tpu_resource = PyTorchRayWorker.options(
244-
resources={"TPU": 4}
245+
resources={"TPU": 4},
246+
runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "tpu,cpu"}),
245247
)
246248
engine_workers = []
247249
for _ in range(num_hosts):

0 commit comments

Comments
 (0)