From ebbe22930ed32502d557e52f5cb39cc1a8b7cbe9 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Tue, 8 Oct 2024 14:11:43 -0500 Subject: [PATCH] [3.13] Tee of tee was not producing n independent iterators (gh-123884) (gh-125081) --- Doc/library/itertools.rst | 49 +++-- Lib/test/test_itertools.py | 171 +++++++++++++++++- ...-09-24-22-38-51.gh-issue-123884.iEPTK4.rst | 4 + Modules/itertoolsmodule.c | 37 +--- 4 files changed, 214 insertions(+), 47 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 3fab46c3c0a5b4..047d805eda628a 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -676,24 +676,37 @@ loops that truncate the stream. Roughly equivalent to:: def tee(iterable, n=2): - iterator = iter(iterable) - shared_link = [None, None] - return tuple(_tee(iterator, shared_link) for _ in range(n)) - - def _tee(iterator, link): - try: - while True: - if link[1] is None: - link[0] = next(iterator) - link[1] = [None, None] - value, link = link - yield value - except StopIteration: - return - - Once a :func:`tee` has been created, the original *iterable* should not be - used anywhere else; otherwise, the *iterable* could get advanced without - the tee objects being informed. + if n < 0: + raise ValueError + if n == 0: + return () + iterator = _tee(iterable) + result = [iterator] + for _ in range(n - 1): + result.append(_tee(iterator)) + return tuple(result) + + class _tee: + + def __init__(self, iterable): + it = iter(iterable) + if isinstance(it, _tee): + self.iterator = it.iterator + self.link = it.link + else: + self.iterator = it + self.link = [None, None] + + def __iter__(self): + return self + + def __next__(self): + link = self.link + if link[1] is None: + link[0] = next(self.iterator) + link[1] = [None, None] + value, self.link = link + return value ``tee`` iterators are not threadsafe. A :exc:`RuntimeError` may be raised when simultaneously using iterators returned by the same :func:`tee` diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 3d20e70fc1b63f..b6404f4366ca0e 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1612,10 +1612,11 @@ def test_tee(self): self.assertEqual(len(result), n) self.assertEqual([list(x) for x in result], [list('abc')]*n) - # tee pass-through to copyable iterator + # tee objects are independent (see bug gh-123884) a, b = tee('abc') c, d = tee(a) - self.assertTrue(a is c) + e, f = tee(c) + self.assertTrue(len({a, b, c, d, e, f}) == 6) # test tee_new t1, t2 = tee('abc') @@ -2029,6 +2030,172 @@ def test_islice_recipe(self): self.assertEqual(next(c), 3) + def test_tee_recipe(self): + + # Begin tee() recipe ########################################### + + def tee(iterable, n=2): + if n < 0: + raise ValueError + if n == 0: + return () + iterator = _tee(iterable) + result = [iterator] + for _ in range(n - 1): + result.append(_tee(iterator)) + return tuple(result) + + class _tee: + + def __init__(self, iterable): + it = iter(iterable) + if isinstance(it, _tee): + self.iterator = it.iterator + self.link = it.link + else: + self.iterator = it + self.link = [None, None] + + def __iter__(self): + return self + + def __next__(self): + link = self.link + if link[1] is None: + link[0] = next(self.iterator) + link[1] = [None, None] + value, self.link = link + return value + + # End tee() recipe ############################################# + + n = 200 + + a, b = tee([]) # test empty iterator + self.assertEqual(list(a), []) + self.assertEqual(list(b), []) + + a, b = tee(irange(n)) # test 100% interleaved + self.assertEqual(lzip(a,b), lzip(range(n), range(n))) + + a, b = tee(irange(n)) # test 0% interleaved + self.assertEqual(list(a), list(range(n))) + self.assertEqual(list(b), list(range(n))) + + a, b = tee(irange(n)) # test dealloc of leading iterator + for i in range(100): + self.assertEqual(next(a), i) + del a + self.assertEqual(list(b), list(range(n))) + + a, b = tee(irange(n)) # test dealloc of trailing iterator + for i in range(100): + self.assertEqual(next(a), i) + del b + self.assertEqual(list(a), list(range(100, n))) + + for j in range(5): # test randomly interleaved + order = [0]*n + [1]*n + random.shuffle(order) + lists = ([], []) + its = tee(irange(n)) + for i in order: + value = next(its[i]) + lists[i].append(value) + self.assertEqual(lists[0], list(range(n))) + self.assertEqual(lists[1], list(range(n))) + + # test argument format checking + self.assertRaises(TypeError, tee) + self.assertRaises(TypeError, tee, 3) + self.assertRaises(TypeError, tee, [1,2], 'x') + self.assertRaises(TypeError, tee, [1,2], 3, 'x') + + # tee object should be instantiable + a, b = tee('abc') + c = type(a)('def') + self.assertEqual(list(c), list('def')) + + # test long-lagged and multi-way split + a, b, c = tee(range(2000), 3) + for i in range(100): + self.assertEqual(next(a), i) + self.assertEqual(list(b), list(range(2000))) + self.assertEqual([next(c), next(c)], list(range(2))) + self.assertEqual(list(a), list(range(100,2000))) + self.assertEqual(list(c), list(range(2,2000))) + + # test invalid values of n + self.assertRaises(TypeError, tee, 'abc', 'invalid') + self.assertRaises(ValueError, tee, [], -1) + + for n in range(5): + result = tee('abc', n) + self.assertEqual(type(result), tuple) + self.assertEqual(len(result), n) + self.assertEqual([list(x) for x in result], [list('abc')]*n) + + # tee objects are independent (see bug gh-123884) + a, b = tee('abc') + c, d = tee(a) + e, f = tee(c) + self.assertTrue(len({a, b, c, d, e, f}) == 6) + + # test tee_new + t1, t2 = tee('abc') + tnew = type(t1) + self.assertRaises(TypeError, tnew) + self.assertRaises(TypeError, tnew, 10) + t3 = tnew(t1) + self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc')) + + # test that tee objects are weak referencable + a, b = tee(range(10)) + p = weakref.proxy(a) + self.assertEqual(getattr(p, '__class__'), type(b)) + del a + gc.collect() # For PyPy or other GCs. + self.assertRaises(ReferenceError, getattr, p, '__class__') + + ans = list('abc') + long_ans = list(range(10000)) + + # Tests not applicable to the tee() recipe + if False: + # check copy + a, b = tee('abc') + self.assertEqual(list(copy.copy(a)), ans) + self.assertEqual(list(copy.copy(b)), ans) + a, b = tee(list(range(10000))) + self.assertEqual(list(copy.copy(a)), long_ans) + self.assertEqual(list(copy.copy(b)), long_ans) + + # check partially consumed copy + a, b = tee('abc') + take(2, a) + take(1, b) + self.assertEqual(list(copy.copy(a)), ans[2:]) + self.assertEqual(list(copy.copy(b)), ans[1:]) + self.assertEqual(list(a), ans[2:]) + self.assertEqual(list(b), ans[1:]) + a, b = tee(range(10000)) + take(100, a) + take(60, b) + self.assertEqual(list(copy.copy(a)), long_ans[100:]) + self.assertEqual(list(copy.copy(b)), long_ans[60:]) + self.assertEqual(list(a), long_ans[100:]) + self.assertEqual(list(b), long_ans[60:]) + + # Issue 13454: Crash when deleting backward iterator from tee() + forward, backward = tee(repeat(None, 2000)) # 20000000 + try: + any(forward) # exhaust the iterator + del backward + except: + del forward, backward + raise + + class TestGC(unittest.TestCase): def makecycle(self, iterator, container): diff --git a/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst b/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst new file mode 100644 index 00000000000000..55f1d4b41125c3 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst @@ -0,0 +1,4 @@ +Fixed bug in itertools.tee() handling of other tee inputs (a tee in a tee). +The output now has the promised *n* independent new iterators. Formerly, +the first iterator was identical (not independent) to the input iterator. +This would sometimes give surprising results. diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index d42f9dd0768658..e87c753113563f 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1137,7 +1137,7 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n) /*[clinic end generated code: output=1c64519cd859c2f0 input=c99a1472c425d66d]*/ { Py_ssize_t i; - PyObject *it, *copyable, *copyfunc, *result; + PyObject *it, *to, *result; if (n < 0) { PyErr_SetString(PyExc_ValueError, "n must be >= 0"); @@ -1154,41 +1154,24 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n) return NULL; } - if (_PyObject_LookupAttr(it, &_Py_ID(__copy__), ©func) < 0) { - Py_DECREF(it); + (void)&_Py_ID(__copy__); // Retain a reference to __copy__ + itertools_state *state = get_module_state(module); + to = tee_fromiterable(state, it); + Py_DECREF(it); + if (to == NULL) { Py_DECREF(result); return NULL; } - if (copyfunc != NULL) { - copyable = it; - } - else { - itertools_state *state = get_module_state(module); - copyable = tee_fromiterable(state, it); - Py_DECREF(it); - if (copyable == NULL) { - Py_DECREF(result); - return NULL; - } - copyfunc = PyObject_GetAttr(copyable, &_Py_ID(__copy__)); - if (copyfunc == NULL) { - Py_DECREF(copyable); - Py_DECREF(result); - return NULL; - } - } - PyTuple_SET_ITEM(result, 0, copyable); + PyTuple_SET_ITEM(result, 0, to); for (i = 1; i < n; i++) { - copyable = _PyObject_CallNoArgs(copyfunc); - if (copyable == NULL) { - Py_DECREF(copyfunc); + to = tee_copy((teeobject *)to, NULL); + if (to == NULL) { Py_DECREF(result); return NULL; } - PyTuple_SET_ITEM(result, i, copyable); + PyTuple_SET_ITEM(result, i, to); } - Py_DECREF(copyfunc); return result; }