Skip to content

Commit c5a86d9

Browse files
committed
Add mixtral support to new CLI
1 parent 7307541 commit c5a86d9

File tree

9 files changed

+194
-23
lines changed

9 files changed

+194
-23
lines changed

jetstream_pt/cli.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from jetstream_pt import fetch_models
1212
from jetstream_pt import environment, engine, quantize_model, torchjax
1313
from jetstream_pt import config
14+
from transformers import AutoTokenizer
1415

1516

1617
FLAGS = flags.FLAGS
@@ -25,13 +26,13 @@
2526

2627
def shard_weights(env, weights, weight_shardings):
2728
"""Shard weights according to weight_shardings"""
28-
for k, v in weight_shardings.items():
29-
print("SHARDING", k, v)
3029
sharded = {}
3130
for key, val in weights.items():
3231
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
3332
with jax.default_device(jax.devices("cpu")[0]):
3433
arr = torch_xla2.tensor.t2j(val)
34+
35+
print("SHARDING", key, sharding)
3536
arr = jax.device_put(arr, sharding)
3637
sharded[key] = torchjax.to_torch(arr)
3738
return sharded
@@ -48,17 +49,17 @@ def create_engine(devices):
4849
FLAGS.max_output_length,
4950
quant_config.enable_weight_quantization,
5051
)
52+
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
5153
env = environment.JetEngineEnvironment(env_data)
54+
env.hf_tokenizer = tokenizer
5255
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
56+
if quant_config.enable_weight_quantization:
57+
quantize_model.quantize_model(model, quant_config)
5358

59+
import pdb; pdb.set_trace()
5460
weight_shardings = model.get_sharding_annotations()
5561
sharded_weights = shard_weights(env, model.state_dict(), weight_shardings)
5662

57-
if quant_config.enable_weight_quantization:
58-
model.load_state_dict(sharded_weights, assign=True, strict=False)
59-
quantize_model.quantize_model(model, quant_config)
60-
sharded_weights = model.state_dict()
61-
6263
return engine.PyTorchEngine(
6364
pt_model=model,
6465
env=env,

jetstream_pt/engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jetstream_pt import cache_manager
3737
from jetstream_pt import quantize
3838
from jetstream_pt import torchjax
39+
from jetstream_pt.hf_tokenizer import HFTokenizerAdapter
3940
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig
4041
from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args
4142
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
@@ -705,6 +706,8 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
705706
def build_tokenizer(
706707
self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all
707708
) -> tokenizer_api.Tokenizer:
709+
if self.env.hf_tokenizer is not None:
710+
return HFTokenizerAdapter(self.env.hf_tokenizer)
708711
if "llama-3" in self.env.model_type:
709712
return token_utils.TikToken(metadata)
710713

jetstream_pt/environment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16-
from typing import Tuple
16+
from typing import Tuple, Any
1717

1818
import jax
1919
import jax.numpy as jnp
@@ -141,6 +141,10 @@ def __init__(self, data: JetEngineEnvironmentData):
141141
self.testing_seed = self._data.testing_seed
142142
self.ring_buffer = self._data.ring_buffer
143143

144+
# If not None, then use this tokenizer without
145+
# trying to create new ones.
146+
self.hf_tokenizer = None
147+
144148
if not self.ring_buffer:
145149
self.lazy_cache_update = True
146150
self.ragged_mha = True

jetstream_pt/fetch_models.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
QuantizationConfig,
1313
)
1414
from jetstream_pt.third_party.llama import model_exportable as llama_model
15+
from jetstream_pt.third_party.mixtral import model as mixtral_model
1516

1617
FLAGS = flags.FLAGS
1718

@@ -38,12 +39,15 @@ class ModelInfo:
3839
num_layers: int
3940
num_heads: int
4041
head_dim: int
42+
n_reps: int # repeatition for GQA
4143

4244

43-
_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128)
44-
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128)
45-
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128)
46-
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128)
45+
_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1)
46+
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1)
47+
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 4)
48+
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
49+
50+
_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)
4751

4852

4953
model_id_to_class = {
@@ -57,8 +61,8 @@ class ModelInfo:
5761
"google/gemma-2b-it": None,
5862
"google/gemma-7b": None,
5963
"google/gemma-7b-it": None,
60-
"mistralai/Mixtral-8x7B-v0.1": None,
61-
"mistralai/Mixtral-8x7B-Instruct-v0.1": None,
64+
"mistralai/Mixtral-8x7B-v0.1": _mixtral_87,
65+
"mistralai/Mixtral-8x7B-Instruct-v0.1": _mixtral_87,
6266
}
6367

6468

@@ -107,6 +111,7 @@ def construct_env_data_from_model_id(
107111
else input_length + output_length
108112
)
109113

