Skip to content

Commit fa1f120

Browse files
authored
make sure GPU works (#130)
* make sure GPU works
1 parent aa90b05 commit fa1f120

File tree

5 files changed

+32
-6
lines changed

5 files changed

+32
-6
lines changed

deps/xla

Submodule xla updated 52 files

install_everything.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@ git submodule update --init --recursive
3838
pip show google-jetstream && pip uninstall -y google-jetstream
3939
pip show torch_xla2 && pip uninstall -y torch_xla2
4040
pip install -e .
41-
pip install -U jax[tpu]==0.4.29 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
41+
pip install -U jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
42+
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu

install_everything_gpu.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,18 @@ pip show tensorboard && pip uninstall -y tensorboard
2424
pip show tensorflow-text && pip uninstall -y tensorflow-text
2525
pip show torch_xla2 && pip uninstall -y torch_xla2
2626

27-
pip install flax==0.8.3
28-
pip install -U "jax[cuda12]==0.4.28"
27+
pip install flax==0.8.4
2928
pip install tensorflow-text
3029
pip install tensorflow
3130

3231
pip install ray[default]==2.22.0
3332
# torch cpu
34-
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu
3533
pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
3634
pip install safetensors colorama coverage humanize
3735

3836
git submodule update --init --recursive
3937
pip show google-jetstream && pip uninstall -y google-jetstream
4038
pip show torch_xla2 && pip uninstall -y torch_xla2
4139
pip install -e .
40+
pip install -U jax[cuda12]==0.4.30
41+
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu

run_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
import os
1717
from typing import Sequence
1818

19+
# import torch_xla2 first!
20+
import torch_xla2 # pylint: disable
1921
import jax
20-
import jetstream_pt
2122
from absl import app, flags
2223
from jetstream.core import server_lib
2324
from jetstream.core.config_lib import ServerConfig, MetricsServerConfig

tests/test_model_impl.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from jetstream_pt.third_party.llama import model_original
2424
from jetstream_pt.third_party.gemma import model_original as gemma_orig
2525
from jetstream_pt.third_party.gemma import model as gemma
26+
from jetstream_pt.third_party.mixtral import model as mixtral
27+
from jetstream_pt.third_party.mixtral import config as mixtral_config
2628
from jetstream_pt import torchjax
2729
from jetstream_pt import layers
2830
from jetstream_pt import cache_manager
@@ -360,6 +362,28 @@ def test_transformer(self):
360362
print("Transformer: Diff norm", (result_torch - expected_out).norm())
361363
self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4))
362364

365+
def test_mixtral_moe(self):
366+
config = mixtral_config.ModelArgs()
367+
config.intermediate_size = 16
368+
config.dim = 16
369+
m = mixtral.ConditionalFeedForward(config)
370+
# random init
371+
states = m.state_dict()
372+
for k, v in states.items():
373+
states[k].normal_()
374+
m.load_state_dict(states, assign=True)
375+
376+
seqlen = 3
377+
num_expert = 8
378+
num_active_expert = 2
379+
x = torch.randn(seqlen, config.dim)
380+
exp_index = torch.randint(0, num_expert, (seqlen, num_active_expert))
381+
382+
res1 = m.forward_for_short_seq_len(x, exp_index)
383+
res2 = m.forward_for_long_seq_len(x, exp_index)
384+
385+
torch.testing.assert_close(res1, res2)
386+
363387

364388
if __name__ == "__main__":
365389
unittest.main()

0 commit comments

Comments
 (0)