Skip to content

Commit 4eda7fe

Browse files
committed
update tests
1 parent 32310d8 commit 4eda7fe

File tree

1 file changed

+57
-4
lines changed

1 file changed

+57
-4
lines changed

tests/kernels/test_kernels.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
import copy
1717
import gc
1818
import types
19-
import unittest
2019
from unittest.mock import patch
2120

2221
from transformers import AutoModelForCausalLM, AutoTokenizer, KernelConfig
2322
from transformers.integrations.hub_kernels import (
24-
_KERNEL_MODULE_MAPPING,
2523
_HUB_KERNEL_MAPPING,
24+
_KERNEL_MODULE_MAPPING,
2625
is_kernel,
2726
lazy_load_kernel,
2827
load_and_register_attn_kernel,
2928
)
29+
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
3030
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
3131
from transformers.testing_utils import (
3232
TestCasePlus,
@@ -57,6 +57,29 @@ def setUp(self):
5757
self.input = "Hello"
5858

5959
def tearDown(self):
60+
# Delete large objects to drop references early
61+
for attr in [
62+
"model_kernelized",
63+
"model_not_kernelized",
64+
"tokenizer",
65+
]:
66+
if hasattr(self, attr):
67+
try:
68+
delattr(self, attr)
69+
except Exception:
70+
pass
71+
72+
# Clear any temporary kernel module cache entries populated by tests
73+
try:
74+
keys_to_remove = [
75+
k for k, v in list(_KERNEL_MODULE_MAPPING.items()) if v is None or isinstance(v, types.ModuleType)
76+
]
77+
for k in keys_to_remove:
78+
_KERNEL_MODULE_MAPPING.pop(k, None)
79+
except Exception:
80+
pass
81+
82+
# Free accelerator memory/cache and trigger GC
6083
gc.collect()
6184
backend_empty_cache(torch_device)
6285
gc.collect()
@@ -130,7 +153,7 @@ def test_kernelized_forward_is_the_same(self, model_1, model_2):
130153

131154
def test_kernelize(self):
132155
model = copy.deepcopy(self.model_not_kernelized)
133-
kernelize(model, mode=Mode.INFERENCE, device=Device(type=model.device.type))
156+
kernelize(model, mode=Mode.INFERENCE, device=Device(type=model.device.type)) # type: ignore[arg-type]
134157
self.test_kernelized_forward_is_different(model, self.model_not_kernelized)
135158
self.test_kernelized_forward_is_the_same(model, self.model_kernelized)
136159
del model
@@ -232,13 +255,17 @@ def fake_get_kernel(repo_id, revision=None):
232255
self.assertIs(mod2, sentinel)
233256
finally:
234257
setattr(kernels_pkg, "get_kernel", original_get_kernel)
258+
# Ensure cache is cleared to avoid holding onto module references across tests
259+
_KERNEL_MODULE_MAPPING.pop("causal-conv1d", None)
235260

236261
def test_lazy_load_kernel_unknown(self):
237262
name = "unknown-kernel-name"
238263
_KERNEL_MODULE_MAPPING.pop(name, None)
239264
mod = lazy_load_kernel(name)
240265
self.assertIsNone(mod)
241266
self.assertIn(name, _KERNEL_MODULE_MAPPING)
267+
# Cleanup cache entry to avoid growth across tests
268+
_KERNEL_MODULE_MAPPING.pop(name, None)
242269

243270
def test_lazy_load_kernel_version(self):
244271
HUB = _HUB_KERNEL_MAPPING
@@ -253,7 +280,7 @@ def test_lazy_load_kernel_version(self):
253280

254281
try:
255282
# Inject dict-style mapping with repo_id and version
256-
HUB[name] = {"repo_id": "kernels-community/causal-conv1d", "version": version_spec}
283+
HUB[name] = {"repo_id": "kernels-community/causal-conv1d", "version": version_spec} # type: ignore[assignment]
257284
_KERNEL_MODULE_MAPPING.pop(name, None)
258285

259286
def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None):
@@ -283,6 +310,7 @@ def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None):
283310
HUB[name] = original_entry
284311
_KERNEL_MODULE_MAPPING.pop(name, None)
285312

313+
286314
@require_kernels
287315
class TestAttentionKernelRegistration(TestCasePlus):
288316
def test_load_and_register_flash_attn_like_kernel(self):
@@ -295,6 +323,15 @@ def test_load_and_register_flash_attn_like_kernel(self):
295323
attn_impl = "org/model"
296324
load_and_register_attn_kernel(attn_impl)
297325
self.assertIn(attn_impl, ALL_ATTENTION_FUNCTIONS.valid_keys())
326+
# Cleanup registration to avoid leaking functions across tests
327+
try:
328+
ALL_ATTENTION_FUNCTIONS.pop(attn_impl, None)
329+
except Exception:
330+
pass
331+
try:
332+
ALL_MASK_ATTENTION_FUNCTIONS.pop(attn_impl, None)
333+
except Exception:
334+
pass
298335

299336
def test_load_and_register_named_function_kernel(self):
300337
def my_attention(*args, **kwargs):
@@ -305,6 +342,15 @@ def my_attention(*args, **kwargs):
305342
attn_impl = "org/model:my_func"
306343
load_and_register_attn_kernel(attn_impl)
307344
self.assertIn(attn_impl, ALL_ATTENTION_FUNCTIONS.valid_keys())
345+
# Cleanup registration to avoid leaking functions across tests
346+
try:
347+
ALL_ATTENTION_FUNCTIONS.pop(attn_impl, None)
348+
except Exception:
349+
pass
350+
try:
351+
ALL_MASK_ATTENTION_FUNCTIONS.pop(attn_impl, None)
352+
except Exception:
353+
pass
308354

309355

310356
@require_kernels
@@ -314,6 +360,13 @@ def setUp(self):
314360
self.model = AutoModelForCausalLM.from_pretrained(self.model_id, use_kernels=False, device_map=torch_device)
315361

316362
def tearDown(self):
363+
# Delete large objects to drop references early
364+
if hasattr(self, "model"):
365+
try:
366+
del self.model
367+
except Exception:
368+
pass
369+
# Free accelerator memory/cache and trigger GC
317370
gc.collect()
318371
backend_empty_cache(torch_device)
319372
gc.collect()

0 commit comments

Comments
 (0)