Skip to content

Commit c6308a9

Browse files
Metal backend: Add Whisper to CI workflow (pytorch#15685)
This PR refactors the model export and e2e scripts to support both CUDA and Metal backends, and updates the Metal CI workflow to generalize model export and e2e testing for multiple models and quantization options. It expands Metal CI model coverage to also include Whisper.
1 parent aba44fd commit c6308a9

File tree

4 files changed

+114
-128
lines changed

4 files changed

+114
-128
lines changed

.ci/scripts/export_model_cuda_artifact.sh renamed to .ci/scripts/export_model_artifact.sh

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# Export model to CUDA format with optional quantization
8+
# Export model to CUDA/Metal format with optional quantization
99

1010
show_help() {
1111
cat << EOF
12-
Usage: export_model_cuda_artifact.sh <hf_model> [quant_name] [output_dir]
12+
Usage: export_model_artifact.sh <device> <hf_model> [quant_name] [output_dir]
1313
14-
Export a HuggingFace model to CUDA format with optional quantization.
14+
Export a HuggingFace model to CUDA/Metal format with optional quantization.
1515
1616
Arguments:
17+
device cuda or metal (required)
18+
1719
hf_model HuggingFace model ID (required)
1820
Supported models:
1921
- mistralai/Voxtral-Mini-3B-2507
@@ -29,9 +31,9 @@ Arguments:
2931
output_dir Output directory for artifacts (optional, default: current directory)
3032
3133
Examples:
32-
export_model_cuda_artifact.sh "openai/whisper-small"
33-
export_model_cuda_artifact.sh "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
34-
export_model_cuda_artifact.sh "google/gemma-3-4b-it" "non-quantized" "./output"
34+
export_model_artifact.sh metal "openai/whisper-small"
35+
export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
36+
export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output"
3537
EOF
3638
}
3739

@@ -48,9 +50,22 @@ fi
4850

4951
set -eux
5052

51-
HF_MODEL="$1"
52-
QUANT_NAME="${2:-non-quantized}"
53-
OUTPUT_DIR="${3:-.}"
53+
DEVICE="$1"
54+
HF_MODEL="$2"
55+
QUANT_NAME="${3:-non-quantized}"
56+
OUTPUT_DIR="${4:-.}"
57+
58+
case "$DEVICE" in
59+
cuda)
60+
;;
61+
metal)
62+
;;
63+
*)
64+
echo "Error: Unsupported device '$DEVICE'"
65+
echo "Supported devices: cuda, metal"
66+
exit 1
67+
;;
68+
esac
5469

5570
# Determine model configuration based on HF model ID
5671
case "$HF_MODEL" in
@@ -75,6 +90,10 @@ case "$HF_MODEL" in
7590
fi
7691
;;
7792
google/gemma-3-4b-it)
93+
if [ "$DEVICE" = "metal" ]; then
94+
echo "Error: Export for device 'metal' is not yet tested for model '$HF_MODEL'"
95+
exit 1
96+
fi
7897
MODEL_NAME="gemma3"
7998
TASK="multimodal-text-to-text"
8099
MAX_SEQ_LEN="64"
@@ -95,9 +114,17 @@ case "$QUANT_NAME" in
95114
EXTRA_ARGS=""
96115
;;
97116
quantized-int4-tile-packed)
117+
if [ "$DEVICE" = "metal" ]; then
118+
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
119+
exit 1
120+
fi
98121
EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d"
99122
;;
100123
quantized-int4-weight-only)
124+
if [ "$DEVICE" = "metal" ]; then
125+
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
126+
exit 1
127+
fi
101128
EXTRA_ARGS="--qlinear_encoder 4w"
102129
;;
103130
*)
@@ -118,12 +145,18 @@ MAX_SEQ_LEN_ARG=""
118145
if [ -n "$MAX_SEQ_LEN" ]; then
119146
MAX_SEQ_LEN_ARG="--max_seq_len $MAX_SEQ_LEN"
120147
fi
148+
149+
DEVICE_ARG=""
150+
if [ "$DEVICE" = "cuda" ]; then
151+
DEVICE_ARG="--device cuda"
152+
fi
153+
121154
optimum-cli export executorch \
122155
--model "$HF_MODEL" \
123156
--task "$TASK" \
124-
--recipe "cuda" \
157+
--recipe "$DEVICE" \
125158
--dtype bfloat16 \
126-
--device cuda \
159+
${DEVICE_ARG} \
127160
${MAX_SEQ_LEN_ARG} \
128161
${EXTRA_ARGS} \
129162
--output_dir ./
@@ -137,18 +170,18 @@ if [ -n "$PREPROCESSOR_OUTPUT" ]; then
137170
fi
138171

