@@ -14,13 +14,9 @@ Commandline Flags might have changed between the release version to HEAD.
14
14
1 . Ssh to Cloud TPU VM (using v5e-8 TPU VM)
15
15
a. Create a Cloud TPU VM if you haven’t
16
16
2 . Download jetstream-pytorch github repo
17
- 3 . Clone repo and install dependencies
18
- 4 . Download and convert weights
19
- 5 . Run checkpoint converter (quantizer)
20
- 6 . Local run
21
- 7 . Run the server
22
- 8 . Run benchmarks
23
- 9 . Typical Errors
17
+ 3 . Run the server
18
+ 4 . Run benchmarks
19
+ 5 . Typical Errors
24
20
25
21
# Ssh to Cloud TPU VM (using v5e-8 TPU VM)
26
22
@@ -49,108 +45,69 @@ cd jetstream-pytorch
49
45
source install_everything.sh
50
46
```
51
47
52
- # Download and convert weights
53
48
54
- ## LLaMA
55
- ### Get official llama weights from meta-llama
49
+ # Run jetstream pytorch
56
50
57
- Following instructions here:
58
- * Llama-2: https://github.com/meta-llama/llama#download
59
- * Llama-3: https://github.com/meta-llama/llama3/#download
51
+ ## List out supported models
60
52
61
- After you have downloaded the weights, it will also download a ` tokenizer.model ` file that is
62
- the tokenizer that we will use.
63
-
64
- ## Gemma
65
- ### Get Gemma Checkpoint from HuggingFace
66
-
67
- Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.
68
-
69
- ``` bash
70
- # Install huggingface-cli and login if it's not set up.
71
- pip install -U " huggingface_hub[cli]"
72
- huggingface-cli login
73
- huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
74
53
```
75
-
76
- ## Mixtral
77
- ### Get Mixtral Checkpoint from HuggingFace
78
-
79
- Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint.
80
-
81
- ``` bash
82
- huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir
54
+ jpt list
83
55
```
84
56
85
- ## Run weight safetensor convert
57
+ This will print out list of support models and variants:
86
58
87
- There are limited support (only Llama models as of now) for accessing checkpoints on GCS. Accessing GCS takes a long time and therefore storing checkpoints to local is recommended.
88
-
89
- ``` bash
90
- export input_ckpt_dir=Original llama weights directory
91
- export output_ckpt_dir=The output directory
92
- export model_name=" llama-3" # or "llama-2", "gemma", "mixtral"
93
- export quantize_weights=True # Whether to quantize weights
94
- export quantize_type=" int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
95
- python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type --quantize_weights=$quantize_weights
96
59
```
97
-
98
-
99
- # Local run
100
-
101
- Set tokenizer path
102
- ``` bash
103
- export tokenizer_path=tokenizer model file path
60
+ meta-llama/Llama-2-7b-chat-hf
61
+ meta-llama/Llama-2-7b-hf
62
+ meta-llama/Llama-2-13b-chat-hf
63
+ meta-llama/Llama-2-13b-hf
64
+ meta-llama/Llama-2-70b-hf
65
+ meta-llama/Llama-2-70b-chat-hf
66
+ meta-llama/Meta-Llama-3-8B
67
+ meta-llama/Meta-Llama-3-8B-Instruct
68
+ meta-llama/Meta-Llama-3-70B
69
+ meta-llama/Meta-Llama-3-70B-Instruct
70
+ google/gemma-2b
71
+ google/gemma-2b-it
72
+ google/gemma-7b
73
+ google/gemma-7b-it
74
+ mistralai/Mixtral-8x7B-v0.1
75
+ mistralai/Mixtral-8x7B-Instruct-v0.1
104
76
```
105
77
106
- ## Llama-2 7b
107
- ``` bash
108
- python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
109
- ```
78
+ To run jetstream-pytorch server with one model:
110
79
111
- ## Llama-2 13b
112
- ``` bash
113
- python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
114
80
```
115
-
116
- ## Llama-3 8b
117
- ``` bash
118
- python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
81
+ jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct
119
82
```
120
83
121
- ## Llama-3 70b
122
- ``` bash
123
- python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
124
- ```
84
+ If it the first time you run this model, it will download weights from
85
+ HuggingFace.
125
86
126
- ## Gemma 7b
127
- ``` bash
128
- python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name .yaml
129
- ```
87
+ HuggingFace's Llama3 weights are gated, so you need to either run
88
+ ` huggingface-cli login ` to set your token, OR, pass your hf_token explicitly.
130
89
131
- ## Mixtral 8x7b
132
- ``` bash
133
- python run_interactive.py --model_name= $model_name --batch_size=128 --max_cache_length=2048 --quantize_weights= $quantize_weights --quantize_type= $quantize_type --quantize_kv_cache= $quantize_weights --checkpoint_path= $output_ckpt_dir --tokenizer_path= $tokenizer_path --sharding_config=default_shardings/ $model_name .yaml
90
+ To pass hf token, add ` --hf_token ` flag
91
+ ```
92
+ jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=...
134
93
```
135
94
95
+ To login using huggingface hub, run:
136
96
137
- # Run the server
138
- Here is an example to run the server with llama2 7B config.
139
-
140
- ``` bash
141
- python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=" default_shardings/llama.yaml"
142
97
```
98
+ pip install -U "huggingface_hub[cli]"
99
+ huggingface-cli login
100
+ ```
101
+ Then follow its prompt.
143
102
144
- Now you can fire gRPC to it.
103
+ After the weights are downloaded,
104
+ Next time when you run this ` --hf_token ` will no longer be required.
145
105
146
- Optional flags:
147
- * ` --shard_on_batch=1 ` This makes the model to shard on
148
- the batch dimension. I.e. this runs in data parallel mode instead of model
149
- parallel. This will ignore the sharding config. This is recommended for Gemma 2B
150
- model, because Gemma 2B is small enough to fit on a single TPU chip.
106
+ To run this model in ` int8 ` quantization, add ` --quantize_weights=1 ` .
107
+ Quantization will be done on the flight as the weight loads.
151
108
152
- * ` --sharding_config=<path> ` This makes use of alternative sharding config instead of
153
- the ones in default_shardings directory .
109
+ Weights downloaded from HuggingFace will be stored by default in ` checkpoints ` folder.
110
+ in the place where ` jpt ` is executed .
154
111
155
112
156
113
# Run the server with ray
0 commit comments