Skip to content

Commit 56da057

Browse files
committed
install cli into a command
1 parent 321f5aa commit 56da057

File tree

5 files changed

+83
-106
lines changed

5 files changed

+83
-106
lines changed

README.md

Lines changed: 43 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,9 @@ Commandline Flags might have changed between the release version to HEAD.
1414
1. Ssh to Cloud TPU VM (using v5e-8 TPU VM)
1515
a. Create a Cloud TPU VM if you haven’t
1616
2. Download jetstream-pytorch github repo
17-
3. Clone repo and install dependencies
18-
4. Download and convert weights
19-
5. Run checkpoint converter (quantizer)
20-
6. Local run
21-
7. Run the server
22-
8. Run benchmarks
23-
9. Typical Errors
17+
3. Run the server
18+
4. Run benchmarks
19+
5. Typical Errors
2420

2521
# Ssh to Cloud TPU VM (using v5e-8 TPU VM)
2622

@@ -49,108 +45,69 @@ cd jetstream-pytorch
4945
source install_everything.sh
5046
```
5147

52-
# Download and convert weights
5348

54-
## LLaMA
55-
### Get official llama weights from meta-llama
49+
# Run jetstream pytorch
5650

57-
Following instructions here:
58-
* Llama-2: https://github.com/meta-llama/llama#download
59-
* Llama-3: https://github.com/meta-llama/llama3/#download
51+
## List out supported models
6052

61-
After you have downloaded the weights, it will also download a `tokenizer.model` file that is
62-
the tokenizer that we will use.
63-
64-
## Gemma
65-
### Get Gemma Checkpoint from HuggingFace
66-
67-
Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.
68-
69-
```bash
70-
# Install huggingface-cli and login if it's not set up.
71-
pip install -U "huggingface_hub[cli]"
72-
huggingface-cli login
73-
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
7453
```
75-
76-
## Mixtral
77-
### Get Mixtral Checkpoint from HuggingFace
78-
79-
Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint.
80-
81-
```bash
82-
huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir
54+
jpt list
8355
```
8456

85-
## Run weight safetensor convert
57+
This will print out list of support models and variants:
8658

87-
There are limited support (only Llama models as of now) for accessing checkpoints on GCS. Accessing GCS takes a long time and therefore storing checkpoints to local is recommended.
88-
89-
```bash
90-
export input_ckpt_dir=Original llama weights directory
91-
export output_ckpt_dir=The output directory
92-
export model_name="llama-3" # or "llama-2", "gemma", "mixtral"
93-
export quantize_weights=True # Whether to quantize weights
94-
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
95-
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type --quantize_weights=$quantize_weights
9659
```
97-
98-
99-
# Local run
100-
101-
Set tokenizer path
102-
```bash
103-
export tokenizer_path=tokenizer model file path
60+
meta-llama/Llama-2-7b-chat-hf
61+
meta-llama/Llama-2-7b-hf
62+
meta-llama/Llama-2-13b-chat-hf
63+
meta-llama/Llama-2-13b-hf
64+
meta-llama/Llama-2-70b-hf
65+
meta-llama/Llama-2-70b-chat-hf
66+
meta-llama/Meta-Llama-3-8B
67+
meta-llama/Meta-Llama-3-8B-Instruct
68+
meta-llama/Meta-Llama-3-70B
69+
meta-llama/Meta-Llama-3-70B-Instruct
70+
google/gemma-2b
71+
google/gemma-2b-it
72+
google/gemma-7b
73+
google/gemma-7b-it
74+
mistralai/Mixtral-8x7B-v0.1
75+
mistralai/Mixtral-8x7B-Instruct-v0.1
10476
```
10577

106-
## Llama-2 7b
107-
```bash
108-
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
109-
```
78+
To run jetstream-pytorch server with one model:
11079

111-
## Llama-2 13b
112-
```bash
113-
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
11480
```
115-
116-
## Llama-3 8b
117-
```bash
118-
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
81+
jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct
11982
```
12083

121-
## Llama-3 70b
122-
```bash
123-
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
124-
```
84+
If it the first time you run this model, it will download weights from
85+
HuggingFace.
12586

126-
## Gemma 7b
127-
```bash
128-
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
129-
```
87+
HuggingFace's Llama3 weights are gated, so you need to either run
88+
`huggingface-cli login` to set your token, OR, pass your hf_token explicitly.
13089

131-
## Mixtral 8x7b
132-
```bash
133-
python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
90+
To pass hf token, add `--hf_token` flag
91+
```
92+
jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=...
13493
```
13594

95+
To login using huggingface hub, run:
13696

