Skip to content

feat: multi-arch CUDA Dockerfile and sm_121 (DGX Spark GB10)#840

Open
nazq wants to merge 1 commit intohuggingface:mainfrom
nazq:feat/arm64-cuda-blackwell
Open

feat: multi-arch CUDA Dockerfile and sm_121 (DGX Spark GB10)#840
nazq wants to merge 1 commit intohuggingface:mainfrom
nazq:feat/arm64-cuda-blackwell

Conversation

@nazq
Copy link

@nazq nazq commented Mar 4, 2026

Summary

Builds on #827 (ARM64 CPU Dockerfile) by extending CUDA support to ARM64 and adding the DGX Spark GB10's sm_121 compute capability. Also adds the CI matrix entries and README updates needed to ship ARM64 images.

Changes

Dockerfile-cuda (multi-arch)

  • Use TARGETARCH to select correct sccache binary (x86_64 or aarch64)
  • Use TARGETARCH to select correct protoc binary (x86_64 or aarch_64)
  • Add sm_121 to nvprune section for DGX Spark GB10

compute_cap.rs

  • (120..=121, 120) => true — sm_121 runtime is compatible with sm_120 compiled binaries
  • (121, 121) => true — exact match for native sm_121 builds
  • Full test coverage for sm_121 compatibility matrix

flash_attn.rs

  • Allow runtime_compute_cap == 121 to use flash attention v2 (same arch family as sm_120)

build.yaml

  • Use matrix.platforms with fallback to linux/amd64 — enables per-variant platform selection without breaking existing entries

matrix.json

  • Add blackwell-121 entry (linux/arm64, CUDA_COMPUTE_CAP=121) for DGX Spark GB10
  • Add cpu-arm64 entry (linux/arm64, Dockerfile-arm64) for ARM64 CPU-only hosts

README.md

  • Add Platform column to Docker Images table
  • Add cpu-arm64-1.9 and 121-1.9 image entries
  • Replace Apple-only ARM64 section with comprehensive aarch64 docs covering CPU-only and CUDA build paths (DGX Spark, Jetson)
  • Add sm_121 to CUDA compute capability examples

Motivation

The NVIDIA DGX Spark uses the GB10 SoC with compute capability 12.1 (sm_121). This is a Blackwell-family chip (Grace + Blackwell GPU) on ARM64. Without these changes, TEI cannot run on the DGX Spark with CUDA acceleration.

Testing

  • docker build -f Dockerfile-cuda --build-arg CUDA_COMPUTE_CAP=121 --platform linux/arm64 .
  • Unit tests pass for compute_cap_matching with sm_121
  • CI matrix produces 121-{version}-grpc and cpu-arm64-{version}-grpc images

Closes #769

@nazq nazq force-pushed the feat/arm64-cuda-blackwell branch 3 times, most recently from 44f1190 to 8cf4772 Compare March 4, 2026 16:38
@alvarobartt alvarobartt self-requested a review March 6, 2026 09:18
alvarobartt
alvarobartt previously approved these changes Mar 6, 2026
Copy link
Member

@alvarobartt alvarobartt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @nazq, looks really clean!

Could you review and update also the table with the different images at https://github.com/huggingface/text-embeddings-inference/blob/main/docs/source/en/supported_models.md? Then I'll merge and validate that the CI is working as expected, hoping to release v1.9.3 next week.

And thanks for building on top of @z4y4ts PR and keeping them as co-author, much appreciated 🤗

@nazq
Copy link
Author

nazq commented Mar 6, 2026

Thanks a lot for the PR @nazq, looks really clean!

Could you review and update also the table with the different images at https://github.com/huggingface/text-embeddings-inference/blob/main/docs/source/en/supported_models.md? Then I'll merge and validate that the CI is working as expected, hoping to release v1.9.3 next week.

And thanks for building on top of @z4y4ts PR and keeping them as co-author, much appreciated 🤗

Updated supported_models.md. I did update the CI too but I've not run it so all done by inspection.

@stefan-it
Copy link

Hi @nazq thanks so much for that PR!

I tested the PR on my Spark and I got a build failure:

