Skip to content

Commit 1cf3a06

Browse files
committed
Add mixtral support to new CLI
1 parent 7307541 commit 1cf3a06

File tree

7 files changed

+150
-16
lines changed

7 files changed

+150
-16
lines changed

jetstream_pt/cli.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from jetstream_pt import fetch_models
1212
from jetstream_pt import environment, engine, quantize_model, torchjax
1313
from jetstream_pt import config
14+
from transformers import AutoTokenizer
1415

1516

1617
FLAGS = flags.FLAGS
@@ -25,13 +26,13 @@
2526

2627
def shard_weights(env, weights, weight_shardings):
2728
"""Shard weights according to weight_shardings"""
28-
for k, v in weight_shardings.items():
29-
print("SHARDING", k, v)
3029
sharded = {}
3130
for key, val in weights.items():
3231
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
3332
with jax.default_device(jax.devices("cpu")[0]):
3433
arr = torch_xla2.tensor.t2j(val)
34+
35+
print("SHARDING", key, sharding)
3536
arr = jax.device_put(arr, sharding)
3637
sharded[key] = torchjax.to_torch(arr)
3738
return sharded
@@ -48,7 +49,9 @@ def create_engine(devices):
4849
FLAGS.max_output_length,
4950
quant_config.enable_weight_quantization,
5051
)
52+
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
5153
env = environment.JetEngineEnvironment(env_data)
54+
env.hf_tokenizer = tokenizer
5255
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
5356

5457
weight_shardings = model.get_sharding_annotations()

jetstream_pt/engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jetstream_pt import cache_manager
3737
from jetstream_pt import quantize
3838
from jetstream_pt import torchjax
39+
from jetstream_pt.hf_tokenizer import HFTokenizerAdapter
3940
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig
4041
from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args
4142
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
@@ -705,6 +706,8 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
705706
def build_tokenizer(
706707
self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all
707708
) -> tokenizer_api.Tokenizer:
709+
if self.env.hf_tokenizer is not None:
710+
return HFTokenizerAdapter(self.env.hf_tokenizer)
708711
if "llama-3" in self.env.model_type:
709712
return token_utils.TikToken(metadata)
710713

jetstream_pt/environment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16-
from typing import Tuple
16+
from typing import Tuple, Any
1717

1818
import jax
1919
import jax.numpy as jnp
@@ -141,6 +141,10 @@ def __init__(self, data: JetEngineEnvironmentData):
141141
self.testing_seed = self._data.testing_seed
142142
self.ring_buffer = self._data.ring_buffer
143143

144+
# If not None, then use this tokenizer without
145+
# trying to create new ones.
146+
self.hf_tokenizer = None
147+
144148
if not self.ring_buffer:
145149
self.lazy_cache_update = True
146150
self.ragged_mha = True

jetstream_pt/fetch_models.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
QuantizationConfig,
1313
)
1414
from jetstream_pt.third_party.llama import model_exportable as llama_model
15+
from jetstream_pt.third_party.mixtral import model as mixtral_model
1516

1617
FLAGS = flags.FLAGS
1718

@@ -38,12 +39,15 @@ class ModelInfo:
3839
num_layers: int
3940
num_heads: int
4041
head_dim: int
42+
n_reps: int # repeatition for GQA
4143

4244

43-
_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128)
44-
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128)
45-
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128)
46-
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128)
45+
_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1)
46+
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1)
47+
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 4)
48+
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
49+
50+
_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)
4751

4852

4953
model_id_to_class = {
@@ -57,8 +61,8 @@ class ModelInfo:
5761
"google/gemma-2b-it": None,
5862
"google/gemma-7b": None,
5963
"google/gemma-7b-it": None,
60-
"mistralai/Mixtral-8x7B-v0.1": None,
61-
"mistralai/Mixtral-8x7B-Instruct-v0.1": None,
64+
"mistralai/Mixtral-8x7B-v0.1": _mixtral_87,
65+
"mistralai/Mixtral-8x7B-Instruct-v0.1": _mixtral_87,
6266
}
6367

6468

@@ -107,6 +111,7 @@ def construct_env_data_from_model_id(
107111
else input_length + output_length
108112
)
109113

