Skip to content

Commit 44cec35

Browse files
committed
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
1 parent c954029 commit 44cec35

File tree

1 file changed

+98
-91
lines changed

1 file changed

+98
-91
lines changed

stan/math/rev/core/team_thread_pool.hpp

Lines changed: 98 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,16 @@ class TeamThreadPool {
5454

5555
template <typename F>
5656
void parallel_region(std::size_t n, F&& fn) {
57-
if (n == 0) return;
57+
if (n == 0)
58+
return;
5859

5960
const std::size_t max_team = team_size();
6061
if (max_team <= 1) {
6162
fn(std::size_t{0});
6263
return;
6364
}
64-
if (n > max_team) n = max_team;
65+
if (n > max_team)
66+
n = max_team;
6567
if (n <= 1) {
6668
fn(std::size_t{0});
6769
return;
@@ -71,7 +73,7 @@ class TeamThreadPool {
7173
// IMPORTANT: must still execute ALL tids to preserve correctness.
7274
if (in_worker_) {
7375
for (std::size_t tid = 0; tid < n; ++tid) {
74-
fn(tid);
76+
fn(tid);
7577
}
7678
return;
7779
}
@@ -109,16 +111,16 @@ class TeamThreadPool {
109111
fn_copy(0);
110112
} catch (...) {
111113
std::lock_guard<std::mutex> lk(exc_m_);
112-
if (!eptr) eptr = std::current_exception();
114+
if (!eptr)
115+
eptr = std::current_exception();
113116
}
114117
in_worker_ = false;
115118

116119
// Wait for workers 1..n-1.
117120
{
118121
std::unique_lock<std::mutex> lk(done_m_);
119-
done_cv_.wait(lk, [&] {
120-
return remaining_.load(std::memory_order_acquire) == 0;
121-
});
122+
done_cv_.wait(
123+
lk, [&] { return remaining_.load(std::memory_order_acquire) == 0; });
122124
}
123125

124126
// Hygiene: deactivate region state
@@ -132,10 +134,11 @@ class TeamThreadPool {
132134
exc_ptr_ = nullptr;
133135
}
134136

135-
if (eptr) std::rethrow_exception(eptr);
137+
if (eptr)
138+
std::rethrow_exception(eptr);
136139
}
137140

138-
private:
141+
private:
139142
using call_fn_t = void (*)(void*, std::size_t);
140143

141144
template <typename Fn>
@@ -173,17 +176,18 @@ class TeamThreadPool {
173176
}
174177

175178
TeamThreadPool()
176-
: stop_(false),
177-
epoch_(0), // optional: keep for logging if you want
178-
wake_gen_(0), // NEW: protected by wake_m_
179-
region_n_(0),
180-
region_ctx_(nullptr),
181-
region_call_(nullptr),
182-
remaining_(0),
183-
exc_ptr_(nullptr),
184-
ready_count_(0) {
179+
: stop_(false),
180+
epoch_(0), // optional: keep for logging if you want
181+
wake_gen_(0), // NEW: protected by wake_m_
182+
region_n_(0),
183+
region_ctx_(nullptr),
184+
region_call_(nullptr),
185+
remaining_(0),
186+
exc_ptr_(nullptr),
187+
ready_count_(0) {
185188
unsigned hw_u = std::thread::hardware_concurrency();
186-
if (hw_u == 0) hw_u = 2;
189+
if (hw_u == 0)
190+
hw_u = 2;
187191
const std::size_t hw = static_cast<std::size_t>(hw_u);
188192

189193
const std::size_t cap = configured_cap_(hw);
@@ -194,85 +198,88 @@ class TeamThreadPool {
194198
const std::size_t tid = i + 1; // workers are 1..num_workers
195199

196200
workers_.emplace_back([this, tid] {
197-
// Per-worker AD tape initialized once.
198-
static thread_local ChainableStack ad_tape;
199-
(void)ad_tape;
200-
201-
in_worker_ = true;
202-
203-
// Startup barrier: ensure each worker has entered the wait loop once.
204-
std::size_t seen_gen = 0;
205-
{
206-
std::lock_guard<std::mutex> lk(wake_m_);
207-
ready_count_.fetch_add(1, std::memory_order_acq_rel);
208-
seen_gen = wake_gen_; // establish initial seen generation under the same mutex
209-
}
210-
ready_cv_.notify_one();
211-
212-
for (;;) {
213-
// Wait for a new generation (or stop).
214-
{
215-
std::unique_lock<std::mutex> lk(wake_m_);
216-
wake_cv_.wait(lk, [&] {
217-
return stop_.load(std::memory_order_acquire) || wake_gen_ != seen_gen;
218-
});
219-
if (stop_.load(std::memory_order_acquire)) break;
220-
221-
// Consume the generation we were woken for.
222-
seen_gen = wake_gen_;
223-
}
224-
225-
const std::size_t n = region_n_.load(std::memory_order_acquire);
226-
if (tid >= n) {
227-
continue; // not participating in this region
228-
}
229-
230-
// Always decrement once for participating workers.
231-
struct DoneGuard {
232-
std::atomic<std::size_t>& rem;
233-
std::mutex& m;
234-
std::condition_variable& cv;
235-
~DoneGuard() {
236-
if (rem.fetch_sub(1, std::memory_order_acq_rel) == 1) {
237-
std::lock_guard<std::mutex> lk(m);
238-
cv.notify_one();
239-
}
240-
}
241-
} guard{remaining_, done_m_, done_cv_};
242-
243-
void* ctx = region_ctx_.load(std::memory_order_acquire);
244-
call_fn_t call = region_call_.load(std::memory_order_acquire);
245-
246-
// If call is unexpectedly null, that's a serious publication bug.
247-
// Don't decrement early without doing work; treat it as an error.
248-
if (!call) {
249-
std::lock_guard<std::mutex> lk(exc_m_);
250-
if (exc_ptr_ && *exc_ptr_ == nullptr) {
251-
*exc_ptr_ = std::make_exception_ptr(
252-
std::runtime_error("TeamThreadPool: region_call_ is null"));
253-
}
254-
continue;
255-
}
256-
257-
try {
258-
call(ctx, tid);
259-
} catch (...) {
260-
std::lock_guard<std::mutex> lk(exc_m_);
261-
if (exc_ptr_ && *exc_ptr_ == nullptr) {
262-
*exc_ptr_ = std::current_exception();
263-
}
264-
}
265-
}
266-
267-
in_worker_ = false;
201+
// Per-worker AD tape initialized once.
202+
static thread_local ChainableStack ad_tape;
203+
(void)ad_tape;
204+
205+
in_worker_ = true;
206+
207+
// Startup barrier: ensure each worker has entered the wait loop once.
208+
std::size_t seen_gen = 0;
209+
{
210+
std::lock_guard<std::mutex> lk(wake_m_);
211+
ready_count_.fetch_add(1, std::memory_order_acq_rel);
212+
seen_gen = wake_gen_; // establish initial seen generation under the
213+
// same mutex
214+
}
215+
ready_cv_.notify_one();
216+
217+
for (;;) {
218+
// Wait for a new generation (or stop).
219+
{
220+
std::unique_lock<std::mutex> lk(wake_m_);
221+
wake_cv_.wait(lk, [&] {
222+
return stop_.load(std::memory_order_acquire)
223+
|| wake_gen_ != seen_gen;
224+
});
225+
if (stop_.load(std::memory_order_acquire))
226+
break;
227+
228+
// Consume the generation we were woken for.
229+
seen_gen = wake_gen_;
230+
}
231+
232+
const std::size_t n = region_n_.load(std::memory_order_acquire);
233+
if (tid >= n) {
234+
continue; // not participating in this region
235+
}
236+
237+
// Always decrement once for participating workers.
238+
struct DoneGuard {
239+
std::atomic<std::size_t>& rem;
240+
std::mutex& m;
241+
std::condition_variable& cv;
242+
~DoneGuard() {
243+
if (rem.fetch_sub(1, std::memory_order_acq_rel) == 1) {
244+
std::lock_guard<std::mutex> lk(m);
245+
cv.notify_one();
246+
}
247+
}
248+
} guard{remaining_, done_m_, done_cv_};
249+
250+
void* ctx = region_ctx_.load(std::memory_order_acquire);
251+
call_fn_t call = region_call_.load(std::memory_order_acquire);
252+
253+
// If call is unexpectedly null, that's a serious publication bug.
254+
// Don't decrement early without doing work; treat it as an error.
255+
if (!call) {
256+
std::lock_guard<std::mutex> lk(exc_m_);
257+
if (exc_ptr_ && *exc_ptr_ == nullptr) {
258+
*exc_ptr_ = std::make_exception_ptr(
259+
std::runtime_error("TeamThreadPool: region_call_ is null"));
260+
}
261+
continue;
262+
}
263+
264+
try {
265+
call(ctx, tid);
266+
} catch (...) {
267+
std::lock_guard<std::mutex> lk(exc_m_);
268+
if (exc_ptr_ && *exc_ptr_ == nullptr) {
269+
*exc_ptr_ = std::current_exception();
270+
}
271+
}
272+
}
273+
274+
in_worker_ = false;
268275
});
269276
}
270277

271278
// Wait for all workers to reach the wait loop once before returning.
272279
{
273280
std::unique_lock<std::mutex> lk(wake_m_);
274281
ready_cv_.wait(lk, [&] {
275-
return ready_count_.load(std::memory_order_acquire) == workers_.size();
282+
return ready_count_.load(std::memory_order_acquire) == workers_.size();
276283
});
277284
}
278285
}

0 commit comments

Comments
 (0)