41.46 error[E0521]: borrowed data escapes outside of associated function                                                                                                                  
41.46   --> /root/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metrics-0.23.0/src/recorder/mod.rs:77:9                                                                            
41.46    |                                                                                                                                                                                
41.46 71 |       fn new(recorder: &dyn Recorder) -> Self {                                                                                                                                
41.46    |              --------  - let's call the lifetime of this reference `'1`                                                                                                        
41.46    |              |                                                                    
41.46    |              `recorder` is a reference that is only valid in the associated function body                                                                                      
41.46 ...                                                                                                                                                                                 
41.46 77 | /         LOCAL_RECORDER.with(|local_recorder| {                                                                                                                               
41.46 78 | |             local_recorder.set(Some(recorder_ptr));                                                                                                                          
41.46 79 | |         });                                                                                                                                                                  
41.46    | |          ^                                                                                                                                                                   
41.46    | |          |                                                                                                                                                                   
41.46    | |__________`recorder` escapes the associated function body here                                                                                                                
41.46    |            argument requires that `'1` must outlive `'static`                                                                                                                  
41.46    |                                                                                                                                                                                
41.46 note: raw pointer casts of trait objects cannot extend lifetimes                                                                                                                    
41.46   --> /root/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metrics-0.23.0/src/recorder/mod.rs:75:60                                                                           
41.46    |                                                                                                                                                                                
41.46 75 |         let recorder_ptr = unsafe { NonNull::new_unchecked(recorder as *const _ as *mut _) };                                                                                  
41.46    |                                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                      
41.46    = note: this was previously accepted by the compiler but was changed recently                                                                                                    
41.46    = help: see <https://github.com/rust-lang/rust/issues/141402> for more information                                                                                               
41.46                                                                                                                                                                                     
41.46 For more information about this error, try `rustc --explain E0521`.                                                                                                                 
41.47 error: could not compile `metrics` (lib) due to 1 previous error                                                                                                                    
41.47 warning: build failed, waiting for other jobs to finish...                                                                                                                          332.4                                                                                                                                                                                     332.4 thread 'main' (7) panicked at /root/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/cargo-chef-0.1.73/src/recipe.rs:218:27:                                                    
332.4 Exited with status code: 101                                                                                                                                                        
332.4 note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace                                                                                                       
------                                                                                                                                                                                    
Dockerfile-cuda:82                                                                                                                                                                        
--------------------                                                                                                                                                                      
  81 |                                                                                                                                                                                    
  82 | >>> RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \                                                                                                       
  83 | >>>     --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \                                                                                                   
  84 | >>>     if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \                                                                                                         
  85 | >>>     then \                                                                                                                                                                     
  86 | >>>     cargo chef cook --release --features candle-cuda-turing --features static-linking --no-default-features --recipe-path recipe.json && sccache -s; \                         
  87 | >>>     else \                                                                                                                                                                     
  88 | >>>     cargo chef cook --release --features candle-cuda --features static-linking --no-default-features --recipe-path recipe.json && sccache -s; \                                
  89 | >>>     fi;                                                                                                                                                                        
  90 |                                                                                                                                                                                    
--------------------                                                                                                                                                                      
ERROR: failed to build: failed to solve: process "/bin/sh -c if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ];     then     cargo chef cook --release --features candle-cud
a-turing --features static-linking --no-default-features --recipe-path recipe.json && sccache -s;     else     cargo chef cook --release --features candle-cuda --features static-linking 
--no-default-features --recipe-path recipe.json && sccache -s;     fi;" did not complete successfully: exit code: 101

After searching a bit, I found out that this #842 PR should fix it. So I applied these changes and the build finished without any errors. So I guess only a rebase is needed.

@nazq
Copy link
Author

nazq commented Mar 15, 2026

Great. Thanks for this i didn't buy a Spark till i knew we could get this PR in. Happy to rebase it

- Add Dockerfile-cuda supporting both x86_64 and ARM64 (aarch64)
- Add sm_121 compute capability for NVIDIA GB10 (DGX Spark)
- Add cpu-arm64 image variant
- Update supported hardware documentation

Co-Authored-By: z4y4ts <z4y4ts@users.noreply.github.com>
@nazq nazq force-pushed the feat/arm64-cuda-blackwell branch from a9395f8 to ad55ed2 Compare March 15, 2026 13:24
@nazq
Copy link
Author

nazq commented Mar 15, 2026

Hey @stefan-it — rebased onto upstream main, which now includes #842. Should fix the metrics crate build failure you hit. Let me know if it works on your Spark!

@stefan-it
Copy link

Hi @nazq many thanks! I did a fresh clone of the rebased branch and built it with:

docker build . -f Dockerfile-cuda --no-cache   --build-arg CUDA_COMPUTE_CAP=121   --platform linux/arm64 -t text-embeddings-inference:121-1.9-pr

result was:

[+] Building 895.2s (32/32) FINISHED                                                                                                                                       docker:default
 => [internal] load build definition from Dockerfile-cuda                                                                                                                            0.0s
 => => transferring dockerfile: 6.46kB                                                                                                                                               0.0s
 => [internal] load metadata for docker.io/nvidia/cuda:12.9.1-runtime-ubuntu24.04                                                                                                    0.2s
 => [internal] load metadata for docker.io/nvidia/cuda:12.9.1-devel-ubuntu24.04                                                                                                      0.2s
 => [internal] load .dockerignore                                                                                                                                                    0.0s
 => => transferring context: 53B                                                                                                                                                     0.0s
 => [internal] load build context                                                                                                                                                    0.0s
 => => transferring context: 17.28kB                                                                                                                                                 0.0s
 => CACHED [base-builder 1/6] FROM docker.io/nvidia/cuda:12.9.1-devel-ubuntu24.04@sha256:020bc241a628776338f4d4053fed4c38f6f7f3d7eb5919fecb8de313bb8ba47c                            0.0s
 => CACHED [base 1/3] FROM docker.io/nvidia/cuda:12.9.1-runtime-ubuntu24.04@sha256:1287141d283b8f06f45681b56a48a85791398c615888b1f96bfb9fc981392d98                                  0.0s
 => [base-builder 2/6] RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends     curl     libssl-dev     pkg-config     && rm -rf /var/l  22.1s
 => [base 2/3] RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends     ca-certificates     libssl-dev     curl     cuda-compat-12-9     19.6s
 => [base 3/3] COPY --chmod=775 cuda-entrypoint.sh entrypoint.sh                                                                                                                     0.0s
 => [base-builder 3/6] RUN case "arm64" in     "amd64") SCCACHE_ARCH=x86_64-unknown-linux-musl ;;     "arm64") SCCACHE_ARCH=aarch64-unknown-linux-musl ;;     *) echo "Unsupported   2.9s 
 => [base-builder 4/6] COPY rust-toolchain.toml rust-toolchain.toml                                                                                                                  0.0s 
 => [base-builder 5/6] RUN curl https://sh.rustup.rs -sSf | bash -s -- -y                                                                                                           32.5s 
 => [base-builder 6/6] RUN cargo install cargo-chef --version 0.1.73 --locked                                                                                                       49.9s 
 => [planner 1/7] WORKDIR /usr/src                                                                                                                                                   0.0s 
 => [builder 2/9] RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL     --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN     if [ 121 -g  0.6s 
 => [planner 2/7] COPY backends backends                                                                                                                                             0.1s 
 => [planner 3/7] COPY core core                                                                                                                                                     0.1s 
 => [planner 4/7] COPY router router                                                                                                                                                 0.1s 
 => [planner 5/7] COPY Cargo.toml ./                                                                                                                                                 0.1s 
 => [planner 6/7] COPY Cargo.lock ./                                                                                                                                                 0.1s 
 => [planner 7/7] RUN cargo chef prepare  --recipe-path recipe.json                                                                                                                  0.2s
 => [builder 3/9] COPY --from=planner /usr/src/recipe.json recipe.json                                                                                                               0.1s
 => [builder 4/9] RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL     --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN     if [ 121   360.7s
 => [builder 5/9] COPY backends backends                                                                                                                                             0.1s 
 => [builder 6/9] COPY core core                                                                                                                                                     0.1s 
 => [builder 7/9] COPY router router                                                                                                                                                 0.1s 
 => [builder 8/9] COPY Cargo.toml ./                                                                                                                                                 0.1s 
 => [builder 9/9] COPY Cargo.lock ./                                                                                                                                                 0.1s 
 => [http-builder 1/1] RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL     --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN     if [  423.1s 
 => [stage-7 1/1] COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router                                                      0.4s 
 => exporting to image                                                                                                                                                               1.5s 
 => => exporting layers                                                                                                                                                              1.4s 
 => => writing image sha256:2018875deaebfac387abad481f0f2bb7979853ad2b607297aa8bdba5b1d67ef4                                                                                         0.0s 
 => => naming to docker.io/library/text-embeddings-inference:121-1.9-pr

So definitely working on a Spark 🥳

@nazq
Copy link
Author

nazq commented Mar 16, 2026

I'll put my order in then ;-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ARM64 Support

3 participants