Skip to content

Commit 57a0317

Browse files
authored
Merge pull request #1699 from ericniebler/fix-cuda-stream-scheduler-performance-regression
fix performance regression in the nvexec maxwell examples
2 parents 0437fe5 + 95008fe commit 57a0317

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

examples/nvexec/maxwell/snr.cuh

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,10 @@ namespace nvexec::_strm {
279279
return;
280280
}
281281

282-
auto sch = ex::get_scheduler(ex::get_env(op_state_.rcvr_));
283282
inner_op_state_t& inner_op_state = op_state_.inner_op_state_.emplace(
284283
ex::__emplace_from{[&]() noexcept {
285284
return ex::connect(
286-
ex::schedule(sch) | op_state_.closure_, receiver_2_t<OpT>{op_state_});
285+
op_state_.closure_(ex::schedule(op_state_.scheduler_)), receiver_2_t<OpT>{op_state_});
287286
}});
288287

289288
ex::start(inner_op_state);
@@ -318,11 +317,11 @@ namespace nvexec::_strm {
318317
using inner_op_state_t = OpT::inner_op_state_t;
319318

320319
if (op_state_.n_) {
321-
auto sch = ex::get_scheduler(ex::get_env(op_state_.rcvr_));
322320
inner_op_state_t& inner_op_state = op_state_.inner_op_state_.emplace(
323321
ex::__emplace_from{[&]() noexcept {
324322
return ex::connect(
325-
ex::schedule(sch) | op_state_.closure_, receiver_2_t<OpT>{op_state_});
323+
op_state_.closure_(ex::schedule(op_state_.scheduler_)),
324+
receiver_2_t<OpT>{op_state_});
326325
}});
327326

328327
ex::start(inner_op_state);
@@ -353,14 +352,18 @@ namespace nvexec::_strm {
353352
struct operation_state_t : operation_state_base_t<ReceiverId> {
354353
using PredSender = ex::__t<PredecessorSenderId>;
355354
using Receiver = ex::__t<ReceiverId>;
356-
using Scheduler = std::invoke_result_t<ex::get_scheduler_t, ex::env_of_t<Receiver>>;
355+
using Scheduler = std::invoke_result_t<
356+
ex::get_completion_scheduler_t<ex::set_value_t>,
357+
ex::env_of_t<PredSender>,
358+
ex::env_of_t<Receiver>
359+
>;
357360
using InnerSender = std::invoke_result_t<Closure, ex::schedule_result_t<Scheduler>>;
358361

359362
using predecessor_op_state_t =
360363
ex::connect_result_t<PredSender, receiver_1_t<operation_state_t>>;
361364
using inner_op_state_t = ex::connect_result_t<InnerSender, receiver_2_t<operation_state_t>>;
362365

363-
PredSender pred_sender_;
366+
Scheduler scheduler_;
364367
Closure closure_;
365368
std::optional<predecessor_op_state_t> pred_op_state_;
366369
std::optional<inner_op_state_t> inner_op_state_;
@@ -385,11 +388,14 @@ namespace nvexec::_strm {
385388
: operation_state_base_t<ReceiverId>(
386389
static_cast<Receiver&&>(rcvr),
387390
ex::get_completion_scheduler<ex::set_value_t>(ex::get_env(pred_sender)).context_state_)
388-
, pred_sender_{static_cast<PredSender&&>(pred_sender)}
391+
, scheduler_(
392+
ex::get_completion_scheduler<ex::set_value_t>(
393+
ex::get_env(pred_sender),
394+
ex::get_env(rcvr)))
389395
, closure_(closure)
390396
, n_(n) {
391397
pred_op_state_.emplace(ex::__emplace_from{[&]() noexcept {
392-
return ex::connect(static_cast<PredSender&&>(pred_sender_), receiver_1_t{*this});
398+
return ex::connect(static_cast<PredSender&&>(pred_sender), receiver_1_t{*this});
393399
}});
394400
}
395401
};
@@ -413,7 +419,10 @@ namespace nvexec::_strm {
413419
static auto connect(Self&& self, Receiver r)
414420
-> nvexec::_strm::repeat_n::operation_state_t<SenderId, Closure, ex::__id<Receiver>> {
415421
return nvexec::_strm::repeat_n::operation_state_t<SenderId, Closure, ex::__id<Receiver>>(
416-
static_cast<Sender&&>(self.sender_), self.closure_, static_cast<Receiver&&>(r), self.n_);
422+
static_cast<Self&&>(self).sender_,
423+
static_cast<Self&&>(self).closure_,
424+
static_cast<Receiver&&>(r),
425+
self.n_);
417426
}
418427

419428
[[nodiscard]]

0 commit comments

Comments
 (0)