Skip to content

Commit 2e920e3

Browse files
JasonLi1909landscapepainter
authored andcommitted
Ray Train test_jax_trainer::test_minimal_multihost Flaky Test Fix (ray-project#56548)
A fix that addresses the failing flaky test `test_jax_trainer.py::test_minimal_multihost`. https://buildkite.com/ray-project/postmerge/builds/12941#01993f89-cc62-4e31-8de2-8b18f81ac177 Issue: The `test_minimal_multihost` introduces a race condition by attempting to initialize a virtualenv directory twice at the same directory path during worker runtime environment setup. This test would not fail in a true multi-host environment, but the tests simulate a multi-host environment on a singular device. This might be a ray core issue resulting in errors on runtime _env, but this PR will at least unblock the test so it is no longer flaky. Fix: Move `worker_runtime_env` to the job level so that the `pip install jax` only happens once --------- Signed-off-by: JasonLi1909 <jasli1909@gmail.com>
1 parent 7dee2bc commit 2e920e3

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

python/ray/train/v2/tests/test_jax_trainer.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,27 @@
33
import ray
44
from ray.tests.conftest import _ray_start_cluster
55
from ray.train import RunConfig, ScalingConfig
6-
from ray.train.v2._internal.constants import HEALTH_CHECK_INTERVAL_S_ENV_VAR
6+
from ray.train.v2._internal.constants import (
7+
HEALTH_CHECK_INTERVAL_S_ENV_VAR,
8+
is_v2_enabled,
9+
)
710
from ray.train.v2.jax import JaxTrainer
811

12+
assert is_v2_enabled()
13+
14+
15+
@pytest.fixture
16+
def jax_runtime_env():
17+
return {
18+
"pip": ["jax"],
19+
"env_vars": {
20+
"JAX_PLATFORMS": "cpu",
21+
},
22+
}
23+
924

1025
@pytest.fixture
11-
def ray_tpu_single_host(monkeypatch):
26+
def ray_tpu_single_host(monkeypatch, jax_runtime_env):
1227
"""Start a mock single-host TPU Ray cluster with 2x4 v6e (8 chips per host)."""
1328
with _ray_start_cluster() as cluster:
1429
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v6e-8")
@@ -19,14 +34,14 @@ def ray_tpu_single_host(monkeypatch):
1934
resources={"TPU": 8},
2035
)
2136

22-
ray.init(address=cluster.address)
37+
ray.init(address=cluster.address, runtime_env=jax_runtime_env)
2338

2439
yield cluster
2540
ray.shutdown()
2641

2742

2843
@pytest.fixture
29-
def ray_tpu_multi_host(monkeypatch):
44+
def ray_tpu_multi_host(monkeypatch, jax_runtime_env):
3045
"""Start a simulated multi-host TPU Ray cluster."""
3146
with _ray_start_cluster() as cluster:
3247
monkeypatch.setenv("TPU_NAME", "test-slice-1")
@@ -44,7 +59,7 @@ def ray_tpu_multi_host(monkeypatch):
4459
resources={"TPU": 4},
4560
)
4661

47-
ray.init(address=cluster.address)
62+
ray.init(address=cluster.address, runtime_env=jax_runtime_env)
4863

4964
yield cluster
5065
ray.shutdown()
@@ -78,12 +93,6 @@ def test_minimal_singlehost(ray_tpu_single_host, tmp_path):
7893
),
7994
run_config=RunConfig(
8095
storage_path=str(tmp_path),
81-
worker_runtime_env={
82-
"pip": ["jax"],
83-
"env_vars": {
84-
"JAX_PLATFORMS": "cpu",
85-
},
86-
},
8796
),
8897
)
8998
result = trainer.fit()
@@ -109,12 +118,6 @@ def test_minimal_multihost(ray_tpu_multi_host, tmp_path):
109118
),
110119
run_config=RunConfig(
111120
storage_path=str(tmp_path),
112-
worker_runtime_env={
113-
"pip": ["jax"],
114-
"env_vars": {
115-
"JAX_PLATFORMS": "cpu",
116-
},
117-
},
118121
),
119122
)
120123
result = trainer.fit()

0 commit comments

Comments
 (0)