Skip to content

Commit 88d3665

Browse files
JasonLi1909liulehui
authored andcommitted
Revising test_jax_trainer flaky test (ray-project#56854)
Revisiting ray-project#56548 as test continues to be flaky on CI **Solution**: The previous attempt to deflake this test still used a `pip install jax` via the `ray.init` runtime_env args. Hence, the pip install related error persisted. This PR instead adds `jax` and `jaxlib` as a dependency of CI train tests, avoiding the need to `pip install jax` via the runtime_env. --------- Signed-off-by: JasonLi1909 <jasli1909@gmail.com> Signed-off-by: Jason Li <57246540+JasonLi1909@users.noreply.github.com>
1 parent 3237290 commit 88d3665

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

python/ray/train/v2/tests/test_jax_trainer.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,7 @@
1313

1414

1515
@pytest.fixture
16-
def jax_runtime_env():
17-
return {
18-
"pip": ["jax"],
19-
"env_vars": {
20-
"JAX_PLATFORMS": "cpu",
21-
},
22-
}
23-
24-
25-
@pytest.fixture
26-
def ray_tpu_single_host(monkeypatch, jax_runtime_env):
16+
def ray_tpu_single_host(monkeypatch):
2717
"""Start a mock single-host TPU Ray cluster with 2x4 v6e (8 chips per host)."""
2818
with _ray_start_cluster() as cluster:
2919
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v6e-8")
@@ -34,14 +24,14 @@ def ray_tpu_single_host(monkeypatch, jax_runtime_env):
3424
resources={"TPU": 8},
3525
)
3626

37-
ray.init(address=cluster.address, runtime_env=jax_runtime_env)
27+
ray.init(address=cluster.address)
3828

3929
yield cluster
4030
ray.shutdown()
4131

4232

4333
@pytest.fixture
44-
def ray_tpu_multi_host(monkeypatch, jax_runtime_env):
34+
def ray_tpu_multi_host(monkeypatch):
4535
"""Start a simulated multi-host TPU Ray cluster."""
4636
with _ray_start_cluster() as cluster:
4737
monkeypatch.setenv("TPU_NAME", "test-slice-1")
@@ -59,7 +49,7 @@ def ray_tpu_multi_host(monkeypatch, jax_runtime_env):
5949
resources={"TPU": 4},
6050
)
6151

62-
ray.init(address=cluster.address, runtime_env=jax_runtime_env)
52+
ray.init(address=cluster.address)
6353

6454
yield cluster
6555
ray.shutdown()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
evaluate==0.4.3
22
mosaicml; python_version < "3.12"
33
sentencepiece==0.1.96
4+
jax==0.4.25
5+
jaxlib==0.4.25
46
s3torchconnector==1.4.3

python/requirements_compiled.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,10 @@ isoduration==20.11.0
859859
# via jsonschema
860860
itsdangerous==2.1.2
861861
# via flask
862+
jax==0.4.25
863+
# via -r python/requirements/ml/train-test-requirements.txt
864+
jaxlib==0.4.25
865+
# via -r python/requirements/ml/train-test-requirements.txt
862866
jedi==0.19.1
863867
# via ipython
864868
jinja2==3.1.6
@@ -1081,7 +1085,10 @@ mistune==0.8.4
10811085
ml-collections==0.1.1
10821086
# via open-spiel
10831087
ml-dtypes==0.3.2
1084-
# via tensorflow
1088+
# via
1089+
# jax
1090+
# jaxlib
1091+
# tensorflow
10851092
mlagents-envs==0.28.0
10861093
# via -r python/requirements/ml/rllib-test-requirements.txt
10871094
mlflow==2.22.0
@@ -1253,6 +1260,8 @@ numpy==1.26.4
12531260
# hpbandster
12541261
# hyperopt
12551262
# imageio
1263+
# jax
1264+
# jaxlib
12561265
# labmaze
12571266
# lightgbm
12581267
# matplotlib
@@ -1379,6 +1388,7 @@ opentelemetry-util-http==0.55b1
13791388
# opentelemetry-instrumentation-fastapi
13801389
opt-einsum==3.3.0
13811390
# via
1391+
# jax
13821392
# pyro-ppl
13831393
# tensorflow
13841394
optuna==4.1.0
@@ -2024,6 +2034,8 @@ scipy==1.11.4
20242034
# gpy
20252035
# hpbandster
20262036
# hyperopt
2037+
# jax
2038+
# jaxlib
20272039
# lightgbm
20282040
# linear-operator
20292041
# mlflow

0 commit comments

Comments
 (0)