@@ -345,7 +345,7 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) {
345
345
#endif
346
346
}
347
347
348
- template <int32_t kWorldSize , bool has_acc>
348
+ template <int32_t kWorldSize , bool has_acc, bool reduce_scatter >
349
349
#if defined(USE_ROCM)
350
350
__launch_bounds__ (512 ) __global__ void two_shot_all_reduce(
351
351
#else
@@ -425,13 +425,18 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
425
425
}
426
426
427
427
// Store to the local buffer.
428
- *reinterpret_cast <uint4 *>(&src_d[0 ][i + N_start]) =
429
- *reinterpret_cast <const uint4 *>(&sums);
428
+ if constexpr (reduce_scatter) {
429
+ *reinterpret_cast <uint4 *>(&output[i]) =
430
+ *reinterpret_cast <const uint4 *>(&sums);
431
+ } else {
432
+ *reinterpret_cast <uint4 *>(&src_d[0 ][i + N_start]) =
433
+ *reinterpret_cast <const uint4 *>(&sums);
434
+ }
430
435
}
431
436
432
437
__syncthreads ();
433
438
434
- // barreris among the blocks with the same idx (release-acuqire semantics)
439
+ // barriers among the blocks with the same idx (release-acuqire semantics)
435
440
if (threadIdx .x < kWorldSize ) {
436
441
// The all blocks notifies the other ranks.
437
442
int32_t flag_block_offset = kWorldSize + blockIdx .x * kWorldSize ;
@@ -445,6 +450,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
445
450
} while (rank_barrier != flag);
446
451
}
447
452
453
+ if constexpr (reduce_scatter) {
454
+ // reduce scatter we can stop here and skip the allgather below
455
+ return ;
456
+ }
457
+
448
458
__syncthreads ();
449
459
450
460
// Gather all needed elts from other intra-node ranks
@@ -628,7 +638,7 @@ void two_shot_car_allreduce(
628
638
#define X (kWorldSize ) \
629
639
if (state->world_size_ == kWorldSize ) { \
630
640
if (z) { \
631
- two_shot_all_reduce<kWorldSize , true > \
641
+ two_shot_all_reduce<kWorldSize , true , false > \
632
642
<<<blocks, kThreadsPerBlock , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
633
643
state->rank_ , \
634
644
state->world_size_ , \
@@ -641,7 +651,7 @@ void two_shot_car_allreduce(
641
651
C10_CUDA_KERNEL_LAUNCH_CHECK (); \
642
652
return ; \
643
653
} else { \
644
- two_shot_all_reduce<kWorldSize , false > \
654
+ two_shot_all_reduce<kWorldSize , false , false > \
645
655
<<<blocks, kThreadsPerBlock , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
646
656
state->rank_ , \
647
657
state->world_size_ , \
@@ -667,4 +677,100 @@ void two_shot_car_allreduce(
667
677
return ;
668
678
}
669
679
680
+ void car_reduce_scatter (
681
+ at::Tensor y_reducescatter,
682
+ at::Tensor y,
683
+ std::optional<at::Tensor> z,
684
+ int64_t comm_idx) { // match the API with nccl_allreduce in
685
+ // https://fburl.com/code/v538vig9
686
+ auto state = get_car_state ();
687
+ c10::cuda::CUDAGuard gg (y_reducescatter.device ());
688
+ TORCH_CHECK (y_reducescatter.is_contiguous ());
689
+ TORCH_CHECK (y.is_contiguous ());
690
+ TORCH_CHECK ((state->world_size_ * y_reducescatter.numel ()) == y.numel ());
691
+ TORCH_CHECK (y.numel () % 8 == 0 );
692
+ TORCH_CHECK (y.numel () < kMaxCAR );
693
+ const auto N = y.numel ();
694
+ if (z) {
695
+ TORCH_CHECK (z->numel () == y.numel ());
696
+ }
697
+ ++state->flag_ ;
698
+
699
+ std::array<at::BFloat16*, 8 > inputs;
700
+ for (auto ii = 0 ; ii < state->world_size_ ; ++ii) {
701
+ inputs[ii] = state->buffers_ [ii].data_ptr <at::BFloat16>();
702
+ }
703
+
704
+ std::array<int32_t *, 8 > barriers;
705
+ for (auto ii = 0 ; ii < state->world_size_ ; ++ii) {
706
+ barriers[ii] = state->barriers_ [ii].data_ptr <int32_t >();
707
+ }
708
+
709
+ AT_CUDA_CHECK (cudaMemcpyAsync (
710
+ inputs[state->rank_ ],
711
+ y.data_ptr <at::BFloat16>(),
712
+ y.numel () * y.element_size (),
713
+ cudaMemcpyDeviceToDevice,
714
+ at::cuda::getCurrentCUDAStream ()));
715
+
716
+ constexpr int32_t N_per_thread = 8 ;
717
+ TORCH_CHECK (N % state->world_size_ == 0 );
718
+ const auto N_per_rank = N / state->world_size_ ;
719
+
720
+ TORCH_CHECK (N_per_rank % N_per_thread == 0 );
721
+ auto threads_per_rank = div_round_up (N_per_rank, N_per_thread);
722
+
723
+ #if defined(USE_ROCM)
724
+ constexpr int32_t kThreadsPerBlock = 512 ;
725
+ #else
726
+ constexpr int32_t kThreadsPerBlock = 1024 ;
727
+ #endif
728
+
729
+ constexpr int32_t kMaxBlocks = 24 ;
730
+
731
+ auto blocks = std::min<int32_t >(
732
+ cuda_calc_block_count (threads_per_rank, kThreadsPerBlock ), kMaxBlocks );
733
+
734
+ #define X (kWorldSize ) \
735
+ if (state->world_size_ == kWorldSize ) { \
736
+ if (z) { \
737
+ two_shot_all_reduce<kWorldSize , true , true > \
738
+ <<<blocks, kThreadsPerBlock , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
739
+ state->rank_ , \
740
+ state->world_size_ , \
741
+ state->flag_ * state->world_size_ , \
742
+ barriers, \
743
+ inputs, \
744
+ z->data_ptr <at::BFloat16>(), \
745
+ y_reducescatter.data_ptr <at::BFloat16>(), \
746
+ N); \
747
+ C10_CUDA_KERNEL_LAUNCH_CHECK (); \
748
+ return ; \
749
+ } else { \
750
+ two_shot_all_reduce<kWorldSize , false , true > \
751
+ <<<blocks, kThreadsPerBlock , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
752
+ state->rank_ , \
753
+ state->world_size_ , \
754
+ state->flag_ * state->world_size_ , \
755
+ barriers, \
756
+ inputs, \
757
+ nullptr , \
758
+ y_reducescatter.data_ptr <at::BFloat16>(), \
759
+ N); \
760
+ C10_CUDA_KERNEL_LAUNCH_CHECK (); \
761
+ return ; \
762
+ } \
763
+ }
764
+
765
+ TORCH_CHECK (
766
+ state->world_size_ == 2 || state->world_size_ == 4 ||
767
+ state->world_size_ == 8 );
768
+ X (2 );
769
+ X (4 );
770
+ X (8 );
771
+
772
+ #undef X
773
+ return ;
774
+ }
775
+
670
776
} // namespace fbgemm_gpu
0 commit comments