114+
model_info = model_id_to_class.get(repo_id)
110115
env_data = JetEngineEnvironmentData(
111116
tokenizer_path=tokenizer_path,
112117
checkpoint_path=checkpoint_path,
@@ -119,8 +124,8 @@ def construct_env_data_from_model_id(
119124
bf16_enable=True,
120125
sharding_config_path="",
121126
shard_on_batch=shard_on_batch,
127+
n_reps=model_info.n_reps,
122128
)
123-
model_info = model_id_to_class.get(repo_id)
124129
env_data.cache_shape = (
125130
batch_size,
126131
model_info.num_heads,

jetstream_pt/hf_tokenizer.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from jetstream.engine import tokenizer_api
2+
3+
4+
class HFTokenizerAdapter(tokenizer_api.Tokenizer):
5+
6+
def __init__(self, tokenizer):
7+
self.tokenizer = tokenizer
8+
9+
10+
def encode(
11+
self, s: str, **kwargs
12+
):
13+
"""Tokenize a string.
14+
Args:
15+
s: String to tokenize.
16+
**kwargs: Additional keyword arguments.
17+
Returns:
18+
tokens: Tokenized into integers.
19+
true_length: Actual length of the non-padded sequence
20+
if padding is used.
21+
"""
22+
return self(s)
23+
24+
def decode(self, token_ids: list[int], **kwargs) -> str:
25+
"""Processess input token ids to generate a string.
26+
Args:
27+
token_ids: List of token ids.
28+
**kwargs: Additional keyword arguments.
29+
Returns:
30+
str: String generated from the token ids.
31+
"""
32+
return self.decode(token_ids)
33+
34+
@property
35+
def pad_id(self) -> int:
36+
"""ID of the pad token."""
37+
return (self.tokenizer.pad_token_id
38+
if self.tokenizer.pad_token_id else 0)
39+
40+
41+
@property
42+
def eos_id(self) -> int:
43+
"""ID of EOS token."""
44+
return self.tokenizer.eos_token_id
45+
46+
@property
47+
def bos_id(self) -> int:
48+
"""ID of BOS token."""
49+
return self.tokenizer.bos_token_id
50+
51+
@property
52+
def stop_tokens(self) -> set[int]:
53+
"""ID of the stop token."""
54+
return {self.eos_id, self.pad_id}

jetstream_pt/layers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,20 @@ def get_quantized_embedding_layer(config: "QuantizationConfig"):
319319
else:
320320
return Int8Embedding
321321

322+
def create_quantized_from_nn_embedding(
323+
float_embedding: nn.Embedding, config: "QuantizationConfig"
324+
):
325+
clazz_ = get_quantized_embedding_layer(config)
326+
obj = clazz_(
327+
float_embedding.num_embeddings,
328+
float_embedding.embedding_dim,
329+
)
330+
weights, scaler, _ = quantize_tensor(
331+
float_embedding.weight, 1)
332+
obj.weight = weights
333+
obj.scaler = scaler
334+
return obj
335+
322336

323337
class RMSNorm(torch.nn.Module):
324338
"""RMSNorm module."""

jetstream_pt/quantize_model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
import torch
2-
from .layers import create_quantized_from_nn_linear
2+
from .layers import (create_quantized_from_nn_linear,
3+
create_quantized_from_nn_embedding)
34

45

56
def quantize_model(float_model, config):
67
"""Apply quantization to linear layers."""
78

89
def quantize_nn_mod(float_model):
910
for name, mod in float_model.named_modules():
10-
if isinstance(mod, torch.nn.Linear):
11+
new_mod = None
12+
if hasattr(mod, 'get_quantized_version'):
13+
new_mod = mod.get_quantized_version()
14+
elif isinstance(mod, torch.nn.Linear):
1115
new_mod = create_quantized_from_nn_linear(mod, config)
16+
elif isinstance(mod, torch.nn.Embedding):
17+
new_mod = create_quantized_from_nn_embedding(mod, config)
18+
19+
if new_mod:
1220
setattr(float_model, name, new_mod)
21+
1322

1423
float_model.apply(quantize_nn_mod)
1524
return float_model

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def __init__(
7575
self.annotate_sharding("w1.weight", 0)
7676
self.annotate_sharding("w2.weight", 1)
7777
self.annotate_sharding("w3.weight", 0)
78+
if LinearLayer != torch.nn.Linear:
79+
self.annotate_sharding("w1.weight_scaler", 0)
80+
self.annotate_sharding("w2.weight_scaler", 0)
81+
self.annotate_sharding("w3.weight_scaler", 0)
82+
7883

7984
def forward(self, x):
8085
result = self.w2(F.silu(self.w1(x)) * self.w3(x))

0 commit comments

Comments
 (0)