diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index abbd50a47f395f..e56ce89c9e52bb 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -13,7 +13,6 @@ import typing import unittest import unittest.mock -import os import weakref import gc from weakref import proxy @@ -21,13 +20,13 @@ from test.support import import_helper from test.support import threading_helper -from test.support.script_helper import assert_python_ok import functools py_functools = import_helper.import_fresh_module('functools', blocked=['_functools']) -c_functools = import_helper.import_fresh_module('functools') +c_functools = import_helper.import_fresh_module('functools', + fresh=['_functools']) decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal']) @@ -202,7 +201,10 @@ def test_repr(self): kwargs = {'a': object(), 'b': object()} kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 'b={b!r}, a={a!r}'.format_map(kwargs)] - if self.partial in (c_functools.partial, py_functools.partial): + if self.partial in ( + getattr(c_functools, 'partial', object()), + py_functools.partial, + ): name = 'functools.partial' else: name = self.partial.__name__ @@ -224,7 +226,10 @@ def test_repr(self): for kwargs_repr in kwargs_reprs]) def test_recursive_repr(self): - if self.partial in (c_functools.partial, py_functools.partial): + if self.partial in ( + getattr(c_functools, 'partial', object()), + py_functools.partial, + ): name = 'functools.partial' else: name = self.partial.__name__ @@ -390,11 +395,13 @@ class TestPartialC(TestPartial, unittest.TestCase): if c_functools: partial = c_functools.partial - class AllowPickle: - def __enter__(self): - return self - def __exit__(self, type, value, tb): - return False + class AllowPickle: + def __init__(self): + self._cm = replaced_module("functools", c_functools) + def __enter__(self): + return self._cm.__enter__() + def __exit__(self, type, value, tb): + return self._cm.__exit__(type, value, tb) def test_attributes_unwritable(self): # attributes should not be writable @@ -1857,9 +1864,10 @@ def test_staticmethod(x): def py_cached_func(x, y): return 3 * x + y -@c_functools.lru_cache() -def c_cached_func(x, y): - return 3 * x + y +if c_functools: + @c_functools.lru_cache() + def c_cached_func(x, y): + return 3 * x + y class TestLRUPy(TestLRU, unittest.TestCase): @@ -1876,18 +1884,20 @@ def cached_staticmeth(x, y): return 3 * x + y +@unittest.skipUnless(c_functools, 'requires the C _functools module') class TestLRUC(TestLRU, unittest.TestCase): - module = c_functools - cached_func = c_cached_func, + if c_functools: + module = c_functools + cached_func = c_cached_func, - @module.lru_cache() - def cached_meth(self, x, y): - return 3 * x + y + @module.lru_cache() + def cached_meth(self, x, y): + return 3 * x + y - @staticmethod - @module.lru_cache() - def cached_staticmeth(x, y): - return 3 * x + y + @staticmethod + @module.lru_cache() + def cached_staticmeth(x, y): + return 3 * x + y class TestSingleDispatch(unittest.TestCase):