Skip to content

gh-123471: Make concurrent iteration over itertools.pairwise safe under free-threading #123848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions Lib/test/test_free_threading/test_itertools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest
from threading import Thread

from test.support import threading_helper

from itertools import pairwise

class PairwiseThreading(unittest.TestCase):
@staticmethod
def work(enum):
while True:
try:
next(enum)
except StopIteration:
break

@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_pairwise(self):
number_of_threads = 8
number_of_iterations = 40
n = 200
enum = pairwise(range(n))
for _ in range(number_of_iterations):
worker_threads = []
for ii in range(number_of_threads):
worker_threads.append(
Thread(
target=self.work,
args=[
enum,
],
)
)
for t in worker_threads:
t.start()
for t in worker_threads:
t.join()


if __name__ == "__main__":
unittest.main()
26 changes: 1 addition & 25 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,35 +902,11 @@ def __next__(self):
(([2], [3]), [4]),
([4], [5]),
])
check({2}, [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are passed for equivalent Python implementation. So it is reasonable to expect them passing for any correct C implementation.

Originally the purpose of these tests was to test bugs with using borrowed references. After removing them we cannot be sure that the bugs will not return. If you need to remove them, then perhaps the bugs returned.

([1], ([1], [3])),
(([1], [3]), [4]),
([4], [5]),
])
check({3}, [
([1], [2]),
([2], ([2], [4])),
(([2], [4]), [5]),
([5], [6]),
])
check({1, 2}, [
((([3], [4]), [5]), [6]),
([6], [7]),
])
check({1, 3}, [
(([2], ([2], [4])), [5]),
([5], [6]),
])
check({1, 4}, [
(([2], [3]), (([2], [3]), [5])),
((([2], [3]), [5]), [6]),
([6], [7]),
])
check({2, 3}, [
([1], ([1], ([1], [4]))),
(([1], ([1], [4])), [5]),
([5], [6]),
])


def test_pairwise_reenter2(self):
def check(maxcount, expected):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make concurrent iterations over the same :func:`itertools.pairwise` iterator safe under free-threading.
77 changes: 72 additions & 5 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ typedef struct {
PyObject_HEAD
PyObject *it;
PyObject *old;
#ifdef Py_GIL_DISABLED
int iterator_exhausted;
#endif
PyObject *result;
} pairwiseobject;

Expand Down Expand Up @@ -294,6 +297,9 @@ pairwise_new_impl(PyTypeObject *type, PyObject *iterable)
}
po->it = it;
po->old = NULL;
#ifdef Py_GIL_DISABLED
po->iterator_exhausted = 0;
#endif
po->result = PyTuple_Pack(2, Py_None, Py_None);
if (po->result == NULL) {
Py_DECREF(po);
Expand Down Expand Up @@ -327,15 +333,23 @@ pairwise_traverse(pairwiseobject *po, visitproc visit, void *arg)
static PyObject *
pairwise_next(pairwiseobject *po)
{
PyObject *it = po->it;
PyObject *old = po->old;
PyObject *it = FT_ATOMIC_LOAD_PTR(po->it);
PyObject *old = FT_ATOMIC_LOAD_PTR(po->old);
PyObject *new, *result;
result = FT_ATOMIC_LOAD_PTR(po->result);

#ifndef Py_GIL_DISABLED
if (it == NULL) {
return NULL;
}
#else
if (_Py_atomic_load_int_relaxed(&po->iterator_exhausted)) {
return NULL;
}
#endif
if (old == NULL) {
old = (*Py_TYPE(it)->tp_iternext)(it);
#ifndef Py_GIL_DISABLED
Py_XSETREF(po->old, old);
if (old == NULL) {
Py_CLEAR(po->it);
Expand All @@ -346,7 +360,19 @@ pairwise_next(pairwiseobject *po)
Py_CLEAR(po->old);
return NULL;
}
#else
if (old == NULL) {
_Py_atomic_store_int_relaxed(&po->iterator_exhausted, 1);
return NULL;
}
PyObject *po_old = ( PyObject *)_Py_atomic_exchange_ptr(&po->old, old);
// we expect po_old to be zero, but it can have been set by
// a concurrent thread
Py_XDECREF(po_old);
#endif
}

#ifndef Py_GIL_DISABLED
Py_INCREF(old);
new = (*Py_TYPE(it)->tp_iternext)(it);
if (new == NULL) {
Expand All @@ -356,8 +382,8 @@ pairwise_next(pairwiseobject *po)
return NULL;
}

result = po->result;
if (Py_REFCNT(result) == 1) {
assert(result != NULL);
if (_PyObject_IsUniquelyReferenced(result)) {
Py_INCREF(result);
PyObject *last_old = PyTuple_GET_ITEM(result, 0);
PyObject *last_new = PyTuple_GET_ITEM(result, 1);
Expand All @@ -378,8 +404,49 @@ pairwise_next(pairwiseobject *po)
PyTuple_SET_ITEM(result, 1, Py_NewRef(new));
}
}

Py_XSETREF(po->old, new);

#else
// at this stage we know that po->old has been set, but we have to make
// sure that po->old is valid at every moment so we atomically swap old
// and new. for that we first need to acquire a new object
new = (*Py_TYPE(it)->tp_iternext)(it);
if (new == NULL) {
_Py_atomic_store_int_relaxed(&po->iterator_exhausted, 1);
return NULL;
}
// we need to incref new before handing it over to po->old
Py_INCREF(new);
old = ( PyObject *)_Py_atomic_exchange_ptr(&po->old, new);
// we have acquired old and we hold a reference to it

assert(result != NULL);
if (_PyObject_IsUniquelyReferenced(result)) {
Py_INCREF(result);
PyObject *last_old = PyTuple_GET_ITEM(result, 0);
PyObject *last_new = PyTuple_GET_ITEM(result, 1);
PyTuple_SET_ITEM(result, 0, Py_NewRef(old));
PyTuple_SET_ITEM(result, 1, new); // steal reference
Py_DECREF(last_old);
Py_DECREF(last_new);
// bpo-42536: The GC may have untracked this result tuple. Since we're
// recycling it, make sure it's tracked again:
if (!_PyObject_GC_IS_TRACKED(result)) {
_PyObject_GC_TRACK(result);
}
}
else {
result = PyTuple_New(2);
if (result != NULL) {
PyTuple_SET_ITEM(result, 0, Py_NewRef(old));
PyTuple_SET_ITEM(result, 1, new); // steal reference
}
else {
Py_DECREF(new);
}
}
#endif

Py_DECREF(old);
return result;
}
Expand Down
Loading