@@ -54,14 +54,16 @@ class TeamThreadPool {
5454
5555 template <typename F>
5656 void parallel_region (std::size_t n, F&& fn) {
57- if (n == 0 ) return ;
57+ if (n == 0 )
58+ return ;
5859
5960 const std::size_t max_team = team_size ();
6061 if (max_team <= 1 ) {
6162 fn (std::size_t {0 });
6263 return ;
6364 }
64- if (n > max_team) n = max_team;
65+ if (n > max_team)
66+ n = max_team;
6567 if (n <= 1 ) {
6668 fn (std::size_t {0 });
6769 return ;
@@ -71,7 +73,7 @@ class TeamThreadPool {
7173 // IMPORTANT: must still execute ALL tids to preserve correctness.
7274 if (in_worker_) {
7375 for (std::size_t tid = 0 ; tid < n; ++tid) {
74- fn (tid);
76+ fn (tid);
7577 }
7678 return ;
7779 }
@@ -109,16 +111,16 @@ class TeamThreadPool {
109111 fn_copy (0 );
110112 } catch (...) {
111113 std::lock_guard<std::mutex> lk (exc_m_);
112- if (!eptr) eptr = std::current_exception ();
114+ if (!eptr)
115+ eptr = std::current_exception ();
113116 }
114117 in_worker_ = false ;
115118
116119 // Wait for workers 1..n-1.
117120 {
118121 std::unique_lock<std::mutex> lk (done_m_);
119- done_cv_.wait (lk, [&] {
120- return remaining_.load (std::memory_order_acquire) == 0 ;
121- });
122+ done_cv_.wait (
123+ lk, [&] { return remaining_.load (std::memory_order_acquire) == 0 ; });
122124 }
123125
124126 // Hygiene: deactivate region state
@@ -132,10 +134,11 @@ class TeamThreadPool {
132134 exc_ptr_ = nullptr ;
133135 }
134136
135- if (eptr) std::rethrow_exception (eptr);
137+ if (eptr)
138+ std::rethrow_exception (eptr);
136139 }
137140
138- private:
141+ private:
139142 using call_fn_t = void (*)(void *, std::size_t );
140143
141144 template <typename Fn>
@@ -173,17 +176,18 @@ class TeamThreadPool {
173176 }
174177
175178 TeamThreadPool ()
176- : stop_(false ),
177- epoch_ (0 ), // optional: keep for logging if you want
178- wake_gen_(0 ), // NEW: protected by wake_m_
179- region_n_(0 ),
180- region_ctx_(nullptr ),
181- region_call_(nullptr ),
182- remaining_(0 ),
183- exc_ptr_(nullptr ),
184- ready_count_(0 ) {
179+ : stop_(false ),
180+ epoch_ (0 ), // optional: keep for logging if you want
181+ wake_gen_(0 ), // NEW: protected by wake_m_
182+ region_n_(0 ),
183+ region_ctx_(nullptr ),
184+ region_call_(nullptr ),
185+ remaining_(0 ),
186+ exc_ptr_(nullptr ),
187+ ready_count_(0 ) {
185188 unsigned hw_u = std::thread::hardware_concurrency ();
186- if (hw_u == 0 ) hw_u = 2 ;
189+ if (hw_u == 0 )
190+ hw_u = 2 ;
187191 const std::size_t hw = static_cast <std::size_t >(hw_u);
188192
189193 const std::size_t cap = configured_cap_ (hw);
@@ -194,85 +198,88 @@ class TeamThreadPool {
194198 const std::size_t tid = i + 1 ; // workers are 1..num_workers
195199
196200 workers_.emplace_back ([this , tid] {
197- // Per-worker AD tape initialized once.
198- static thread_local ChainableStack ad_tape;
199- (void )ad_tape;
200-
201- in_worker_ = true ;
202-
203- // Startup barrier: ensure each worker has entered the wait loop once.
204- std::size_t seen_gen = 0 ;
205- {
206- std::lock_guard<std::mutex> lk (wake_m_);
207- ready_count_.fetch_add (1 , std::memory_order_acq_rel);
208- seen_gen = wake_gen_; // establish initial seen generation under the same mutex
209- }
210- ready_cv_.notify_one ();
211-
212- for (;;) {
213- // Wait for a new generation (or stop).
214- {
215- std::unique_lock<std::mutex> lk (wake_m_);
216- wake_cv_.wait (lk, [&] {
217- return stop_.load (std::memory_order_acquire) || wake_gen_ != seen_gen;
218- });
219- if (stop_.load (std::memory_order_acquire)) break ;
220-
221- // Consume the generation we were woken for.
222- seen_gen = wake_gen_;
223- }
224-
225- const std::size_t n = region_n_.load (std::memory_order_acquire);
226- if (tid >= n) {
227- continue ; // not participating in this region
228- }
229-
230- // Always decrement once for participating workers.
231- struct DoneGuard {
232- std::atomic<std::size_t >& rem;
233- std::mutex& m;
234- std::condition_variable& cv;
235- ~DoneGuard () {
236- if (rem.fetch_sub (1 , std::memory_order_acq_rel) == 1 ) {
237- std::lock_guard<std::mutex> lk (m);
238- cv.notify_one ();
239- }
240- }
241- } guard{remaining_, done_m_, done_cv_};
242-
243- void * ctx = region_ctx_.load (std::memory_order_acquire);
244- call_fn_t call = region_call_.load (std::memory_order_acquire);
245-
246- // If call is unexpectedly null, that's a serious publication bug.
247- // Don't decrement early without doing work; treat it as an error.
248- if (!call) {
249- std::lock_guard<std::mutex> lk (exc_m_);
250- if (exc_ptr_ && *exc_ptr_ == nullptr ) {
251- *exc_ptr_ = std::make_exception_ptr (
252- std::runtime_error (" TeamThreadPool: region_call_ is null" ));
253- }
254- continue ;
255- }
256-
257- try {
258- call (ctx, tid);
259- } catch (...) {
260- std::lock_guard<std::mutex> lk (exc_m_);
261- if (exc_ptr_ && *exc_ptr_ == nullptr ) {
262- *exc_ptr_ = std::current_exception ();
263- }
264- }
265- }
266-
267- in_worker_ = false ;
201+ // Per-worker AD tape initialized once.
202+ static thread_local ChainableStack ad_tape;
203+ (void )ad_tape;
204+
205+ in_worker_ = true ;
206+
207+ // Startup barrier: ensure each worker has entered the wait loop once.
208+ std::size_t seen_gen = 0 ;
209+ {
210+ std::lock_guard<std::mutex> lk (wake_m_);
211+ ready_count_.fetch_add (1 , std::memory_order_acq_rel);
212+ seen_gen = wake_gen_; // establish initial seen generation under the
213+ // same mutex
214+ }
215+ ready_cv_.notify_one ();
216+
217+ for (;;) {
218+ // Wait for a new generation (or stop).
219+ {
220+ std::unique_lock<std::mutex> lk (wake_m_);
221+ wake_cv_.wait (lk, [&] {
222+ return stop_.load (std::memory_order_acquire)
223+ || wake_gen_ != seen_gen;
224+ });
225+ if (stop_.load (std::memory_order_acquire))
226+ break ;
227+
228+ // Consume the generation we were woken for.
229+ seen_gen = wake_gen_;
230+ }
231+
232+ const std::size_t n = region_n_.load (std::memory_order_acquire);
233+ if (tid >= n) {
234+ continue ; // not participating in this region
235+ }
236+
237+ // Always decrement once for participating workers.
238+ struct DoneGuard {
239+ std::atomic<std::size_t >& rem;
240+ std::mutex& m;
241+ std::condition_variable& cv;
242+ ~DoneGuard () {
243+ if (rem.fetch_sub (1 , std::memory_order_acq_rel) == 1 ) {
244+ std::lock_guard<std::mutex> lk (m);
245+ cv.notify_one ();
246+ }
247+ }
248+ } guard{remaining_, done_m_, done_cv_};
249+
250+ void * ctx = region_ctx_.load (std::memory_order_acquire);
251+ call_fn_t call = region_call_.load (std::memory_order_acquire);
252+
253+ // If call is unexpectedly null, that's a serious publication bug.
254+ // Don't decrement early without doing work; treat it as an error.
255+ if (!call) {
256+ std::lock_guard<std::mutex> lk (exc_m_);
257+ if (exc_ptr_ && *exc_ptr_ == nullptr ) {
258+ *exc_ptr_ = std::make_exception_ptr (
259+ std::runtime_error (" TeamThreadPool: region_call_ is null" ));
260+ }
261+ continue ;
262+ }
263+
264+ try {
265+ call (ctx, tid);
266+ } catch (...) {
267+ std::lock_guard<std::mutex> lk (exc_m_);
268+ if (exc_ptr_ && *exc_ptr_ == nullptr ) {
269+ *exc_ptr_ = std::current_exception ();
270+ }
271+ }
272+ }
273+
274+ in_worker_ = false ;
268275 });
269276 }
270277
271278 // Wait for all workers to reach the wait loop once before returning.
272279 {
273280 std::unique_lock<std::mutex> lk (wake_m_);
274281 ready_cv_.wait (lk, [&] {
275- return ready_count_.load (std::memory_order_acquire) == workers_.size ();
282+ return ready_count_.load (std::memory_order_acquire) == workers_.size ();
276283 });
277284 }
278285 }
0 commit comments