114+
model_info = model_id_to_class.get(repo_id)
110115
env_data = JetEngineEnvironmentData(
111116
tokenizer_path=tokenizer_path,
112117
checkpoint_path=checkpoint_path,
@@ -119,8 +124,8 @@ def construct_env_data_from_model_id(
119124
bf16_enable=True,
120125
sharding_config_path="",
121126
shard_on_batch=shard_on_batch,
127+
n_reps=model_info.n_reps,
122128
)
123-
model_info = model_id_to_class.get(repo_id)
124129
env_data.cache_shape = (
125130
batch_size,
126131
model_info.num_heads,

jetstream_pt/hf_tokenizer.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from jetstream.engine import tokenizer_api
2+
3+
4+
class HFTokenizerAdapter(tokenizer_api.Tokenizer):
5+
6+
def __init__(self, tokenizer):
7+
self.tokenizer = tokenizer
8+
9+
10+
def encode(
11+
self, s: str, **kwargs
12+
):
13+
"""Tokenize a string.
14+
Args:
15+
s: String to tokenize.
16+
**kwargs: Additional keyword arguments.
17+
Returns:
18+
tokens: Tokenized into integers.
19+
true_length: Actual length of the non-padded sequence
20+
if padding is used.
21+
"""
22+
return self(s)
23+
24+
def decode(self, token_ids: list[int], **kwargs) -> str:
25+
"""Processess input token ids to generate a string.
26+
Args:
27+
token_ids: List of token ids.
28+
**kwargs: Additional keyword arguments.
29+
Returns:
30+
str: String generated from the token ids.
31+
"""
32+
return self.decode(token_ids)
33+
34+
@property
35+
def pad_id(self) -> int:
36+
"""ID of the pad token."""
37+
return (self.tokenizer.pad_token_id
38+
if self.tokenizer.pad_token_id else 0)
39+
40+
41+
@property
42+
def eos_id(self) -> int:
43+
"""ID of EOS token."""
44+
return self.tokenizer.eos_token_id
45+
46+
@property
47+
def bos_id(self) -> int:
48+
"""ID of BOS token."""
49+
return self.tokenizer.bos_token_id
50+
51+
@property
52+
def stop_tokens(self) -> set[int]:
53+
"""ID of the stop token."""
54+
return {self.eos_id, self.pad_id}

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def __init__(
7575
self.annotate_sharding("w1.weight", 0)
7676
self.annotate_sharding("w2.weight", 1)
7777
self.annotate_sharding("w3.weight", 0)
78+
if LinearLayer != torch.nn.Linear:
79+
self.annotate_sharding("w1.weight_scaler", 0)
80+
self.annotate_sharding("w2.weight_scaler", 0)
81+
self.annotate_sharding("w3.weight_scaler", 0)
82+
7883

7984
def forward(self, x):
8085
result = self.w2(F.silu(self.w1(x)) * self.w3(x))

jetstream_pt/third_party/mixtral/model.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from torch.nn import functional as F
2323
from .config import ModelArgs, find_multiple
2424
from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_embedding_layer
25+
from jetstream_pt.model_base import ModuleBase
2526

2627
import jax
2728

2829

29-
class Transformer(nn.Module):
30+
class Transformer(ModuleBase):
3031

3132
def __init__(self, config: ModelArgs, env) -> None:
3233
super().__init__()
@@ -37,6 +38,7 @@ def __init__(self, config: ModelArgs, env) -> None:
3738
self.tok_embeddings = Embedding(
3839
config.vocab_size, config.dim, device=config.device
3940
)
41+
4042
self.layers = nn.ModuleList(
4143
TransformerBlock(config, env, layer_id)
4244
for layer_id in range(config.n_layer)
@@ -47,6 +49,14 @@ def __init__(self, config: ModelArgs, env) -> None:
4749
config.dim, config.vocab_size, bias=False, device=config.device
4850
)
4951

52+
self.hf_name("norm", "model.norm")
53+
self.hf_name("layers", "model.layers")
54+
self.hf_name('output', 'lm_head')
55+
self.hf_name('tok_embeddings', 'model.embed_tokens')
56+
57+
self.annotate_sharding("tok_embeddings.weight", 1)
58+
self.annotate_sharding("output.weight", 0)
59+
5060
self.max_batch_size = -1
5161
self.max_seq_length = -1
5262

@@ -140,8 +150,20 @@ def get_weight_sharding_type():
140150
"output.weight": "ColumnParallelLinear",
141151
}
142152

153+
@classmethod
154+
def from_hf_model_id(cls, model_id, env):
155+
name = {
156+
"mistralai/Mixtral-8x7B-v0.1": "Mixtral-8x7B-v0.1",
157+
"mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-v0.1",
158+
}.get(model_id)
159+
assert name
160+
args = ModelArgs.from_name(name)
161+
args.device = 'meta'
162+
model = cls(args, env)
163+
return model
143164

144-
class TransformerBlock(nn.Module):
165+
166+
class TransformerBlock(ModuleBase):
145167

146168
def __init__(self, config: ModelArgs, env, layer_id) -> None:
147169
super().__init__()
@@ -154,10 +176,37 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None:
154176
device=config.device,
155177
layer_id=layer_id,
156178
)
179+
self.hf_name("attention", "self_attn")
180+
self.attention.hf_name("wq", "q_proj")
181+
self.attention.hf_name("wk", "k_proj")
182+
self.attention.hf_name("wv", "v_proj")
183+
self.attention.hf_name("wo", "o_proj")
184+
185+
self.attention.annotate_sharding("wq", 0)
186+
self.attention.annotate_sharding("wk", 0)
187+
self.attention.annotate_sharding("wv", 0)
188+
self.attention.annotate_sharding("wo", 1)
189+
157190
self.block_sparse_moe = MOEFeedForward(config, config.device, env)
158191
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
159192
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
160193

194+
self.hf_name("attention_norm", "input_layernorm")
195+
self.hf_name("ffn_norm", "post_attention_layernorm")
196+
self._register_load_state_dict_pre_hook(self.load_hook)
197+
198+
def load_hook(self, state_dict, prefix, *args):
199+
if prefix + "block_sparse_moe.experts" in state_dict:
200+
w1s, w2s, w3s = [], [], []
201+
for i in range(8):
202+
exp_prefix = f"{prefix}block_sparse_moe.experts.{i}."
203+
w1s.append(state_dict.pop(exp_prefix + ".w1"))
204+
w2s.append(state_dict.pop(exp_prefix + ".w2"))
205+
w3s.append(state_dict.pop(exp_prefix + ".w3"))
206+
state_dict[prefix + "block_sparse_moe.cond_ffn.w1"] = torch.cat(w1s)
207+
state_dict[prefix + "block_sparse_moe.cond_ffn.w2"] = torch.cat(w2s)
208+
state_dict[prefix + "block_sparse_moe.cond_ffn.w3"] = torch.cat(w3s)
209+
161210
def forward(
162211
self,
163212
x: Tensor,
@@ -189,7 +238,7 @@ def forward(
189238
return out
190239

191240

192-
class Int8ConditionalFeedForward(nn.Module):
241+
class Int8ConditionalFeedForward(ModuleBase):
193242

194243
def __init__(self, config):
195244
super().__init__()
@@ -215,12 +264,20 @@ def __init__(self, config):
215264
self.register_buffer("w2", w2)
216265
self.register_buffer("w3", w3)
217266

267+
self.annotate_sharding("w1", 1)
268+
self.annotate_sharding("w2", 2)
269+
self.annotate_sharding("w3", 1)
270+
218271
w1_scaler = torch.empty(config.num_experts, config.intermediate_size)
219272
w2_scaler = torch.empty(config.num_experts, config.dim)
220273
w3_scaler = torch.empty(config.num_experts, config.intermediate_size)
274+
221275
self.register_buffer("w1_scaler", w1_scaler)
222276
self.register_buffer("w2_scaler", w2_scaler)
223277
self.register_buffer("w3_scaler", w3_scaler)
278+
self.annotate_sharding("w1_scaler", 1)
279+
self.annotate_sharding("w2_scaler", -1)
280+
self.annotate_sharding("w3_scaler", 1)
224281

225282
def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
226283
seq_len = x.shape[0]
@@ -266,7 +323,7 @@ def forward_for_long_seq_len(self, x, expert_indices):
266323
return expert_outs[seq_indexes, expert_indices]
267324

268325

269-
class ConditionalFeedForward(nn.Module):
326+
class ConditionalFeedForward(ModuleBase):
270327

271328
def __init__(self, config):
272329
super().__init__()
@@ -280,6 +337,9 @@ def __init__(self, config):
280337
self.w3 = nn.Parameter(
281338
torch.empty(config.num_experts, config.intermediate_size, config.dim)
282339
)
340+
self.annotate_sharding("w1", 1)
341+
self.annotate_sharding("w2", 2)
342+
self.annotate_sharding("w3", 1)
283343

284344
def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
285345
seq_len = x.shape[0]
@@ -318,7 +378,7 @@ def forward_for_long_seq_len(self, x, expert_indices):
318378
return expert_outs[seq_indexes, expert_indices]
319379

320380

321-
class MOEFeedForward(nn.Module):
381+
class MOEFeedForward(ModuleBase):
322382

323383
def __init__(self, config, device, env) -> None:
324384
super().__init__()
@@ -352,7 +412,7 @@ def forward(self, x: Tensor) -> Tensor:
352412
return expert_outs
353413

354414

355-
class RMSNorm(nn.Module):
415+
class RMSNorm(ModuleBase):
356416

357417
def __init__(self, dim: int, eps: float = 1e-5):
358418
super().__init__()

0 commit comments

Comments
 (0)