Skip to content

Commit b0f95f3

Browse files
committed
Add a test for Moe layer, modify install script to make GPU run
1 parent c3293c4 commit b0f95f3

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

install_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ pip show google-jetstream && pip uninstall -y google-jetstream
3939
pip show torch_xla2 && pip uninstall -y torch_xla2
4040
pip install -e .
4141
pip install -U jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
42+
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu

install_everything_gpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ pip install tensorflow
3030

3131
pip install ray[default]==2.22.0
3232
# torch cpu
33-
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu
3433
pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
3534
pip install safetensors colorama coverage humanize
3635

@@ -39,3 +38,4 @@ pip show google-jetstream && pip uninstall -y google-jetstream
3938
pip show torch_xla2 && pip uninstall -y torch_xla2
4039
pip install -e .
4140
pip install -U jax[cuda12]==0.4.30
41+
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu

tests/test_model_impl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from jetstream_pt.third_party.llama import model_original
2424
from jetstream_pt.third_party.gemma import model_original as gemma_orig
2525
from jetstream_pt.third_party.gemma import model as gemma
26-
from jetstream_pt.third_party.mixtral import model as mixtral
26+
from jetstream_pt.third_party.mixtral import model as mixtral
2727
from jetstream_pt.third_party.mixtral import config as mixtral_config
2828
from jetstream_pt import torchjax
2929
from jetstream_pt import layers
@@ -370,19 +370,19 @@ def test_mixtral_moe(self):
370370
# random init
371371
states = m.state_dict()
372372
for k, v in states.items():
373-
states[k].normal_()
373+
states[k].normal_()
374374
m.load_state_dict(states, assign=True)
375375

376376
seqlen = 3
377-
num_expert = 8
377+
num_expert = 8
378378
num_active_expert = 2
379-
x = torch.randn(10, config.dim)
379+
x = torch.randn(seqlen, config.dim)
380380
exp_index = torch.randint(0, num_expert, (seqlen, num_active_expert))
381381

382382
res1 = m.forward_for_short_seq_len(x, exp_index)
383383
res2 = m.forward_for_long_seq_len(x, exp_index)
384384

385-
torch.testing.assert_close(res1, res2, atol=1e-4, rtol=1e-4)
385+
torch.testing.assert_close(res1, res2)
386386

387387

388388
if __name__ == "__main__":

0 commit comments

Comments
 (0)