1616import copy
1717import gc
1818import types
19- import unittest
2019from unittest .mock import patch
2120
2221from transformers import AutoModelForCausalLM , AutoTokenizer , KernelConfig
2322from transformers .integrations .hub_kernels import (
24- _KERNEL_MODULE_MAPPING ,
2523 _HUB_KERNEL_MAPPING ,
24+ _KERNEL_MODULE_MAPPING ,
2625 is_kernel ,
2726 lazy_load_kernel ,
2827 load_and_register_attn_kernel ,
2928)
29+ from transformers .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
3030from transformers .modeling_utils import ALL_ATTENTION_FUNCTIONS
3131from transformers .testing_utils import (
3232 TestCasePlus ,
@@ -57,6 +57,29 @@ def setUp(self):
5757 self .input = "Hello"
5858
5959 def tearDown (self ):
60+ # Delete large objects to drop references early
61+ for attr in [
62+ "model_kernelized" ,
63+ "model_not_kernelized" ,
64+ "tokenizer" ,
65+ ]:
66+ if hasattr (self , attr ):
67+ try :
68+ delattr (self , attr )
69+ except Exception :
70+ pass
71+
72+ # Clear any temporary kernel module cache entries populated by tests
73+ try :
74+ keys_to_remove = [
75+ k for k , v in list (_KERNEL_MODULE_MAPPING .items ()) if v is None or isinstance (v , types .ModuleType )
76+ ]
77+ for k in keys_to_remove :
78+ _KERNEL_MODULE_MAPPING .pop (k , None )
79+ except Exception :
80+ pass
81+
82+ # Free accelerator memory/cache and trigger GC
6083 gc .collect ()
6184 backend_empty_cache (torch_device )
6285 gc .collect ()
@@ -130,7 +153,7 @@ def test_kernelized_forward_is_the_same(self, model_1, model_2):
130153
131154 def test_kernelize (self ):
132155 model = copy .deepcopy (self .model_not_kernelized )
133- kernelize (model , mode = Mode .INFERENCE , device = Device (type = model .device .type ))
156+ kernelize (model , mode = Mode .INFERENCE , device = Device (type = model .device .type )) # type: ignore[arg-type]
134157 self .test_kernelized_forward_is_different (model , self .model_not_kernelized )
135158 self .test_kernelized_forward_is_the_same (model , self .model_kernelized )
136159 del model
@@ -232,13 +255,17 @@ def fake_get_kernel(repo_id, revision=None):
232255 self .assertIs (mod2 , sentinel )
233256 finally :
234257 setattr (kernels_pkg , "get_kernel" , original_get_kernel )
258+ # Ensure cache is cleared to avoid holding onto module references across tests
259+ _KERNEL_MODULE_MAPPING .pop ("causal-conv1d" , None )
235260
236261 def test_lazy_load_kernel_unknown (self ):
237262 name = "unknown-kernel-name"
238263 _KERNEL_MODULE_MAPPING .pop (name , None )
239264 mod = lazy_load_kernel (name )
240265 self .assertIsNone (mod )
241266 self .assertIn (name , _KERNEL_MODULE_MAPPING )
267+ # Cleanup cache entry to avoid growth across tests
268+ _KERNEL_MODULE_MAPPING .pop (name , None )
242269
243270 def test_lazy_load_kernel_version (self ):
244271 HUB = _HUB_KERNEL_MAPPING
@@ -253,7 +280,7 @@ def test_lazy_load_kernel_version(self):
253280
254281 try :
255282 # Inject dict-style mapping with repo_id and version
256- HUB [name ] = {"repo_id" : "kernels-community/causal-conv1d" , "version" : version_spec }
283+ HUB [name ] = {"repo_id" : "kernels-community/causal-conv1d" , "version" : version_spec } # type: ignore[assignment]
257284 _KERNEL_MODULE_MAPPING .pop (name , None )
258285
259286 def fake_get_kernel (repo_id , revision = None , version = None , user_agent = None ):
@@ -283,6 +310,7 @@ def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None):
283310 HUB [name ] = original_entry
284311 _KERNEL_MODULE_MAPPING .pop (name , None )
285312
313+
286314@require_kernels
287315class TestAttentionKernelRegistration (TestCasePlus ):
288316 def test_load_and_register_flash_attn_like_kernel (self ):
@@ -295,6 +323,15 @@ def test_load_and_register_flash_attn_like_kernel(self):
295323 attn_impl = "org/model"
296324 load_and_register_attn_kernel (attn_impl )
297325 self .assertIn (attn_impl , ALL_ATTENTION_FUNCTIONS .valid_keys ())
326+ # Cleanup registration to avoid leaking functions across tests
327+ try :
328+ ALL_ATTENTION_FUNCTIONS .pop (attn_impl , None )
329+ except Exception :
330+ pass
331+ try :
332+ ALL_MASK_ATTENTION_FUNCTIONS .pop (attn_impl , None )
333+ except Exception :
334+ pass
298335
299336 def test_load_and_register_named_function_kernel (self ):
300337 def my_attention (* args , ** kwargs ):
@@ -305,6 +342,15 @@ def my_attention(*args, **kwargs):
305342 attn_impl = "org/model:my_func"
306343 load_and_register_attn_kernel (attn_impl )
307344 self .assertIn (attn_impl , ALL_ATTENTION_FUNCTIONS .valid_keys ())
345+ # Cleanup registration to avoid leaking functions across tests
346+ try :
347+ ALL_ATTENTION_FUNCTIONS .pop (attn_impl , None )
348+ except Exception :
349+ pass
350+ try :
351+ ALL_MASK_ATTENTION_FUNCTIONS .pop (attn_impl , None )
352+ except Exception :
353+ pass
308354
309355
310356@require_kernels
@@ -314,6 +360,13 @@ def setUp(self):
314360 self .model = AutoModelForCausalLM .from_pretrained (self .model_id , use_kernels = False , device_map = torch_device )
315361
316362 def tearDown (self ):
363+ # Delete large objects to drop references early
364+ if hasattr (self , "model" ):
365+ try :
366+ del self .model
367+ except Exception :
368+ pass
369+ # Free accelerator memory/cache and trigger GC
317370 gc .collect ()
318371 backend_empty_cache (torch_device )
319372 gc .collect ()
0 commit comments