Skip to content

Commit ba76d4d

Browse files
xw285cornellfacebook-github-bot
authored andcommitted
custom reduce scatter (#3686)
Summary: X-link: facebookresearch/FBGEMM#763 Piggyback on the twoshot allreduce for the reducescatter - pretty much the first half of twoshot allreduce. Differential Revision: D69364062
1 parent 1b7789a commit ba76d4d

File tree

3 files changed

+248
-69
lines changed

3 files changed

+248
-69
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,11 @@ void nccl_alltoall(
177177
torch::cuda::nccl::all2all(dsts, srcs, *get_nccl_comm(comm_idx), stream);
178178
}
179179

180-
void nccl_reducescatter(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
180+
void nccl_reducescatter(
181+
at::Tensor dst,
182+
at::Tensor src,
183+
std::optional<at::Tensor> bias,
184+
int64_t comm_idx) {
181185
using namespace c10d;
182186
TORCH_CHECK(src.is_contiguous());
183187
TORCH_CHECK(dst.is_contiguous());
@@ -194,6 +198,10 @@ void nccl_reducescatter(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
194198
*get_nccl_comm(comm_idx),
195199
at::cuda::getCurrentCUDAStream()),
196200
"ncclReduceScatter");
201+
202+
if (bias) {
203+
dst.add_(*bias);
204+
}
197205
}
198206

199207
void nccl_allreduce(
@@ -259,6 +267,11 @@ void two_shot_car_allreduce(
259267
at::Tensor src,
260268
std::optional<at::Tensor> bias,
261269
int64_t comm_idx);
270+
void car_reduce_scatter(
271+
at::Tensor dst,
272+
at::Tensor src,
273+
std::optional<at::Tensor> bias,
274+
int64_t comm_idx);
262275

263276
at::Tensor car_tensor();
264277

@@ -282,7 +295,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
282295
"nccl_alltoall_single(Tensor(a!) dst, Tensor src, int world_size, int comm_idx=0) -> ()");
283296
m.def("nccl_alltoall(Tensor(a!)[] dst, Tensor[] src, int comm_idx=0) -> ()");
284297

285-
m.def("nccl_reducescatter(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()");
298+
m.def(
299+
"nccl_reducescatter(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
286300

287301
m.def(
288302
"nccl_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
@@ -302,6 +316,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
302316

303317
m.def(
304318
"two_shot_car_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
319+
320+
m.def(
321+
"car_reduce_scatter(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
305322
}
306323

307324
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
@@ -312,6 +329,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
312329
m.impl("nccl_reducescatter", nccl_reducescatter);
313330
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
314331
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
332+
m.impl("car_reduce_scatter", car_reduce_scatter);
315333
}
316334

317335
// Though it shouldnt be used, it is useful to define these functions for CPU to
@@ -324,6 +342,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
324342
m.impl("nccl_reducescatter", nccl_reducescatter);
325343
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
326344
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
345+
m.impl("car_reduce_scatter", car_reduce_scatter);
327346
}
328347

329348
// Shape registration functions for car operators.
@@ -360,6 +379,7 @@ void nccl_alltoall_meta(
360379
void nccl_reducescatter_meta(
361380
at::Tensor /* dst */,
362381
at::Tensor /* src */,
382+
std::optional<at::Tensor> bias,
363383
int64_t /* comm_idx */) {
364384
return;
365385
}
@@ -380,6 +400,14 @@ void two_shot_car_allreduce_meta(
380400
return;
381401
}
382402

403+
void car_reduce_scatter_meta(
404+
at::Tensor /* dst */,
405+
at::Tensor /* src */,
406+
std::optional<at::Tensor> /* bias */,
407+
int64_t /* comm_idx */) {
408+
return;
409+
}
410+
383411
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
384412
m.impl("nccl_allreduce", nccl_allreduce_meta);
385413
m.impl("nccl_allgather", nccl_allgather_meta);
@@ -388,6 +416,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
388416
m.impl("nccl_reducescatter", nccl_reducescatter_meta);
389417
m.impl("one_shot_car_allreduce", one_shot_car_allreduce_meta);
390418
m.impl("two_shot_car_allreduce", two_shot_car_allreduce_meta);
419+
m.impl("car_reduce_scatter", car_reduce_scatter_meta);
391420
}
392421

393422
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/comm/car.cu

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) {
345345
#endif
346346
}
347347

348-
template <int32_t kWorldSize, bool has_acc>
348+
template <int32_t kWorldSize, bool has_acc, bool reduce_scatter>
349349
#if defined(USE_ROCM)
350350
__launch_bounds__(512) __global__ void two_shot_all_reduce(
351351
#else
@@ -425,13 +425,18 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
425425
}
426426

427427
// Store to the local buffer.
428-
*reinterpret_cast<uint4*>(&src_d[0][i + N_start]) =
429-
*reinterpret_cast<const uint4*>(&sums);
428+
if constexpr (reduce_scatter) {
429+
*reinterpret_cast<uint4*>(&output[i]) =
430+
*reinterpret_cast<const uint4*>(&sums);
431+
} else {
432+
*reinterpret_cast<uint4*>(&src_d[0][i + N_start]) =
433+
*reinterpret_cast<const uint4*>(&sums);
434+
}
430435
}
431436

432437
__syncthreads();
433438

