Skip to content

Commit 663c102

Browse files
authored
Jetstream + RayServe deployment for interleave mode (#146)
* kuberay manifests and dockerfile * sample ray_serve * Single host interleave * update image * Gcsfuse and jax platform fix * multihost * Cleanup * Cleanup * Parameterize tpu head type * Format * revert * revert * update readme * fix format * lint
1 parent 50a6d10 commit 663c102

File tree

5 files changed

+543
-0
lines changed

5 files changed

+543
-0
lines changed

README.md

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,92 @@ python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts
199199
Please look at `deps/JetStream/benchmarks/README.md` for more information.
200200

201201

202+
203+
## Run server with Ray Serve
204+
205+
### Prerequisites
206+
207+
If running on GKE:
208+
209+
1. Follow instructions on [this link](https://github.com/GoogleCloudPlatform/ai-on-gke/tree/main/ray-on-gke/guides/tpu) to setup a GKE cluster and the TPU webhook.
210+
2. Follow instructions
211+
[here](https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/cloud-storage-fuse-csi-driver)
212+
to enable GCSFuse for your cluster. This will be needed to store the
213+
converted weights.
214+
3. Deploy one of the sample Kuberay cluster configurations:
215+
```bash
216+
kubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-singlehost.yaml
217+
```
218+
or
219+
```bash
220+
kubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-multihost.yaml
221+
```
222+
223+
224+
### Start a Ray Serve deployment
225+
226+
Single-host (Llama2 7B):
227+
228+
```bash
229+
export RAY_ADDRESS=http://localhost:8265
230+
231+
kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 &
232+
233+
ray job submit --runtime-env-json='{"working_dir": "."}' -- python run_ray_serve_interleave.py --tpu_chips=4 --num_hosts=1 --size=7b --model_name=llama-2 --batch_size=32 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
234+
```
235+
236+
Multi-host (Llama2 70B):
237+
238+
```bash
239+
export RAY_ADDRESS=http://localhost:8265
240+
241+
kubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 &
242+
243+
ray job submit --runtime-env-json='{"working_dir": "."}' -- python run_ray_serve_interleave.py --tpu_chips=8 --num_hosts=2 --size=70b --model_name=llama-2 --batch_size=8 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"
244+
```
245+
246+
### Sending an inference request
247+
248+
Port-forward to port 8888 for gRPC:
249+
```
250+
kubectl port-forward svc/example-cluster-kuberay-head-svc 8888:8888 &
251+
```
252+
253+
Sample python script:
254+
255+
```python
256+
import requests
257+
import os
258+
import grpc
259+
260+
from jetstream.core.proto import jetstream_pb2
261+
from jetstream.core.proto import jetstream_pb2_grpc
262+
263+
prompt = "What are the top 5 languages?"
264+
265+
channel = grpc.insecure_channel("localhost:8888")
266+
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
267+
268+
request = jetstream_pb2.DecodeRequest(
269+
text_content=jetstream_pb2.DecodeRequest.TextContent(
270+
text=prompt
271+
),
272+
priority=0,
273+
max_tokens=2000,
274+
)
275+
276+
response = stub.Decode(request)
277+
output = []
278+
for resp in response:
279+
output.extend(resp.stream_content.samples[0].text)
280+
281+
text_output = "".join(output)
282+
print(f"Prompt: {prompt}")
283+
print(f"Response: {text_output}")
284+
```
285+
286+
287+
202288
# Typical Errors
203289

204290
## Unexpected keyword argument 'device'

kuberay/image/Dockerfile

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
FROM rayproject/ray:2.22.0-py310
2+
3+
RUN pip install flax==0.8.3
4+
RUN pip install jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5+
RUN pip install tensorflow-text
6+
RUN pip install tensorflow
7+
8+
RUN pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
9+
RUN pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
10+
RUN pip install safetensors colorama coverage humanize
11+
12+
RUN git clone https://github.com/google/jetstream-pytorch
13+
WORKDIR jetstream-pytorch
14+
15+
RUN git submodule update --init --recursive
16+
RUN pip install -e .
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# This template contains a Kuberay cluster using a 2x2x2 TPU v4 PodSlice.
2+
# To get access to TPU resources, please follow instructions in this link:
3+
# https://cloud.google.com/kubernetes-engine/docs/how-to/tpus
4+
apiVersion: ray.io/v1
5+
kind: RayCluster
6+
metadata:
7+
name: example-cluster-kuberay
8+
spec:
9+
headGroupSpec:
10+
rayStartParams:
11+
{}
12+
template:
13+
spec:
14+
imagePullSecrets:
15+
[]
16+
serviceAccountName: ray-ksa
17+
containers:
18+
- volumeMounts:
19+
- name: gcs-fuse-checkpoint
20+
mountPath: /llama
21+
readOnly: true
22+
- mountPath: /tmp/ray
23+
name: ray-logs
24+
name: ray-head
25+
image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240709
26+
imagePullPolicy: IfNotPresent
27+
resources:
28+
limits:
29+
cpu: "4"
30+
ephemeral-storage: 30Gi
31+
memory: 40G
32+
requests:
33+
cpu: "4"
34+
ephemeral-storage: 30Gi
35+
memory: 40G
36+
securityContext:
37+
{}
38+
env:
39+
- name: JAX_PLATFORMS
40+
value: "cpu"
41+
- name: RAY_memory_monitor_refresh_ms
42+
value: "0"
43+
- name: RAY_GRAFANA_IFRAME_HOST
44+
value: http://${grafana_host}
45+
- name: RAY_GRAFANA_HOST
46+
value: http://grafana:80
47+
- name: RAY_PROMETHEUS_HOST
48+
value: http://frontend:9090
49+
ports:
50+
- containerPort: 6379
51+
name: gcs
52+
- containerPort: 8265
53+
name: dashboard
54+
- containerPort: 10001
55+
name: client
56+
- containerPort: 8000
57+
name: serve
58+
- containerPort: 8471
59+
name: slicebuilder
60+
- containerPort: 8081
61+
name: mxla
62+
- containerPort: 8888
63+
name: grpc
64+
volumes:
65+
- emptyDir: {}
66+
name: ray-logs
67+
- name: gcs-fuse-checkpoint
68+
csi:
69+
driver: gcsfuse.csi.storage.gke.io
70+
readOnly: true
71+
volumeAttributes:
72+
bucketName: ricliu-llama2-70b-chat
73+
mountOptions: "implicit-dirs"
74+
metadata:
75+
annotations:
76+
gke-gcsfuse/volumes: "true"
77+
labels:
78+
cloud.google.com/gke-ray-node-type: head
79+
app.kubernetes.io/name: kuberay
80+
app.kubernetes.io/instance: example-cluster
81+
82+
workerGroupSpecs:
83+
- rayStartParams:
84+
{}
85+
replicas: 1
86+
minReplicas: 1
87+
maxReplicas: 1
88+
numOfHosts: 2
89+
groupName: workergroup
90+
template:
91+
spec:
92+
imagePullSecrets:
93+
[]
94+
serviceAccountName: ray-ksa
95+
containers:
96+
- volumeMounts:
97+
- mountPath: /tmp/ray
98+
name: ray-logs
99+
- name: gcs-fuse-checkpoint
100+
mountPath: /llama
101+
readOnly: true
102+
name: ray-worker
103+
image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240709
104+
imagePullPolicy: IfNotPresent
105+
resources:
106+
limits:
107+
cpu: "8"
108+
ephemeral-storage: 30Gi
109+
google.com/tpu: "4"
110+
memory: 200G
111+
requests:
112+
cpu: "8"
113+
ephemeral-storage: 30Gi
114+
google.com/tpu: "4"
115+
memory: 200G
116+
securityContext:
117+
{}
118+
env:
119+
- name: JAX_PLATFORMS
120+
value: "cpu"
121+
ports:
122+
null
123+
volumes:
124+
- emptyDir: {}
125+
name: ray-logs
126+
- name: gcs-fuse-checkpoint
127+
csi:
128+
driver: gcsfuse.csi.storage.gke.io
129+
readOnly: true
130+
volumeAttributes:
131+
bucketName: ricliu-llama2-70b-chat
132+
mountOptions: "implicit-dirs"
133+
nodeSelector:
134+
cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice
135+
cloud.google.com/gke-tpu-topology: 2x2x2
136+
iam.gke.io/gke-metadata-server-enabled: "true"
137+
metadata:
138+
annotations:
139+
gke-gcsfuse/volumes: "true"
140+
labels:
141+
cloud.google.com/gke-ray-node-type: worker
142+
app.kubernetes.io/name: kuberay
143+
app.kubernetes.io/instance: example-cluster
144+
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# This template contains a Kuberay cluster using a 2x2x1 TPU v4 PodSlice.
2+
# To get access to TPU resources, please follow instructions in this link:
3+
# https://cloud.google.com/kubernetes-engine/docs/how-to/tpus
4+
apiVersion: ray.io/v1
5+
kind: RayCluster
6+
metadata:
7+
name: example-cluster-kuberay
8+
spec:
9+
headGroupSpec:
10+
rayStartParams:
11+
{}
12+
template:
13+
spec:
14+
imagePullSecrets:
15+
[]
16+
serviceAccountName: ray-ksa
17+
containers:
18+
- volumeMounts:
19+
- name: gcs-fuse-checkpoint
20+
mountPath: /llama
21+
readOnly: true
22+
- mountPath: /tmp/ray
23+
name: ray-logs
24+
name: ray-head
25+
image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240709
26+
imagePullPolicy: IfNotPresent
27+
resources:
28+
limits:
29+
cpu: "4"
30+
ephemeral-storage: 30Gi
31+
memory: 40G
32+
requests:
33+
cpu: "4"
34+
ephemeral-storage: 30Gi
35+
memory: 40G
36+
securityContext:
37+
{}
38+
env:
39+
- name: JAX_PLATFORMS
40+
value: "cpu"
41+
- name: RAY_memory_monitor_refresh_ms
42+
value: "0"
43+
- name: RAY_GRAFANA_IFRAME_HOST
44+
value: http://${grafana_host}
45+
- name: RAY_GRAFANA_HOST
46+
value: http://grafana:80
47+
- name: RAY_PROMETHEUS_HOST
48+
value: http://frontend:9090
49+
ports:
50+
- containerPort: 6379
51+
name: gcs
52+
- containerPort: 8265
53+
name: dashboard
54+
- containerPort: 10001
55+
name: client
56+
- containerPort: 8000
57+
name: serve
58+
- containerPort: 8888
59+
name: grpc
60+
volumes:
61+
- emptyDir: {}
62+
name: ray-logs
63+
- name: gcs-fuse-checkpoint
64+
csi:
65+
driver: gcsfuse.csi.storage.gke.io
66+
readOnly: true
67+
volumeAttributes:
68+
bucketName: ricliu-llama2
69+
mountOptions: "implicit-dirs"
70+
metadata:
71+
annotations:
72+
gke-gcsfuse/volumes: "true"
73+
labels:
74+
cloud.google.com/gke-ray-node-type: head
75+
app.kubernetes.io/name: kuberay
76+
app.kubernetes.io/instance: example-cluster
77+
78+
workerGroupSpecs:
79+
- rayStartParams:
80+
{}
81+
replicas: 1
82+
minReplicas: 1
83+
maxReplicas: 1
84+
numOfHosts: 1
85+
groupName: workergroup
86+
template:
87+
spec:
88+
imagePullSecrets:
89+
[]
90+
serviceAccountName: ray-ksa
91+
containers:
92+
- volumeMounts:
93+
- mountPath: /tmp/ray
94+
name: ray-logs
95+
- name: gcs-fuse-checkpoint
96+
mountPath: /llama
97+
readOnly: true
98+
name: ray-worker
99+
image: gcr.io/tpu-vm-gke-testing/ricliu-jetstream:20240709
100+
imagePullPolicy: IfNotPresent
101+
resources:
102+
limits:
103+
cpu: "8"
104+
ephemeral-storage: 30Gi
105+
google.com/tpu: "4"
106+
memory: 200G
107+
requests:
108+
cpu: "8"
109+
ephemeral-storage: 30Gi
110+
google.com/tpu: "4"
111+
memory: 200G
112+
securityContext:
113+
{}
114+
env:
115+
- name: JAX_PLATFORMS
116+
value: "cpu"
117+
ports:
118+
null
119+
volumes:
120+
- emptyDir: {}
121+
name: ray-logs
122+
- name: gcs-fuse-checkpoint
123+
csi:
124+
driver: gcsfuse.csi.storage.gke.io
125+
readOnly: true
126+
volumeAttributes:
127+
bucketName: ricliu-llama2
128+
mountOptions: "implicit-dirs"
129+
nodeSelector:
130+
cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice
131+
cloud.google.com/gke-tpu-topology: 2x2x1
132+
iam.gke.io/gke-metadata-server-enabled: "true"
133+
metadata:
134+
annotations:
135+
gke-gcsfuse/volumes: "true"
136+
labels:
137+
cloud.google.com/gke-ray-node-type: worker
138+
app.kubernetes.io/name: kuberay
139+
app.kubernetes.io/instance: example-cluster
140+

0 commit comments

Comments
 (0)