Skip to content

Add MTP drafter pipeline for npu executor to enable speculative decoding. #9197

Add MTP drafter pipeline for npu executor to enable speculative decoding.

Add MTP drafter pipeline for npu executor to enable speculative decoding. #9197

Workflow file for this run

name: "CI"
on:
push:
tags:
- v*.*.*
pull_request:
branches:
- main
schedule:
- cron: "0 10 * * *" # Run at 2am PST (10am UTC) every day to refresh the cache.
workflow_dispatch: # Manual trigger
inputs:
REFRESH_CACHE:
description: 'Refresh cache to remove unused files'
type: boolean
default: true
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
presubmit:
name: "Presubmit"
runs-on: LiteRT_Linux_x64
permissions:
actions: write # For gh cache delete.
contents: write # For gh release upload.
env:
MODEL_KEY: gemma-3-1b-it-v1
MODEL_PATH: ./models/gemma3-1b-it-int4.litertlm
MODEL_URL: https://huggingface.co/litert-community/Gemma3-1B-IT/resolve/main/gemma3-1b-it-int4.litertlm
GH_TOKEN: ${{ github.token }} # For gh release upload.
REFRESH_CACHE: ${{ github.event_name == 'schedule' ||
(github.event_name == 'workflow_dispatch' && inputs.REFRESH_CACHE) }}
steps:
- name: Checkout code.
uses: actions/checkout@v4
with:
lfs: true
- name : Set up cache keys.
id: cache-keys
run: |
CACHE_RESTORE_KEY_2="${{ github.workflow }}"
CACHE_RESTORE_KEY_1="$CACHE_RESTORE_KEY_2-${{ hashFiles('**/WORKSPACE', '**/.bazelrc') }}"
CACHE_RESTORE_KEY_0="$CACHE_RESTORE_KEY_1-${{ hashFiles('**/BUILD*') }}"
# If it's not a pull request, then it will be the same as $CACHE_RESTORE_KEY_1-.
CACHE_RESTORE_KEY_HEAD="$CACHE_RESTORE_KEY_0-${{ github.event.pull_request.base.sha }}"
CACHE_KEY="$CACHE_RESTORE_KEY_0-${{ github.sha }}"
echo "CACHE_RESTORE_KEY_2=$CACHE_RESTORE_KEY_2" >> "$GITHUB_OUTPUT"
echo "CACHE_RESTORE_KEY_1=$CACHE_RESTORE_KEY_1" >> "$GITHUB_OUTPUT"
echo "CACHE_RESTORE_KEY_0=$CACHE_RESTORE_KEY_0" >> "$GITHUB_OUTPUT"
echo "CACHE_RESTORE_KEY_HEAD=$CACHE_RESTORE_KEY_HEAD" >> "$GITHUB_OUTPUT"
echo "CACHE_KEY=$CACHE_KEY" >> "$GITHUB_OUTPUT"
- name: Clean build outputs if cache is being refreshed.
if: env.REFRESH_CACHE == 'true'
run: bazel clean --expunge
- name: Restore bazel cache if cache is not being refreshed.
id: bazel-cache
if: env.REFRESH_CACHE != 'true'
uses: actions/cache/restore@v4
with:
path: |
~/.cache/bazel-linux
~/.cache/bazel-android
key: ${{ steps.cache-keys.outputs.CACHE_KEY }}
restore-keys: |
${{ steps.cache-keys.outputs.CACHE_RESTORE_KEY_HEAD }}
${{ steps.cache-keys.outputs.CACHE_RESTORE_KEY_0 }}-
${{ steps.cache-keys.outputs.CACHE_RESTORE_KEY_1 }}-
${{ steps.cache-keys.outputs.CACHE_RESTORE_KEY_2 }}-
- name: Check cache hit.
run: |
echo "Cache Hit: ${{ steps.bazel-cache.outputs.cache-hit }}"
echo "Cache Primary Key: ${{ steps.bazel-cache.outputs.cache-primary-key }}"
echo "Cache Matched Key: ${{ steps.bazel-cache.outputs.cache-matched-key }}"
- name: Download Model
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
mkdir -p ./models
echo "Downloading model from Hugging Face..."
curl -L --retry 5 -f \
-H "Authorization: Bearer $HF_TOKEN" \
-o ${{ env.MODEL_PATH }} \
"${{ env.MODEL_URL }}"
ls -lh ${{ env.MODEL_PATH }}
- name: Run bazel build on Linux.
run: |
bazel build --disk_cache=~/.cache/bazel-linux --config=linux_x86_64 \
//... \
//runtime/engine:litert_lm_main
- name: Check if litert_lm_main doesn't link libLiteRt.so.
# Return exit code 1 if libLiteRt.so is required.
run: |
! readelf -d bazel-bin/runtime/engine/litert_lm_main | grep libLiteRt.so
- name: Update litert_lm_main prebuilt for Linux if new version tag is pushed.
if: github.ref_type == 'tag'
run: |
cp bazel-bin/runtime/engine/litert_lm_main litert_lm_main.linux_x86_64
gh release upload ${{ github.ref_name }} litert_lm_main.linux_x86_64 --clobber
- name: Run bazel test on Linux.
run: |
bazel test --disk_cache=~/.cache/bazel-linux --config=linux_x86_64 \
--test_output=errors \
//...
- name: Run bazel build on Linux with dynamic linking.
run: |
bazel build --config=linux_x86_64 \
--define=litert_link_capi_so=true \
--define=resolve_symbols_in_exec=false \
//runtime/engine:litert_lm_main
- name: Install pytest
run: python3 -m pip install --break-system-packages pytest==8.3.4
- name: Run pytest
run: pytest tools/test/ --model-path=${{ env.MODEL_PATH }} --build-system=bazel
- name: Check if litert_lm_main has only LiteRt symbols undefined.
# Return exit code 1 if libLiteRt.so has LiteRt symbols except for LiteRtTopK
# and some exceptions listed explictly here.
# TODO b/453859132: Remove OpaqueOptions.
run: |
! readelf -sW bazel-bin/runtime/engine/litert_lm_main \
| grep " LiteRt" | grep -v " UND LiteRt" | grep -v " LiteRtTopK" \
| grep -v -e LiteRtIsSameLayout -e LiteRtGetNumLayoutElements \
-e "LiteRt.*Logger" -e "LiteRt.*Metric" -e "LiteRt.*OpaqueOptions" \
-e "LiteRt.*EnvironmentOptions" -e LiteRtGetLogSeverityName \
-e LiteRtCompareApiVersion -e LiteRtGetStatusString \
-e LiteRtGetNumModelSignatures -e LiteRtGetModelSignature \
-e LiteRtGetSignatureKey -e LiteRtGetSignatureOutputTensor \
-e LiteRtGetQuantizationTypeId -e LiteRtGetPerTensorQuantization \
-e TensorBufferRequirements
- name: Setup Android NDK.
uses: nttld/setup-ndk@v1
id: setup-ndk
with:
ndk-version: r28b
add-to-path: false
- name: Run bazel build for Android.
run: |
bazel build --disk_cache=~/.cache/bazel-android --config=android_arm64 \
//... \
//runtime/engine:litert_lm_main \
@litert//litert/vendors/mediatek/dispatch:dispatch_api_so \
@litert//litert/vendors/qualcomm/dispatch:dispatch_api_so \
-- \
-//python/... \
-//schema/py:* \
-//kotlin/java/com/google/ai/edge/litertlm/example/...
env:
ANDROID_NDK_HOME: ${{ steps.setup-ndk.outputs.ndk-path }}
- name: Update litert_lm_main prebuilt for Android if new version tag is pushed.
if: github.ref_type == 'tag'
run: |
cp bazel-bin/runtime/engine/litert_lm_main litert_lm_main.android_arm64
gh release upload ${{ github.ref_name }} litert_lm_main.android_arm64 --clobber
- name: Remove cache if cache is being refreshed.
if: env.REFRESH_CACHE == 'true'
continue-on-error: true # Ignore errors when cache is not found.
run: gh cache delete ${{ steps.cache-keys.outputs.CACHE_KEY }}
- name: Save bazel cache if it's new or being refreshed.
uses: actions/cache/save@v4
if: env.REFRESH_CACHE == 'true' || steps.bazel-cache.outputs.cache-hit != 'true'
with:
path: |
~/.cache/bazel-linux
~/.cache/bazel-android
key: ${{ steps.cache-keys.outputs.CACHE_KEY }}