Skip to content

Commit cf0f071

Browse files
authored
[kernels] Fix failling tests (#42953)
fix
1 parent b5eea34 commit cf0f071

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/kernels/test_kernels.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
import kernels as kernels_pkg
4545
from kernels import Device, Mode, kernelize
4646

47+
import transformers.integrations.hub_kernels as hub_kernels_pkg
48+
4749

4850
@require_kernels
4951
@slow
@@ -95,6 +97,7 @@ def test_forward(self):
9597

9698
self.EXPECTED_OUTPUT = set()
9799
self.EXPECTED_OUTPUT.add("Hello, I'm looking for a reliable and trustworthy online")
100+
self.EXPECTED_OUTPUT.add("Hello! I'm excited to be a part of this")
98101

99102
self.assertTrue(output in self.EXPECTED_OUTPUT)
100103

@@ -249,7 +252,7 @@ def fake_get_kernel(repo_id, revision=None, version=None):
249252
self.assertIn(repo_id, {"kernels-community/causal-conv1d"})
250253
return sentinel
251254

252-
setattr(kernels_pkg, "get_kernel", fake_get_kernel)
255+
setattr(hub_kernels_pkg, "get_kernel", fake_get_kernel)
253256
_KERNEL_MODULE_MAPPING.pop("causal-conv1d", None)
254257

255258
mod1 = lazy_load_kernel("causal-conv1d")
@@ -286,15 +289,15 @@ def test_lazy_load_kernel_version(self):
286289
HUB[name] = {"repo_id": "kernels-community/causal-conv1d", "version": version_spec} # type: ignore[assignment]
287290
_KERNEL_MODULE_MAPPING.pop(name, None)
288291

289-
def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None):
292+
def fake_get_kernel(repo_id, revision=None, version=None):
290293
call_count["n"] += 1
291294
self.assertEqual(repo_id, "kernels-community/causal-conv1d")
292295
self.assertIsNone(revision, "revision must not be set when version is provided")
293296
self.assertEqual(version, version_spec)
294297
return sentinel_mod
295298

296299
# Patch kernels.get_kernel so lazy_load_kernel picks it up on import
297-
setattr(kernels_pkg, "get_kernel", fake_get_kernel)
300+
setattr(hub_kernels_pkg, "get_kernel", fake_get_kernel)
298301

299302
# Act
300303
mod1 = lazy_load_kernel(name)

0 commit comments

Comments
 (0)