|
44 | 44 | import kernels as kernels_pkg |
45 | 45 | from kernels import Device, Mode, kernelize |
46 | 46 |
|
| 47 | + import transformers.integrations.hub_kernels as hub_kernels_pkg |
| 48 | + |
47 | 49 |
|
48 | 50 | @require_kernels |
49 | 51 | @slow |
@@ -95,6 +97,7 @@ def test_forward(self): |
95 | 97 |
|
96 | 98 | self.EXPECTED_OUTPUT = set() |
97 | 99 | self.EXPECTED_OUTPUT.add("Hello, I'm looking for a reliable and trustworthy online") |
| 100 | + self.EXPECTED_OUTPUT.add("Hello! I'm excited to be a part of this") |
98 | 101 |
|
99 | 102 | self.assertTrue(output in self.EXPECTED_OUTPUT) |
100 | 103 |
|
@@ -249,7 +252,7 @@ def fake_get_kernel(repo_id, revision=None, version=None): |
249 | 252 | self.assertIn(repo_id, {"kernels-community/causal-conv1d"}) |
250 | 253 | return sentinel |
251 | 254 |
|
252 | | - setattr(kernels_pkg, "get_kernel", fake_get_kernel) |
| 255 | + setattr(hub_kernels_pkg, "get_kernel", fake_get_kernel) |
253 | 256 | _KERNEL_MODULE_MAPPING.pop("causal-conv1d", None) |
254 | 257 |
|
255 | 258 | mod1 = lazy_load_kernel("causal-conv1d") |
@@ -286,15 +289,15 @@ def test_lazy_load_kernel_version(self): |
286 | 289 | HUB[name] = {"repo_id": "kernels-community/causal-conv1d", "version": version_spec} # type: ignore[assignment] |
287 | 290 | _KERNEL_MODULE_MAPPING.pop(name, None) |
288 | 291 |
|
289 | | - def fake_get_kernel(repo_id, revision=None, version=None, user_agent=None): |
| 292 | + def fake_get_kernel(repo_id, revision=None, version=None): |
290 | 293 | call_count["n"] += 1 |
291 | 294 | self.assertEqual(repo_id, "kernels-community/causal-conv1d") |
292 | 295 | self.assertIsNone(revision, "revision must not be set when version is provided") |
293 | 296 | self.assertEqual(version, version_spec) |
294 | 297 | return sentinel_mod |
295 | 298 |
|
296 | 299 | # Patch kernels.get_kernel so lazy_load_kernel picks it up on import |
297 | | - setattr(kernels_pkg, "get_kernel", fake_get_kernel) |
| 300 | + setattr(hub_kernels_pkg, "get_kernel", fake_get_kernel) |
298 | 301 |
|
299 | 302 | # Act |
300 | 303 | mod1 = lazy_load_kernel(name) |
|
0 commit comments