Skip to content

Commit 7d99123

Browse files
committed
Reword, address comments
1 parent 56da057 commit 7d99123

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

README.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,18 @@ mistralai/Mixtral-8x7B-Instruct-v0.1
7878
To run jetstream-pytorch server with one model:
7979

8080
```
81-
jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct
81+
jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct
8282
```
8383

84-
If it the first time you run this model, it will download weights from
84+
If it's the first time you run this model, it will download weights from
8585
HuggingFace.
8686

8787
HuggingFace's Llama3 weights are gated, so you need to either run
8888
`huggingface-cli login` to set your token, OR, pass your hf_token explicitly.
8989

90-
To pass hf token, add `--hf_token` flag
90+
To pass hf token explicitly, add `--hf_token` flag
9191
```
92-
jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=...
92+
jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=...
9393
```
9494

9595
To login using huggingface hub, run:
@@ -109,6 +109,13 @@ Quantization will be done on the flight as the weight loads.
109109
Weights downloaded from HuggingFace will be stored by default in `checkpoints` folder.
110110
in the place where `jpt` is executed.
111111

112+
You can change where the weights are stored with `--working_dir` flag.
113+
114+
If you wish to use your own checkpoint, then, place them inside
115+
of the `checkpoints/<org>/<model>/hf_original` dir (or the corresponding subdir in `--working_dir`). For example,
116+
Llama3 checkpoints will be at `checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors`. You can replace these files with modified
117+
weights in HuggingFace format.
118+
112119

113120
# Run the server with ray
114121
Below are steps run server with ray:

jetstream_pt/fetch_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ModelInfo:
3939
# information needed to allocate cache
4040
num_layers: int
4141
# number of kv heads
42-
num_heads: int
42+
num_kv_heads: int
4343

4444
head_dim: int
4545
n_reps: int # repeatition for GQA
@@ -139,7 +139,7 @@ def construct_env_data_from_model_id(
139139
)
140140
env_data.cache_shape = (
141141
batch_size,
142-
model_info.num_heads,
142+
model_info.num_kv_heads,
143143
max_cache_length,
144144
model_info.head_dim,
145145
)

0 commit comments

Comments
 (0)