Skip to content

Commit d3547f9

Browse files
authored
Add test for Mixtral model. (#131)
* Add test for Mixtral model. * Fix per comments.
1 parent fa1f120 commit d3547f9

File tree

4 files changed

+116
-10
lines changed

4 files changed

+116
-10
lines changed

jetstream_pt/third_party/mixtral/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,15 @@ def from_name(cls, name: str):
7575
num_experts=8,
7676
num_activated_experts=2,
7777
),
78+
"Mixtral-tiny": dict(
79+
block_size=128,
80+
n_layer=3,
81+
n_head=32,
82+
n_local_heads=8,
83+
dim=128,
84+
intermediate_size=None,
85+
rope_base=1000000.0,
86+
num_experts=8,
87+
num_activated_experts=2,
88+
),
7889
}

jetstream_pt/third_party/mixtral/model_original.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
from .config import ModelArgs, find_multiple
2424

2525

26+
def find_multiple(n: int, k: int) -> int:
27+
if n % k == 0:
28+
return n
29+
return n + k - (n % k)
30+
31+
2632
class KVCache(nn.Module):
2733

2834
def __init__(
@@ -31,7 +37,8 @@ def __init__(
3137
max_seq_length,
3238
n_heads,
3339
head_dim,
34-
dtype=torch.bfloat16,
40+
# dtype=torch.bfloat16,
41+
dtype=torch.float32,
3542
):
3643
super().__init__()
3744
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
@@ -191,22 +198,21 @@ class ConditionalFeedForward(nn.Module):
191198

192199
def __init__(self, config):
193200
super().__init__()
201+
# Replace the weight init of torch.empty with torch.rand for testing purpose
194202
self.w1 = nn.Parameter(
195-
torch.empty(config.num_experts, config.intermediate_size, config.dim)
203+
torch.rand(config.num_experts, config.intermediate_size, config.dim)
196204
)
197205
self.w2 = nn.Parameter(
198-
torch.empty(config.num_experts, config.dim, config.intermediate_size)
206+
torch.rand(config.num_experts, config.dim, config.intermediate_size)
199207
)
200208
self.w3 = nn.Parameter(
201-
torch.empty(config.num_experts, config.intermediate_size, config.dim)
209+
torch.rand(config.num_experts, config.intermediate_size, config.dim)
202210
)
203211

204212
def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
205-
# T = num_tokens, I = intermediate size, D = hidden dim, A = activated experts
206213
w1_weights = self.w1[expert_indices] # [T, A, D, D]
207214
w3_weights = self.w3[expert_indices] # [T, A, D, D]
208215
w2_weights = self.w2[expert_indices] # [T, A, D, D]
209-
# x: [T, D]
210216
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
211217
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
212218
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
@@ -215,7 +221,7 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
215221

216222
class MOEFeedForward(nn.Module):
217223

218-
def __init__(self, config, env=None) -> None:
224+
def __init__(self, config) -> None:
219225
super().__init__()
220226
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
221227
self.cond_ffn = ConditionalFeedForward(config)
@@ -261,7 +267,8 @@ def precompute_freqs_cis(
261267
freqs = torch.outer(t, freqs)
262268
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
263269
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
264-
return cache.to(dtype=torch.bfloat16)
270+
# return cache.to(dtype=torch.bfloat16)
271+
return cache
265272

266273

267274
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:

tests/helpers.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import jax
22
import torch
33
import torch_xla2
4-
import jax
54
from jetstream_pt.third_party.llama import model_args
5+
from jetstream_pt.third_party.mixtral import config as mixtral_config
66
from jetstream_pt import environment
77

88

@@ -31,13 +31,38 @@ def make_env_tiny(bf16_enable=True):
3131
return env, config
3232

3333

34+
def make_mixtral_env(bf16_enable=True):
35+
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
36+
torch.set_default_dtype(torch_dtype)
37+
jax.config.update("jax_dynamic_shapes", False)
38+
jax.config.update("jax_traceback_filtering", "off")
39+
config = mixtral_config.ModelArgs.from_name("Mixtral-tiny")
40+
environment_data = environment.JetEngineEnvironmentData()
41+
environment_data.max_input_sequence_length = 128
42+
environment_data.cache_sequence_length = 128
43+
environment_data.bf16_enable = bf16_enable
44+
environment_data.model_type = "mixtral"
45+
environment_data.batch_size = 1
46+
environment_data.num_layers = config.n_layer
47+
environment_data.cache_shape = (
48+
1,
49+
config.n_local_heads,
50+
environment_data.cache_sequence_length,
51+
config.dim // config.n_head,
52+
)
53+
env = environment.JetEngineEnvironment(environment_data)
54+
env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
55+
return env, config
56+
57+
3458
def to_xla_tensor(tree):
3559
return torch_xla2.default_env().to_xla(tree)
3660

3761

3862
def call_xla_model(model, weights, args):
3963
with jax.default_device(jax.devices("cpu")[0]):
4064
xla_weights, xla_inputs = to_xla_tensor((weights, args))
41-
result = torch.func.functional_call(model, xla_weights, xla_inputs)
65+
with torch_xla2.default_env():
66+
result = torch.func.functional_call(model, xla_weights, xla_inputs)
4267
result_torch = torch_xla2.tensor.j2t(result._elem)
4368
return result_torch

tests/test_model_impl.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from jetstream_pt.third_party.llama import model_original
2424
from jetstream_pt.third_party.gemma import model_original as gemma_orig
2525
from jetstream_pt.third_party.gemma import model as gemma
26+
from jetstream_pt.third_party.mixtral import model_original as mixtral_orig
2627
from jetstream_pt.third_party.mixtral import model as mixtral
2728
from jetstream_pt.third_party.mixtral import config as mixtral_config
2829
from jetstream_pt import torchjax
@@ -362,6 +363,68 @@ def test_transformer(self):
362363
print("Transformer: Diff norm", (result_torch - expected_out).norm())
363364
self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4))
364365

