Skip to content

Commit 22ca368

Browse files
committed
Update submodules, prepare for leasing v0.2.4
1 parent 97aaeae commit 22ca368

File tree

7 files changed

+15
-5
lines changed

7 files changed

+15
-5
lines changed

benchmarks/prefill_offline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import os
1717
import time
1818

19+
# import torch_xla2 first!
20+
# pylint: disable-next=all
21+
import torch_xla2
1922
import humanize
2023
import jax
2124
import numpy as np

benchmarks/run_offline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import os
1717
import time
1818

19+
# import torch_xla2 first!
20+
# pylint: disable-next=all
21+
import torch_xla2
1922
import jax
2023
import jax.numpy as jnp
2124
# pylint: disable-next=all

deps/JetStream

deps/xla

Submodule xla updated 231 files

install_everything.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,18 @@ pip show tensorboard && pip uninstall -y tensorboard
2424
pip show tensorflow-text && pip uninstall -y tensorflow-text
2525
pip show torch_xla2 && pip uninstall -y torch_xla2
2626

27-
pip install flax==0.8.3
28-
pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
27+
pip install flax
2928
pip install tensorflow-text
3029
pip install tensorflow
3130

3231
pip install ray[default]==2.22.0
3332
# torch cpu
34-
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu
33+
pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
3534
pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
3635
pip install safetensors colorama coverage humanize
3736

3837
git submodule update --init --recursive
3938
pip show google-jetstream && pip uninstall -y google-jetstream
4039
pip show torch_xla2 && pip uninstall -y torch_xla2
4140
pip install -e .
41+
pip install -U jax[tpu]==0.4.29 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

run_interactive.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import time
1818
from typing import List
1919

20+
# import torch_xla2 first!
21+
import torch_xla2 # pylint: disable
2022
import jax
2123
import numpy as np
2224
from absl import app, flags

run_server_with_ray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from typing import Sequence
1919
from absl import app, flags
2020

21+
# import torch_xla2 first!
22+
import torch_xla2 # pylint: disable
2123
import jax
2224
from jetstream.core import server_lib
2325
from jetstream.core.config_lib import ServerConfig

0 commit comments

Comments
 (0)