139172
test -f model.pte
140-
test -f aoti_cuda_blob.ptd
173+
test -f aoti_${DEVICE}_blob.ptd
141174
if [ -n "$PREPROCESSOR_OUTPUT" ]; then
142175
test -f $PREPROCESSOR_OUTPUT
143176
fi
144177
echo "::endgroup::"
145178

146179
echo "::group::Store $MODEL_NAME Artifacts"
147180
mkdir -p "${OUTPUT_DIR}"
148-
cp model.pte "${OUTPUT_DIR}/"
149-
cp aoti_cuda_blob.ptd "${OUTPUT_DIR}/"
181+
mv model.pte "${OUTPUT_DIR}/"
182+
mv aoti_${DEVICE}_blob.ptd "${OUTPUT_DIR}/"
150183
if [ -n "$PREPROCESSOR_OUTPUT" ]; then
151-
cp $PREPROCESSOR_OUTPUT "${OUTPUT_DIR}/"
184+
mv $PREPROCESSOR_OUTPUT "${OUTPUT_DIR}/"
152185
fi
153186
ls -al "${OUTPUT_DIR}"
154187
echo "::endgroup::"

.ci/scripts/test_model_cuda_e2e.sh renamed to .ci/scripts/test_model_e2e.sh

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# Test CUDA model end-to-end, need to run .ci/scripts/export_model_cuda_artifact.sh first
8+
# Test CUDA/Metal model end-to-end, need to run .ci/scripts/export_model_artifact.sh first
99

1010
show_help() {
1111
cat << EOF
12-
Usage: test_model_cuda_e2e.sh <hf_model> <quant_name> [model_dir]
12+
Usage: test_model_e2e.sh <device> <hf_model> <quant_name> [model_dir]
1313
14-
Build and run end-to-end tests for CUDA models.
14+
Build and run end-to-end tests for CUDA/Metal models.
1515
1616
Arguments:
17+
device cuda or metal (required)
18+
1719
hf_model HuggingFace model ID (required)
1820
Supported models:
1921
- mistralai/Voxtral-Mini-3B-2507
@@ -27,12 +29,12 @@ Arguments:
2729
- quantized-int4-weight-only
2830
2931
model_dir Directory containing model artifacts (optional, default: current directory)
30-
Expected files: model.pte, aoti_cuda_blob.ptd
32+
Expected files: model.pte, aoti_cuda_blob.ptd/aoti_metal_blob.ptd
3133
Tokenizers and test files will be downloaded to this directory
3234
3335
Examples:
34-
test_model_cuda_e2e.sh "openai/whisper-small" "non-quantized"
35-
test_model_cuda_e2e.sh "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output"
36+
test_model_e2e.sh metal "openai/whisper-small" "non-quantized"
37+
test_model_e2e.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output"
3638
EOF
3739
}
3840

@@ -55,20 +57,21 @@ fi
5557

5658
set -eux
5759