137-
# Run the server
138-
Here is an example to run the server with llama2 7B config.
139-
140-
```bash
141-
python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
14297
```
98+
pip install -U "huggingface_hub[cli]"
99+
huggingface-cli login
100+
```
101+
Then follow its prompt.
143102

144-
Now you can fire gRPC to it.
103+
After the weights are downloaded,
104+
Next time when you run this `--hf_token` will no longer be required.
145105

146-
Optional flags:
147-
* `--shard_on_batch=1` This makes the model to shard on
148-
the batch dimension. I.e. this runs in data parallel mode instead of model
149-
parallel. This will ignore the sharding config. This is recommended for Gemma 2B
150-
model, because Gemma 2B is small enough to fit on a single TPU chip.
106+
To run this model in `int8` quantization, add `--quantize_weights=1`.
107+
Quantization will be done on the flight as the weight loads.
151108

152-
* `--sharding_config=<path>` This makes use of alternative sharding config instead of
153-
the ones in default_shardings directory.
109+
Weights downloaded from HuggingFace will be stored by default in `checkpoints` folder.
110+
in the place where `jpt` is executed.
154111

155112

156113
# Run the server with ray

jetstream_pt/cli.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def shard_weights(env, weights, weight_shardings):
3333
sharded = {}
3434
for key, val in weights.items():
3535
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
36+
print("SHARDING", key, sharding)
3637
with jax.default_device(jax.devices("cpu")[0]):
3738
arr = torch_xla2.tensor.t2j(val)
3839

39-
print("SHARDING", key, sharding)
4040
arr = jax.device_put(arr, sharding)
4141
sharded[key] = torchjax.to_torch(arr)
4242
return sharded
@@ -207,22 +207,25 @@ def interactive():
207207
print(tokenizer.decode(sampled_tokens_list))
208208

209209

210-
def main(argv):
211-
"""Entry point"""
212-
if len(argv) < 2:
213-
print("Invalid arguments. please specify 'list' or 'serve'")
214-
215-
if argv[1] == "list":
216-
list_model()
217-
elif argv[1] == "serve":
218-
serve()
219-
elif argv[1] == "interactive":
220-
interactive()
221-
else:
222-
print(
223-
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
224-
)
210+
def main():
211+
def main_real(argv):
212+
"""Entry point"""
213+
if len(argv) < 2:
214+
print("Invalid arguments. please specify 'list' or 'serve'")
215+
216+
if argv[1] == "list":
217+
list_model()
218+
elif argv[1] == "serve":
219+
serve()
220+
elif argv[1] == "interactive":
221+
interactive()
222+
else:
223+
print(
224+
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
225+
)
226+
app.run(main_real)
227+
return 0
225228

226229

227230
if __name__ == "__main__":
228-
app.run(main)
231+
main()

jetstream_pt/fetch_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,18 @@ class ModelInfo:
3838
model_class: torch.nn.Module
3939
# information needed to allocate cache
4040
num_layers: int
41+
# number of kv heads
4142
num_heads: int
43+
4244
head_dim: int
4345
n_reps: int # repeatition for GQA
4446

4547

4648
_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1)
4749
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1)
48-
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 4)
50+
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8)
4951
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
52+
_llama3_70 = _llama2_70
5053

5154
_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)
5255

@@ -59,8 +62,12 @@ class ModelInfo:
5962
"meta-llama/Llama-2-7b-hf": _llama2_7,
6063
"meta-llama/Llama-2-13b-chat-hf": _llama2_13,
6164
"meta-llama/Llama-2-13b-hf": _llama2_13,
65+
"meta-llama/Llama-2-70b-hf": _llama2_70,
66+
"meta-llama/Llama-2-70b-chat-hf": _llama2_70,
6267
"meta-llama/Meta-Llama-3-8B": _llama3_8,
6368
"meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8,
69+
"meta-llama/Meta-Llama-3-70B": _llama3_70,
70+
"meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70,
6471
"google/gemma-2b": _gemma_2b,
6572
"google/gemma-2b-it": _gemma_2b,
6673
"google/gemma-7b": _gemma_7b,

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,12 @@ def from_hf_model_id(cls, model_id, env):
347347
"meta-llama/Llama-2-7b-hf": "llama-2-7b",
348348
"meta-llama/Llama-2-13b-chat-hf": "llama-2-13b",
349349
"meta-llama/Llama-2-13b-hf": "llama-2-13b",
350+
"meta-llama/Llama-2-70b-hf": "llama-2-70b",
351+
"meta-llama/Llama-2-70b-chat-hf": "llama-2-70b",
350352
"meta-llama/Meta-Llama-3-8B": "llama-3-8b",
351353
"meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b",
354+
"meta-llama/Meta-Llama-3-70B": "llama-3-70b",
355+
"meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b",
352356
}.get(model_id)
353357
assert name
354358
args = model_args.get_model_args(
@@ -380,4 +384,6 @@ def transform(val, n_heads):
380384
updated[key] = transform(
381385
value, self.params.n_kv_heads or self.params.n_heads
382386
)
383-
return super().convert_hf_weights(updated)
387+
res = super().convert_hf_weights(updated)
388+
res['freqs_cis'] = self.freqs_cis
389+
return res

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ dependencies = [
1818
"google-jetstream @ {root:uri}/deps/JetStream",
1919
]
2020

21+
2122
requires-python = ">=3.10"
2223
license = {file = "LICENSE"}
2324

25+
[project.scripts]
26+
jpt = "jetstream_pt.cli:main"
27+
2428
[tool.hatch.metadata]
25-
allow-direct-references = true
29+
allow-direct-references = true

0 commit comments

Comments
 (0)