-
Notifications
You must be signed in to change notification settings - Fork 17
Update README for new CLI #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,13 +14,9 @@ Commandline Flags might have changed between the release version to HEAD. | |
1. Ssh to Cloud TPU VM (using v5e-8 TPU VM) | ||
a. Create a Cloud TPU VM if you haven’t | ||
2. Download jetstream-pytorch github repo | ||
3. Clone repo and install dependencies | ||
4. Download and convert weights | ||
5. Run checkpoint converter (quantizer) | ||
6. Local run | ||
7. Run the server | ||
8. Run benchmarks | ||
9. Typical Errors | ||
3. Run the server | ||
4. Run benchmarks | ||
5. Typical Errors | ||
|
||
# Ssh to Cloud TPU VM (using v5e-8 TPU VM) | ||
|
||
|
@@ -49,108 +45,69 @@ cd jetstream-pytorch | |
source install_everything.sh | ||
``` | ||
|
||
# Download and convert weights | ||
|
||
## LLaMA | ||
### Get official llama weights from meta-llama | ||
# Run jetstream pytorch | ||
|
||
Following instructions here: | ||
* Llama-2: https://github.com/meta-llama/llama#download | ||
* Llama-3: https://github.com/meta-llama/llama3/#download | ||
## List out supported models | ||
|
||
After you have downloaded the weights, it will also download a `tokenizer.model` file that is | ||
the tokenizer that we will use. | ||
|
||
## Gemma | ||
### Get Gemma Checkpoint from HuggingFace | ||
|
||
Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint. | ||
|
||
```bash | ||
# Install huggingface-cli and login if it's not set up. | ||
pip install -U "huggingface_hub[cli]" | ||
huggingface-cli login | ||
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir | ||
``` | ||
|
||
## Mixtral | ||
### Get Mixtral Checkpoint from HuggingFace | ||
|
||
Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint. | ||
|
||
```bash | ||
huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir | ||
jpt list | ||
``` | ||
|
||
## Run weight safetensor convert | ||
This will print out list of support models and variants: | ||
|
||
There are limited support (only Llama models as of now) for accessing checkpoints on GCS. Accessing GCS takes a long time and therefore storing checkpoints to local is recommended. | ||
|
||
```bash | ||
export input_ckpt_dir=Original llama weights directory | ||
export output_ckpt_dir=The output directory | ||
export model_name="llama-3" # or "llama-2", "gemma", "mixtral" | ||
export quantize_weights=True # Whether to quantize weights | ||
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified. | ||
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type --quantize_weights=$quantize_weights | ||
``` | ||
|
||
|
||
# Local run | ||
|
||
Set tokenizer path | ||
```bash | ||
export tokenizer_path=tokenizer model file path | ||
meta-llama/Llama-2-7b-chat-hf | ||
meta-llama/Llama-2-7b-hf | ||
meta-llama/Llama-2-13b-chat-hf | ||
meta-llama/Llama-2-13b-hf | ||
meta-llama/Llama-2-70b-hf | ||
meta-llama/Llama-2-70b-chat-hf | ||
meta-llama/Meta-Llama-3-8B | ||
meta-llama/Meta-Llama-3-8B-Instruct | ||
meta-llama/Meta-Llama-3-70B | ||
meta-llama/Meta-Llama-3-70B-Instruct | ||
google/gemma-2b | ||
google/gemma-2b-it | ||
google/gemma-7b | ||
google/gemma-7b-it | ||
mistralai/Mixtral-8x7B-v0.1 | ||
mistralai/Mixtral-8x7B-Instruct-v0.1 | ||
``` | ||
|
||
## Llama-2 7b | ||
```bash | ||
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml | ||
``` | ||
To run jetstream-pytorch server with one model: | ||
|
||
## Llama-2 13b | ||
```bash | ||
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml | ||
``` | ||
|
||
## Llama-3 8b | ||
```bash | ||
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml | ||
jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct | ||
``` | ||
|
||
## Llama-3 70b | ||
```bash | ||
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml | ||
``` | ||
If it the first time you run this model, it will download weights from | ||
wang2yn84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
HuggingFace. | ||
|
||
## Gemma 7b | ||
```bash | ||
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml | ||
``` | ||
HuggingFace's Llama3 weights are gated, so you need to either run | ||
`huggingface-cli login` to set your token, OR, pass your hf_token explicitly. | ||
|
||
## Mixtral 8x7b | ||
```bash | ||
python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml | ||
To pass hf token, add `--hf_token` flag | ||
``` | ||
jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
``` | ||
|
||
To login using huggingface hub, run: | ||
|
||
# Run the server | ||
Here is an example to run the server with llama2 7B config. | ||
|
||
```bash | ||
python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" | ||
``` | ||
pip install -U "huggingface_hub[cli]" | ||
huggingface-cli login | ||
``` | ||
Then follow its prompt. | ||
|
||
Now you can fire gRPC to it. | ||
After the weights are downloaded, | ||
Next time when you run this `--hf_token` will no longer be required. | ||
|
||
Optional flags: | ||
* `--shard_on_batch=1` This makes the model to shard on | ||
the batch dimension. I.e. this runs in data parallel mode instead of model | ||
parallel. This will ignore the sharding config. This is recommended for Gemma 2B | ||
model, because Gemma 2B is small enough to fit on a single TPU chip. | ||
To run this model in `int8` quantization, add `--quantize_weights=1`. | ||
Quantization will be done on the flight as the weight loads. | ||
|
||
* `--sharding_config=<path>` This makes use of alternative sharding config instead of | ||
the ones in default_shardings directory. | ||
Weights downloaded from HuggingFace will be stored by default in `checkpoints` folder. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we the options to store weight separately? Even we have problem storing the weights in gcp vm directly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For gs bucket it need to be brought locally or use mount using Fuse. The working dir can be edited. Added paragraph to describe that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'll be great if you can add how to change the working dir. Cuz for us, we also need to direct to the external ssd. I will approve the PR to unblock you for now. |
||
in the place where `jpt` is executed. | ||
|
||
|
||
# Run the server with ray | ||
|
Uh oh!
There was an error while loading. Please reload this page.