|
1 | 1 | # Jetstream-PyTorch
|
2 | 2 | JetStream Engine implementation in PyTorch
|
3 | 3 |
|
| 4 | +# Outline |
4 | 5 |
|
5 |
| -# Install |
| 6 | +1. Ssh to Cloud TPU VM (using v5e-8 TPU VM) |
| 7 | + a. Create a Cloud TPU VM if you haven’t |
| 8 | +2. Download jetstream-pytorch github repo |
| 9 | +3. Clone repo and install dependencies |
| 10 | +4. Download and convert weights |
| 11 | +5. Run checkpoint converter (quantizer) |
| 12 | +6. Local run |
| 13 | +7. Run the server |
| 14 | +8. Run benchmarks |
| 15 | +9. Typical Errors |
6 | 16 |
|
7 |
| -### 1. Get the jetstream-pytorch code |
| 17 | +# Ssh to Cloud TPU VM (using v5e-8 TPU VM) |
| 18 | + |
| 19 | +```bash |
| 20 | +gcloud compute config-ssh |
| 21 | +gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone" |
| 22 | +``` |
| 23 | +## Create a Cloud TPU VM in a GCP project if you haven’t |
| 24 | +Follow step 1-9 in the following guide |
| 25 | +* https://cloud.google.com/tpu/docs/v5e-inference#prepare-a-project |
| 26 | + |
| 27 | +# Clone repo and install dependencies |
| 28 | + |
| 29 | +## Get the jetstream-pytorch code |
8 | 30 | ```bash
|
9 | 31 | git clone https://github.com/google/jetstream-pytorch.git
|
10 | 32 | ```
|
11 | 33 |
|
12 |
| -1.1 (optional) Create a virtual env using `venv` or `conda` and activate it. |
| 34 | +(optional) Create a virtual env using `venv` or `conda` and activate it. |
13 | 35 |
|
14 |
| -### 2. Run installation script: |
| 36 | +## 2. Run installation script: |
15 | 37 |
|
16 | 38 | ```bash
|
17 | 39 | cd jetstream-pytorch
|
18 | 40 | source install_everything.sh
|
19 | 41 | ```
|
| 42 | +NOTE: the above script will export PYTHONPATH, so sourcing will make it to take effect in the current shell |
20 | 43 |
|
| 44 | +# Download and convert weights |
21 | 45 |
|
22 |
| -# Get weights |
23 |
| - |
24 |
| -### First get official llama weights from meta-llama |
| 46 | +## Get official llama weights from meta-llama |
25 | 47 |
|
26 | 48 | Following instructions here: https://github.com/meta-llama/llama#download
|
27 | 49 |
|
28 | 50 | After you have downloaded the weights, it will also download a `tokenizer.model` file that is
|
29 | 51 | the tokenizer that we will use.
|
30 | 52 |
|
31 |
| -### Run weight merger to convert (and ) |
| 53 | +## Run weight safetensor convert |
| 54 | + |
32 | 55 | ```bash
|
33 | 56 | export input_ckpt_dir=Original llama weights directory
|
34 | 57 | export output_ckpt_dir=The output directory
|
@@ -73,3 +96,20 @@ export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
|
73 | 96 | python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
|
74 | 97 | ```
|
75 | 98 | Please look at `deps/JetStream/benchmarks/README.md` for more information.
|
| 99 | + |
| 100 | + |
| 101 | +# Typical Errors |
| 102 | + |
| 103 | +## Unexpected keyword argument 'device' |
| 104 | + |
| 105 | +Fix: |
| 106 | +* Uninstall jax and jaxlib dependencies |
| 107 | +* Reinstall using `source install_everything.sh |
| 108 | + |
| 109 | +## Out of memory |
| 110 | + |
| 111 | +Fix: |
| 112 | +* Use smaller batch size |
| 113 | +* Use quantization |
| 114 | + |
| 115 | + |
0 commit comments