diff --git a/deps/xla b/deps/xla index c216d26c..c2753715 160000 --- a/deps/xla +++ b/deps/xla @@ -1 +1 @@ -Subproject commit c216d26c23a37eb85dd8f8152ffe1acdb6b484a0 +Subproject commit c27537153f3ea983a7ba9b0e1bfdae4b37ca5e9e diff --git a/install_everything.sh b/install_everything.sh index 220e6df2..e4366327 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -38,4 +38,5 @@ git submodule update --init --recursive pip show google-jetstream && pip uninstall -y google-jetstream pip show torch_xla2 && pip uninstall -y torch_xla2 pip install -e . -pip install -U jax[tpu]==0.4.29 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu diff --git a/install_everything_gpu.sh b/install_everything_gpu.sh index ebf53bab..b581c159 100644 --- a/install_everything_gpu.sh +++ b/install_everything_gpu.sh @@ -24,14 +24,12 @@ pip show tensorboard && pip uninstall -y tensorboard pip show tensorflow-text && pip uninstall -y tensorflow-text pip show torch_xla2 && pip uninstall -y torch_xla2 -pip install flax==0.8.3 -pip install -U "jax[cuda12]==0.4.28" +pip install flax==0.8.4 pip install tensorflow-text pip install tensorflow pip install ray[default]==2.22.0 # torch cpu -pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage pip install safetensors colorama coverage humanize @@ -39,3 +37,5 @@ git submodule update --init --recursive pip show google-jetstream && pip uninstall -y google-jetstream pip show torch_xla2 && pip uninstall -y torch_xla2 pip install -e . +pip install -U jax[cuda12]==0.4.30 +pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu diff --git a/run_server.py b/run_server.py index 1ed199a3..102b0156 100644 --- a/run_server.py +++ b/run_server.py @@ -16,8 +16,9 @@ import os from typing import Sequence +# import torch_xla2 first! +import torch_xla2 # pylint: disable import jax -import jetstream_pt from absl import app, flags from jetstream.core import server_lib from jetstream.core.config_lib import ServerConfig, MetricsServerConfig diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 44ae6e31..65ac8913 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -23,6 +23,8 @@ from jetstream_pt.third_party.llama import model_original from jetstream_pt.third_party.gemma import model_original as gemma_orig from jetstream_pt.third_party.gemma import model as gemma +from jetstream_pt.third_party.mixtral import model as mixtral +from jetstream_pt.third_party.mixtral import config as mixtral_config from jetstream_pt import torchjax from jetstream_pt import layers from jetstream_pt import cache_manager @@ -360,6 +362,28 @@ def test_transformer(self): print("Transformer: Diff norm", (result_torch - expected_out).norm()) self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) + def test_mixtral_moe(self): + config = mixtral_config.ModelArgs() + config.intermediate_size = 16 + config.dim = 16 + m = mixtral.ConditionalFeedForward(config) + # random init + states = m.state_dict() + for k, v in states.items(): + states[k].normal_() + m.load_state_dict(states, assign=True) + + seqlen = 3 + num_expert = 8 + num_active_expert = 2 + x = torch.randn(seqlen, config.dim) + exp_index = torch.randint(0, num_expert, (seqlen, num_active_expert)) + + res1 = m.forward_for_short_seq_len(x, exp_index) + res2 = m.forward_for_long_seq_len(x, exp_index) + + torch.testing.assert_close(res1, res2) + if __name__ == "__main__": unittest.main()