58-
HF_MODEL="$1"
59-
QUANT_NAME="$2"
60+
DEVICE="$1"
61+
HF_MODEL="$2"
62+
QUANT_NAME="$3"
6063
# Download tokenizers, audio, and image files to this directory
61-
MODEL_DIR="${3:-.}"
64+
MODEL_DIR="${4:-.}"
6265

6366
echo "Testing model: $HF_MODEL (quantization: $QUANT_NAME)"
6467

65-
# Make sure model.pte and aoti_cuda_blob.ptd exist
68+
# Make sure model.pte and aoti_${DEVICE}_blob.ptd exist
6669
if [ ! -f "$MODEL_DIR/model.pte" ]; then
6770
echo "Error: model.pte not found in $MODEL_DIR"
6871
exit 1
6972
fi
70-
if [ ! -f "$MODEL_DIR/aoti_cuda_blob.ptd" ]; then
71-
echo "Error: aoti_cuda_blob.ptd not found in $MODEL_DIR"
73+
if [ ! -f "$MODEL_DIR/aoti_${DEVICE}_blob.ptd" ]; then
74+
echo "Error: aoti_${DEVICE}_blob.ptd not found in $MODEL_DIR"
7275
exit 1
7376
fi
7477
# Locate EXECUTORCH_ROOT from the directory of this script
@@ -152,14 +155,24 @@ ls -al
152155
echo "::endgroup::"
153156

154157
echo "::group::Build $MODEL_NAME Runner"
158+
159+
if [ "$DEVICE" = "cuda" ]; then
160+
BUILD_BACKEND="EXECUTORCH_BUILD_CUDA"
161+
elif [ "$DEVICE" = "metal" ]; then
162+
BUILD_BACKEND="EXECUTORCH_BUILD_METAL"
163+
else
164+
echo "Error: Unsupported device '$DEVICE'. Must be 'cuda' or 'metal'."
165+
exit 1
166+
fi
167+
155168
cmake --preset llm \
156-
-DEXECUTORCH_BUILD_CUDA=ON \
169+
-D${BUILD_BACKEND}=ON \
157170
-DCMAKE_INSTALL_PREFIX=cmake-out \
158171
-DCMAKE_BUILD_TYPE=Release \
159172
-Bcmake-out -S.
160173
cmake --build cmake-out -j$(nproc) --target install --config Release
161174

162-
cmake -DEXECUTORCH_BUILD_CUDA=ON \
175+
cmake -D${BUILD_BACKEND}=ON \
163176
-DCMAKE_BUILD_TYPE=Release \
164177
-Sexamples/models/$RUNNER_PATH \
165178
-Bcmake-out/examples/models/$RUNNER_PATH/
@@ -168,11 +181,13 @@ echo "::endgroup::"
168181

169182
echo "::group::Run $MODEL_NAME Runner"
170183
set +e
171-
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
184+
if [ "$DEVICE" = "cuda" ]; then
185+
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
186+
fi
172187

173188
# Build runner command with common arguments
174189
RUNNER_BIN="cmake-out/examples/models/$RUNNER_PATH/$RUNNER_TARGET"
175-
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --data_path ${MODEL_DIR}/aoti_cuda_blob.ptd --temperature 0"
190+
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --data_path ${MODEL_DIR}/aoti_${DEVICE}_blob.ptd --temperature 0"
176191

177192
# Add model-specific arguments
178193
case "$MODEL_NAME" in

.github/workflows/cuda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ jobs:
142142
pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}
143143
echo "::endgroup::"
144144
145-
source .ci/scripts/export_model_cuda_artifact.sh "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"
145+
source .ci/scripts/export_model_artifact.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"
146146
147147
benchmark-model-cuda:
148148
name: benchmark-model-cuda
@@ -249,4 +249,4 @@ jobs:
249249
download-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-cuda-${{ matrix.quant }}
250250
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
251251
script: |
252-
source .ci/scripts/test_model_cuda_e2e.sh "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"
252+
source .ci/scripts/test_model_e2e.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"

0 commit comments

Comments
 (0)