Skip to content

Commit f94b0c5

Browse files
authored
Use deformable_detr kernel from the Hub (#36853)
* Use `deformable_detr` kernel from the Hub Remove the `deformable_detr` kernel from `kernels/` and use the pre-built kernel from the Hub instead. * Add license header * Add `kernels` as an extra `hub-kernels` Also add it to `testing`, so that the kernel replacement gets tested when using CUDA in CI.
1 parent 2638d54 commit f94b0c5

21 files changed

+405
-3834
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
130130
"keras>2.9,<2.16",
131131
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
132+
"kernels>=0.3.2,<0.4",
132133
"librosa",
133134
"natten>=0.14.6,<0.15.0",
134135
"nltk<=3.8.1",
@@ -301,8 +302,9 @@ def run(self):
301302
extras["optuna"] = deps_list("optuna")
302303
extras["ray"] = deps_list("ray[tune]")
303304
extras["sigopt"] = deps_list("sigopt")
305+
extras["hub-kernels"] = deps_list("kernels")
304306

305-
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
307+
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
306308

307309
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
308310
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"kenlm": "kenlm",
3636
"keras": "keras>2.9,<2.16",
3737
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
38+
"kernels": "kernels>=0.3.2,<0.4",
3839
"librosa": "librosa",
3940
"natten": "natten>=0.14.6,<0.15.0",
4041
"nltk": "nltk<=3.8.1",

src/transformers/integrations/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@
7070
"replace_with_higgs_linear",
7171
],
7272
"hqq": ["prepare_for_hqq_linear"],
73+
"hub_kernels": [
74+
"LayerRepository",
75+
"register_kernel_mapping",
76+
"replace_kernel_forward_from_hub",
77+
"use_kernel_forward_from_hub",
78+
],
7379
"integration_utils": [
7480
"INTEGRATION_TO_CALLBACK",
7581
"AzureMLCallback",
@@ -198,6 +204,12 @@
198204
)
199205
from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
200206
from .hqq import prepare_for_hqq_linear
207+
from .hub_kernels import (
208+
LayerRepository,
209+
register_kernel_mapping,
210+
replace_kernel_forward_from_hub,
211+
use_kernel_forward_from_hub,
212+
)
201213
from .integration_utils import (
202214
INTEGRATION_TO_CALLBACK,
203215
AzureMLCallback,
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Dict, Union
15+
16+
17+
try:
18+
from kernels import (
19+
Device,
20+
LayerRepository,
21+
register_kernel_mapping,
22+
replace_kernel_forward_from_hub,
23+
use_kernel_forward_from_hub,
24+
)
25+
26+
_hub_kernels_available = True
27+
28+
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
29+
"MultiScaleDeformableAttention": {
30+
"cuda": LayerRepository(
31+
repo_id="kernels-community/deformable-detr",
32+
layer_name="MultiScaleDeformableAttention",
33+
)
34+
}
35+
}
36+
37+
register_kernel_mapping(_KERNEL_MAPPING)
38+
39+
except ImportError:
40+
# Stub to make decorators int transformers work when `kernels`
41+
# is not installed.
42+
def use_kernel_forward_from_hub(*args, **kwargs):
43+
def decorator(cls):
44+
return cls
45+
46+
return decorator
47+
48+
class LayerRepository:
49+
def __init__(self, *args, **kwargs):
50+
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
51+
52+
def replace_kernel_forward_from_hub(*args, **kwargs):
53+
raise RuntimeError(
54+
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
55+
)
56+
57+
def register_kernel_mapping(*args, **kwargs):
58+
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
59+
60+
_hub_kernels_available = False
61+
62+
63+
def is_hub_kernels_available():
64+
return _hub_kernels_available
65+
66+
67+
__all__ = [
68+
"LayerRepository",
69+
"is_hub_kernels_available",
70+
"use_kernel_forward_from_hub",
71+
"register_kernel_mapping",
72+
"replace_kernel_forward_from_hub",
73+
]

src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp

Lines changed: 0 additions & 40 deletions
This file was deleted.

src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h

Lines changed: 0 additions & 32 deletions
This file was deleted.

src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu

Lines changed: 0 additions & 159 deletions
This file was deleted.

0 commit comments

Comments
 (0)