Skip to content

Commit 4443f79

Browse files
authored
feat(runtimes): Support Distributed MLX on CUDA (#2790)
* feat(runtimes): Support Distributed MLX on CUDA Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Remove arm build from MLX runtime Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Update get_runtime_packages API Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Force to change vars in examples Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Remove LD_LIBRARY_PATH updates Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Add patch command to DeepSpeed example Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Cleanup apt packages Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Reduce MLX and DeepSpeed image size Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> --------- Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
1 parent d705bb6 commit 4443f79

8 files changed

Lines changed: 1598 additions & 761 deletions

File tree

.github/workflows/build-and-push-images.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ on:
44
push:
55
branches:
66
- master
7-
- 'release-*'
7+
- "release-*"
88
tags:
9-
- 'v*'
9+
- "v*"
1010
pull_request:
1111

1212
jobs:
@@ -34,9 +34,10 @@ jobs:
3434
- component-name: deepspeed-runtime
3535
dockerfile: cmd/runtimes/deepspeed/Dockerfile
3636
platforms: linux/amd64,linux/arm64
37+
# TODO (andreyvelich): mlx[cuda] doesn't support arm at the moment: https://github.com/ml-explore/mlx/issues/2469
3738
- component-name: mlx-runtime
3839
dockerfile: cmd/runtimes/mlx/Dockerfile
39-
platforms: linux/arm64
40+
platforms: linux/amd64
4041
- component-name: torchtune-trainer
4142
dockerfile: cmd/trainers/torchtune/Dockerfile
4243
platforms: linux/amd64,linux/arm64

cmd/runtimes/deepspeed/Dockerfile

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
FROM mpioperator/base:v0.6.0 AS mpi
22
FROM nvidia/cuda:12.8.1-devel-ubuntu22.04
33

4-
# Disable interactive dialog from apt.
5-
ENV DEBIAN_FRONTEND noninteractive
6-
7-
# Install libraries required for OpenMPI to work.
8-
RUN apt-get update && apt install -y --no-install-recommends \
9-
cmake g++ gcc \
10-
wget vim \
11-
openssh-client openssh-server libcap2-bin \
12-
libopenmpi-dev openmpi-bin
4+
# Install libraries required for OpenMPI to work. Image installs OpenMPI 5.0.7
5+
RUN apt update && apt install -y --no-install-recommends \
6+
openssh-server openssh-client libcap2-bin \
7+
g++ libopenmpi-dev \
8+
python3-dev pip && rm -f /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && rm -rf /var/lib/apt/lists/*
139

1410
# Add capability to run sshd as non-root.
1511
RUN setcap CAP_NET_BIND_SERVICE=+eip /usr/sbin/sshd
@@ -24,10 +20,7 @@ COPY --from=mpi /etc/ssh/ssh_config /etc/ssh/ssh_config
2420
COPY --from=mpi /etc/ssh/sshd_config /etc/ssh/sshd_config
2521
COPY --from=mpi /home/mpiuser/.sshd_config /home/mpiuser/.sshd_config
2622

27-
# Install the required Python packages.
28-
RUN apt install -y python3-dev pip && rm -f /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python
29-
30-
ENV LD_LIBRARY_PATH=/usr/local/lib:/usr/local/mpi/lib:/usr/local/mpi/lib64:${LD_LIBRARY_PATH}
23+
# Set home directory for mpiuser.
3124
ENV HOME=/home/mpiuser
3225
ENV PATH=$HOME/.local/bin:$PATH
3326

cmd/runtimes/mlx/Dockerfile

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
FROM mpioperator/base:v0.6.0 AS mpi
2-
FROM debian:trixie
2+
FROM nvidia/cuda:12.8.1-devel-ubuntu22.04
33

4-
# Install libraries required for OpenMPI and MLX. This image installs OpenMPI 5.0.7
4+
# Install libraries required for OpenMPI to work. Image installs OpenMPI 5.0.7
55
RUN apt update && apt install -y --no-install-recommends \
66
openssh-server openssh-client libcap2-bin \
7-
libopenmpi-dev \
8-
git g++ libblas-dev liblapack-dev liblapacke-dev
7+
g++ libopenmpi-dev libblas-dev liblapack-dev liblapacke-dev \
8+
python3-dev pip && rm -f /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && rm -rf /var/lib/apt/lists/*
99

1010
# Add capability to run sshd as non-root.
1111
RUN setcap CAP_NET_BIND_SERVICE=+eip /usr/sbin/sshd
12+
RUN apt remove libcap2-bin -y
1213

1314
# Configure mpiuser and home directory.
1415
RUN useradd -m mpiuser
@@ -19,12 +20,14 @@ COPY --from=mpi /etc/ssh/ssh_config /etc/ssh/ssh_config
1920
COPY --from=mpi /etc/ssh/sshd_config /etc/ssh/sshd_config
2021
COPY --from=mpi /home/mpiuser/.sshd_config /home/mpiuser/.sshd_config
2122

22-
# Install the required Python packages. This image has Python 3.13
23-
RUN apt update && apt install -y python3 python3-pip && ln -s /usr/bin/python3 /usr/bin/python && apt clean
23+
# Set home directory for mpiuser.
24+
ENV HOME=/home/mpiuser
25+
ENV PATH=$HOME/.local/bin:$PATH
2426

25-
# We have to build MLX and MLX Data from source.
26-
RUN git clone https://github.com/ml-explore/mlx.git
27-
RUN cd mlx && git checkout f018e248cd75dbb65668f418d6afb67842ea28b7 && CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -v --break-system-packages .
27+
COPY cmd/runtimes/mlx/requirements.txt .
28+
RUN pip install --user -r requirements.txt
2829

29-
RUN git clone https://github.com/ml-explore/mlx-data.git
30-
RUN cd mlx-data && git checkout 79516daa75aa3e9fd72fc5e3fb5e9e629912feac && CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -v --break-system-packages .
30+
# Give mpiuser permission to download packages and HF models.
31+
# .cache directory is used by ML frameworks to download models.
32+
RUN chown -R mpiuser:mpiuser /home/mpiuser/.local
33+
RUN mkdir -p /home/mpiuser/.cache && chown -R mpiuser:mpiuser /home/mpiuser/.cache

cmd/runtimes/mlx/requirements.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# MLX libraries.
2+
mlx[cuda]==0.28.0
3+
mlx-data==0.1.0
4+
mlx-lm==0.26.3
5+
# HuggingFace libraries.
6+
datasets==4.0.0

examples/deepspeed/text-summarization/T5-Fine-Tuning.ipynb

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
"\n",
1515
"Wikihow dataset: https://huggingface.co/datasets/sentence-transformers/wikihow\n",
1616
"\n",
17-
"This Notebook will use **4 x A100 NVIDIA GPUs**, to fine-tune T5 model on 2 nodes (every node has 2 GPUs).\n",
18-
"\n",
19-
"**TODO (andreyvelich)**: Currently, to run this Notebook you have to manualy update the container resources in the ClusterTrainingRuntime, since we don't propogate TrainJob's `resources_per_node` to the JobSet"
17+
"This Notebook will use **4 x A100 NVIDIA GPUs**, to fine-tune T5 model on 2 nodes (every node has 2 GPUs)."
2018
]
2119
},
2220
{
@@ -35,7 +33,56 @@
3533
"id": "4900404c5d532bdf",
3634
"metadata": {},
3735
"outputs": [],
38-
"source": "# !pip install git+https://github.com/kubeflow/sdk.git@main"
36+
"source": [
37+
"# !pip install git+https://github.com/kubeflow/sdk.git@main"
38+
]
39+
},
40+
{
41+
"cell_type": "markdown",
42+
"id": "74f50fe7-5a01-468c-9efe-2913b3d251da",
43+
"metadata": {},
44+
"source": [
45+
"## Update the GPU Resources\n",
46+
"\n",
47+
"Currently, Kubeflow Trainer does not support configuring DeepSpeed resources directly through a\n",
48+
"TrainJob specification.\n",
49+
"\n",
50+
"To adjust GPU allocations (and other container resource settings), you must manually patch the ClusterTrainingRuntime.\n",
51+
"\n",
52+
"Progress for native resource configuration in TrainJob is being tracked here: [kubeflow/trainer#2650](https://github.com/kubeflow/trainer/issues/2650)\n"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": 1,
58+
"id": "d038d8cd-4e5a-4c4c-aa17-a5c575e2948a",
59+
"metadata": {},
60+
"outputs": [
61+
{
62+
"name": "stdout",
63+
"output_type": "stream",
64+
"text": [
65+
"clustertrainingruntime.trainer.kubeflow.org/deepspeed-distributed patched\n"
66+
]
67+
}
68+
],
69+
"source": [
70+
"patch = \"\"\"\n",
71+
"[\n",
72+
" {\n",
73+
" \"op\": \"add\",\n",
74+
" \"path\": \"/spec/template/spec/replicatedJobs/0/template/spec/template/spec/containers/0/resources\",\n",
75+
" \"value\": { \"limits\": { \"nvidia.com/gpu\": \"2\" } }\n",
76+
" },\n",
77+
" {\n",
78+
" \"op\": \"add\",\n",
79+
" \"path\": \"/spec/template/spec/replicatedJobs/1/template/spec/template/spec/containers/0/resources\",\n",
80+
" \"value\": { \"limits\": { \"nvidia.com/gpu\": \"2\" } }\n",
81+
" }\n",
82+
"]\n",
83+
"\"\"\"\n",
84+
"!kubectl patch clustertrainingruntime deepspeed-distributed --type='json' -p \"$patch\""
85+
]
3986
},
4087
{
4188
"cell_type": "markdown",
@@ -293,7 +340,7 @@
293340
"outputs": [],
294341
"source": [
295342
"MODEL_NAME = \"t5-base\"\n",
296-
"BUCKET_NAME = \"TODO: add your bucket here\""
343+
"# BUCKET_NAME = \"TODO: add your bucket here\""
297344
]
298345
},
299346
{

0 commit comments

Comments
 (0)