366+
# pylint: disable-next=all
367+
def test_mixtral_transformer(self):
368+
"""test transformer diff between original model vs xla_model"""
369+
env, model_arg = helpers.make_mixtral_env(False)
370+
371+
model_orig = mixtral_orig.Transformer(model_arg)
372+
model_orig.setup_caches(max_batch_size=1, max_seq_length=env.cache_len)
373+
374+
state_dict = dict(model_orig.state_dict())
375+
state_dict["freqs_cis"] = model_orig.freqs_cis
376+
new_dict = {}
377+
378+
for k, v in state_dict.items():
379+
if "kv_cache" in k:
380+
continue
381+
if "wqkv" in k:
382+
wq = k.replace("wqkv", "wq")
383+
wk = k.replace("wqkv", "wk")
384+
wv = k.replace("wqkv", "wv")
385+
kv_size = model_arg.n_local_heads * model_arg.head_dim
386+
wq_t, wk_t, wv_t = v.split([model_arg.dim, kv_size, kv_size], dim=0)
387+
388+
new_dict[wq] = wq_t
389+
new_dict[wk] = wk_t
390+
new_dict[wv] = wv_t
391+
continue
392+
# "Freqs_cis" for exported model is calculated differently, by complex data type
393+
if "freqs_cis" in k:
394+
new_dict[k] = mixtral.precompute_freqs_cis(
395+
model_arg.block_size,
396+
model_arg.dim // model_arg.n_head,
397+
model_arg.rope_base,
398+
)
399+
continue
400+
new_dict[k] = v
401+
402+
model_ours = mixtral.Transformer(model_arg, env)
403+
404+
# Invoke original model
405+
seqlen = 32
406+
x = torch.randint(0, 32000, (1, seqlen)) # (batch, seqlen, embedding dim)
407+
start_pos = 0
408+
mask = self._prefill_mask(seqlen, start_pos)
409+
input_pos = torch.arange(0, seqlen)
410+
inputs_orig = (x, input_pos)
411+
412+
expected_out = model_orig(*inputs_orig)
413+
414+
# Invoke the exported model
415+
caches = env.make_caches_prefill()
416+
417+
input_ours = (
418+
x,
419+
input_pos,
420+
caches,
421+
mask,
422+
)
423+
result_torch = helpers.call_xla_model(model_ours, new_dict, input_ours)
424+
425+
print("Transformer: Diff norm", (result_torch - expected_out).norm())
426+
self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4))
427+
365428
def test_mixtral_moe(self):
366429
config = mixtral_config.ModelArgs()
367430
config.intermediate_size = 16

0 commit comments

Comments
 (0)