434-
// barreris among the blocks with the same idx (release-acuqire semantics)
439+
// barriers among the blocks with the same idx (release-acuqire semantics)
435440
if (threadIdx.x < kWorldSize) {
436441
// The all blocks notifies the other ranks.
437442
int32_t flag_block_offset = kWorldSize + blockIdx.x * kWorldSize;
@@ -445,6 +450,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
445450
} while (rank_barrier != flag);
446451
}
447452

453+
if constexpr (reduce_scatter) {
454+
// reduce scatter we can stop here and skip the allgather below
455+
return;
456+
}
457+
448458
__syncthreads();
449459

450460
// Gather all needed elts from other intra-node ranks
@@ -628,7 +638,7 @@ void two_shot_car_allreduce(
628638
#define X(kWorldSize) \
629639
if (state->world_size_ == kWorldSize) { \
630640
if (z) { \
631-
two_shot_all_reduce<kWorldSize, true> \
641+
two_shot_all_reduce<kWorldSize, true, false> \
632642
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
633643
state->rank_, \
634644
state->world_size_, \
@@ -641,7 +651,7 @@ void two_shot_car_allreduce(
641651
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
642652
return; \
643653
} else { \
644-
two_shot_all_reduce<kWorldSize, false> \
654+
two_shot_all_reduce<kWorldSize, false, false> \
645655
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
646656
state->rank_, \
647657
state->world_size_, \
@@ -667,4 +677,100 @@ void two_shot_car_allreduce(
667677
return;
668678
}
669679

680+
void car_reduce_scatter(
681+
at::Tensor y_reducescatter,
682+
at::Tensor y,
683+
std::optional<at::Tensor> z,
684+
int64_t comm_idx) { // match the API with nccl_allreduce in
685+
// https://fburl.com/code/v538vig9
686+
auto state = get_car_state();
687+
c10::cuda::CUDAGuard gg(y_reducescatter.device());
688+
TORCH_CHECK(y_reducescatter.is_contiguous());
689+
TORCH_CHECK(y.is_contiguous());
690+
TORCH_CHECK((state->world_size_ * y_reducescatter.numel()) == y.numel());
691+
TORCH_CHECK(y.numel() % 8 == 0);
692+
TORCH_CHECK(y.numel() < kMaxCAR);
693+
const auto N = y.numel();
694+
if (z) {
695+
TORCH_CHECK(z->numel() == y.numel());
696+
}
697+
++state->flag_;
698+
699+
std::array<at::BFloat16*, 8> inputs;
700+
for (auto ii = 0; ii < state->world_size_; ++ii) {
701+
inputs[ii] = state->buffers_[ii].data_ptr<at::BFloat16>();
702+
}
703+
704+
std::array<int32_t*, 8> barriers;
705+
for (auto ii = 0; ii < state->world_size_; ++ii) {
706+
barriers[ii] = state->barriers_[ii].data_ptr<int32_t>();
707+
}
708+
709+
AT_CUDA_CHECK(cudaMemcpyAsync(
710+
inputs[state->rank_],
711+
y.data_ptr<at::BFloat16>(),
712+
y.numel() * y.element_size(),
713+
cudaMemcpyDeviceToDevice,
714+
at::cuda::getCurrentCUDAStream()));
715+
716+
constexpr int32_t N_per_thread = 8;
717+
TORCH_CHECK(N % state->world_size_ == 0);
718+
const auto N_per_rank = N / state->world_size_;
719+
720+
TORCH_CHECK(N_per_rank % N_per_thread == 0);
721+
auto threads_per_rank = div_round_up(N_per_rank, N_per_thread);
722+
723+
#if defined(USE_ROCM)
724+
constexpr int32_t kThreadsPerBlock = 512;
725+
#else
726+
constexpr int32_t kThreadsPerBlock = 1024;
727+
#endif
728+
729+
constexpr int32_t kMaxBlocks = 24;
730+
731+
auto blocks = std::min<int32_t>(
732+
cuda_calc_block_count(threads_per_rank, kThreadsPerBlock), kMaxBlocks);
733+
734+
#define X(kWorldSize) \
735+
if (state->world_size_ == kWorldSize) { \
736+
if (z) { \
737+
two_shot_all_reduce<kWorldSize, true, true> \
738+
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
739+
state->rank_, \
740+
state->world_size_, \
741+
state->flag_ * state->world_size_, \
742+
barriers, \
743+
inputs, \
744+
z->data_ptr<at::BFloat16>(), \
745+
y_reducescatter.data_ptr<at::BFloat16>(), \
746+
N); \
747+
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
748+
return; \
749+
} else { \
750+
two_shot_all_reduce<kWorldSize, false, true> \
751+
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
752+
state->rank_, \
753+
state->world_size_, \
754+
state->flag_ * state->world_size_, \
755+
barriers, \
756+
inputs, \
757+
nullptr, \
758+
y_reducescatter.data_ptr<at::BFloat16>(), \
759+
N); \
760+
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
761+
return; \
762+
} \
763+
}
764+
765+
TORCH_CHECK(
766+
state->world_size_ == 2 || state->world_size_ == 4 ||
767+
state->world_size_ == 8);
768+
X(2);
769+
X(4);
770+
X(8);
771+
772+
#undef X
773+
return;
774+
}
775+
670776
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)