Skip to content

Commit 813fb53

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into op_add
2 parents b8ee0fc + f1ae790 commit 813fb53

40 files changed

+3333
-866
lines changed

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

Lines changed: 51 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,15 @@ Buffer::Buffer(int rank,
6969
calc_ctx = reinterpret_cast<phi::GPUContext*>(
7070
reinterpret_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
7171
->GetDeviceContext(place, true));
72-
// Task fifo memory
73-
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
74-
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
75-
int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS;
72+
73+
// Metadata memory
74+
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
75+
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);
76+
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*);
7677

7778
// Common checks
7879
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
79-
(num_nvl_bytes <= std::numeric_limits<int64_t>::max() ||
80+
(num_nvl_bytes <= std::numeric_limits<int>::max() ||
8081
num_rdma_bytes == 0));
8182
EP_HOST_ASSERT(
8283
num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
@@ -90,40 +91,35 @@ Buffer::Buffer(int rank,
9091
EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS || low_latency_mode);
9192

9293
// Get ranks
93-
// CUDA_CHECK(cudaGetDevice(&device_id));
9494
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
95-
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS),
95+
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS);
9696
num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
9797

9898
// Get device info
9999
cudaDeviceProp device_prop = {};
100100
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
101101

102102
if (num_nvl_bytes > 0) {
103-
// Local IPC: alloc local memory and set local IPC handle
104-
CUDA_CHECK(cudaMalloc(
105-
&buffer_ptrs[nvl_rank],
106-
num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes));
103+
// Local IPC: alloc local memory and set local IPC handles
104+
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank],
105+
num_nvl_bytes + barrier_signal_bytes +
106+
buffer_ptr_bytes + barrier_signal_ptr_bytes));
107107
CUDA_CHECK(
108108
cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
109-
buffer_ptrs_gpu = reinterpret_cast<void**>(
110-
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
111-
fifo_bytes);
112-
113-
// Set task fifo
114-
EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0);
115-
task_fifo_ptrs[nvl_rank] = reinterpret_cast<int*>(
116-
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
117-
task_fifo_ptrs_gpu = reinterpret_cast<int**>(
118-
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
119-
fifo_bytes + buffer_ptr_bytes);
109+
buffer_ptrs_gpu =
110+
reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) +
111+
num_nvl_bytes + barrier_signal_bytes);
112+
113+
// Set barrier signals
114+
barrier_signal_ptrs[nvl_rank] = reinterpret_cast<int*>(
115+
static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
116+
barrier_signal_ptrs_gpu = reinterpret_cast<int**>(
117+
static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
118+
barrier_signal_bytes + buffer_ptr_bytes);
120119

121120
// No need to synchronize, will do a full device sync during `sync`
122121
CUDA_CHECK(cudaMemsetAsync(
123-
buffer_ptrs[nvl_rank],
124-
0,
125-
num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes,
126-
comm_stream));
122+
barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream));
127123
}
128124

129125
// Create 32 MiB workspace
@@ -165,8 +161,7 @@ Buffer::~Buffer() noexcept(false) {
165161
if (num_nvl_bytes > 0) {
166162
// Barrier
167163
intranode::barrier(
168-
task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
169-
move_fifo_slots();
164+
barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
170165
CUDA_CHECK(cudaDeviceSynchronize());
171166

172167
// Close remote IPC
@@ -197,10 +192,6 @@ Buffer::~Buffer() noexcept(false) {
197192
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
198193
}
199194

200-
void Buffer::move_fifo_slots(int num_slots) {
201-
head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
202-
}
203-
204195
bool Buffer::is_available() const { return available; }
205196

206197
bool Buffer::is_internode_available() const {
@@ -249,7 +240,7 @@ void Buffer::sync(
249240

250241
// Sync IPC handles
251242
if (num_nvl_bytes > 0) {
252-
EP_HOST_ASSERT(num_ranks == static_cast<int64_t>(device_ids.size()));
243+
EP_HOST_ASSERT(num_ranks == device_ids.size());
253244
EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size());
254245
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks;
255246
++i) {
@@ -261,22 +252,22 @@ void Buffer::sync(
261252
ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
262253
CUDA_CHECK(cudaIpcOpenMemHandle(
263254
&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
264-
task_fifo_ptrs[i] = reinterpret_cast<int*>(
265-
reinterpret_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
255+
barrier_signal_ptrs[i] = reinterpret_cast<int*>(
256+
static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
266257
} else {
267258
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved,
268259
handle_str.c_str(),
269260
CUDA_IPC_HANDLE_SIZE) == 0);
270261
}
271262
}
272263

273-
// Copy all buffer and task pointers to GPU
264+
// Copy all buffer and barrier signal pointers to GPU
274265
CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu,
275266
buffer_ptrs,
276267
sizeof(void*) * NUM_MAX_NVL_PEERS,
277268
cudaMemcpyHostToDevice));
278-
CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu,
279-
task_fifo_ptrs,
269+
CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu,
270+
barrier_signal_ptrs,
280271
sizeof(int*) * NUM_MAX_NVL_PEERS,
281272
cudaMemcpyHostToDevice));
282273
CUDA_CHECK(cudaDeviceSynchronize());
@@ -520,7 +511,7 @@ Buffer::intranode_dispatch(
520511

521512
// FP8 scales checks
522513
float* x_scales_ptr = nullptr;
523-
int num_scales = 0;
514+
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
524515
if (x_scales.has_value()) {
525516
EP_HOST_ASSERT(x.element_size() == 1);
526517
EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32);
@@ -529,6 +520,8 @@ Buffer::intranode_dispatch(
529520
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
530521
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
531522
x_scales_ptr = x_scales->data_ptr<float>();
523+
scale_token_stride = static_cast<int>(x_scales->stride(0));
524+
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
532525
}
533526

534527
// Allocate all tensors on comm stream if set
@@ -564,12 +557,10 @@ Buffer::intranode_dispatch(
564557
intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr<int>(),
565558
num_memset_int,
566559
buffer_ptrs_gpu,
567-
task_fifo_ptrs_gpu,
568-
head,
560+
barrier_signal_ptrs_gpu,
569561
rank,
570562
num_ranks,
571563
comm_stream);
572-
move_fifo_slots(2);
573564
} else {
574565
rank_prefix_matrix = ConvertPaddleTensorToDetailTensor(
575566
paddle::experimental::empty({num_ranks, num_ranks},
@@ -604,12 +595,10 @@ Buffer::intranode_dispatch(
604595
num_memset_int,
605596
expert_alignment,
606597
buffer_ptrs_gpu,
607-
task_fifo_ptrs_gpu,
608-
head,
598+
barrier_signal_ptrs_gpu,
609599
rank,
610600
comm_stream,
611601
num_channels);
612-
move_fifo_slots(3);
613602

614603
// Synchronize total received tokens and tokens per expert
615604
auto start_time = std::chrono::high_resolution_clock::now();
@@ -719,10 +708,13 @@ Buffer::intranode_dispatch(
719708
is_token_in_rank.data_ptr<bool>(),
720709
channel_prefix_matrix.data_ptr<int>(),
721710
num_tokens,
711+
0, // num_worst_tokens (not exposed)
722712
static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)),
723713
num_topk,
724714
num_experts,
725715
num_scales,
716+
scale_token_stride,
717+
scale_hidden_stride,
726718
buffer_ptrs_gpu,
727719
rank,
728720
num_ranks,
@@ -867,15 +859,11 @@ Buffer::intranode_combine(
867859
num_channels,
868860
num_recv_tokens,
869861
num_channels * num_ranks * 2,
870-
task_fifo_ptrs_gpu,
871-
head,
862+
barrier_signal_ptrs_gpu,
872863
rank,
873864
num_ranks,
874865
comm_stream);
875866

876-
// NOTES: this function uses two FIFO slots (barrier before and after)
877-
move_fifo_slots(2);
878-
879867
// Combine data
880868
auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
881869
{num_recv_tokens, hidden}, x.dtype(), x.place()));
@@ -895,6 +883,8 @@ Buffer::intranode_combine(
895883
recv_topk_weights_ptr,
896884
x.data_ptr(),
897885
topk_weights_ptr,
886+
nullptr, // bias_ptrs[0] (not exposed)
887+
nullptr, // bias_ptrs[1] (not exposed)
898888
src_idx.data_ptr<int>(),
899889
rank_prefix_matrix.data_ptr<int>(),
900890
channel_prefix_matrix.data_ptr<int>(),
@@ -1084,7 +1074,7 @@ Buffer::internode_dispatch(
10841074

10851075
// FP8 scales checks
10861076
float* x_scales_ptr = nullptr;
1087-
int num_scales = 0;
1077+
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
10881078
if (x_scales.has_value()) {
10891079
EP_HOST_ASSERT(x.element_size() == 1);
10901080
EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32);
@@ -1093,6 +1083,8 @@ Buffer::internode_dispatch(
10931083
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
10941084
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
10951085
x_scales_ptr = x_scales->data_ptr<float>();
1086+
scale_token_stride = static_cast<int>(x_scales->stride(0));
1087+
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
10961088
}
10971089

10981090
// Allocate all tensors on comm stream if set
@@ -1144,15 +1136,13 @@ Buffer::internode_dispatch(
11441136
config.num_max_rdma_chunked_recv_tokens,
11451137
buffer_ptrs_gpu,
11461138
config.num_max_nvl_chunked_recv_tokens,
1147-
task_fifo_ptrs_gpu,
1148-
head,
1139+
barrier_signal_ptrs_gpu,
11491140
rank,
11501141
comm_stream,
11511142
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
11521143
num_nvl_bytes,
11531144
true,
11541145
low_latency_mode);
1155-
move_fifo_slots(2);
11561146
} else {
11571147
rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor(
11581148
paddle::experimental::empty({num_rdma_ranks, num_channels},
@@ -1196,14 +1186,12 @@ Buffer::internode_dispatch(
11961186
config.num_max_rdma_chunked_recv_tokens,
11971187
buffer_ptrs_gpu,
11981188
config.num_max_nvl_chunked_recv_tokens,
1199-
task_fifo_ptrs_gpu,
1200-
head,
1189+
barrier_signal_ptrs_gpu,
12011190
rank,
12021191
comm_stream,
12031192
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
12041193
num_nvl_bytes,
12051194
low_latency_mode);
1206-
move_fifo_slots(3);
12071195

12081196
// Synchronize total received tokens and tokens per expert
12091197
auto start_time = std::chrono::high_resolution_clock::now();
@@ -1320,12 +1308,14 @@ Buffer::internode_dispatch(
13201308
recv_rdma_rank_prefix_sum.data_ptr<int>(),
13211309
gbl_channel_prefix_matrix.data_ptr<int>(),
13221310
recv_gbl_rank_prefix_sum.data_ptr<int>(),
1311+
is_token_in_rank.data_ptr<bool>(),
13231312
num_tokens,
13241313
hidden_int4,
13251314
num_scales,
13261315
num_topk,
13271316
num_experts,
1328-
is_token_in_rank.data_ptr<bool>(),
1317+
scale_token_stride,
1318+
scale_hidden_stride,
13291319
rdma_buffer_ptr,
13301320
config.num_max_rdma_chunked_send_tokens,
13311321
config.num_max_rdma_chunked_recv_tokens,
@@ -1523,15 +1513,13 @@ Buffer::internode_combine(
15231513
config.num_max_rdma_chunked_recv_tokens,
15241514
buffer_ptrs_gpu,
15251515
config.num_max_nvl_chunked_recv_tokens,
1526-
task_fifo_ptrs_gpu,
1527-
head,
1516+
barrier_signal_ptrs_gpu,
15281517
rank,
15291518
comm_stream,
15301519
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
15311520
num_nvl_bytes,
15321521
false,
15331522
low_latency_mode);
1534-
move_fifo_slots(2);
15351523

15361524
// Launch data combine
15371525
auto combined_x =
@@ -1543,6 +1531,8 @@ Buffer::internode_combine(
15431531
is_combined_token_in_rank.data_ptr<bool>(),
15441532
x.data_ptr(),
15451533
topk_weights_ptr,
1534+
nullptr, // bias_ptrs[0] (not exposed)
1535+
nullptr, // bias_ptrs[1] (not exposed)
15461536
combined_rdma_head.data_ptr<int>(),
15471537
combined_nvl_head.data_ptr<int>(),
15481538
src_meta.data_ptr(),

paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,9 @@ struct Buffer {
7777
// After IPC/NVSHMEM synchronization, this flag will be true
7878
bool available = false;
7979

80-
// Task fifo
81-
int head = 0;
82-
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
83-
int** task_fifo_ptrs_gpu = nullptr;
80+
// Barrier signals
81+
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
82+
int** barrier_signal_ptrs_gpu = nullptr;
8483

8584
// Workspace
8685
void* workspace = nullptr;
@@ -97,9 +96,6 @@ struct Buffer {
9796
volatile int* moe_recv_rdma_counter = nullptr;
9897
int* moe_recv_rdma_counter_mapped = nullptr;
9998

100-
private:
101-
void move_fifo_slots(int num_slots = 1);
102-
10399
public:
104100
Buffer(int rank,
105101
int num_ranks,

paddle/fluid/distributed/collective/deep_ep/include/event_pool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace deep_ep::detail {
2222

2323
class EventPool {
2424
public:
25-
EventPool() = default;
25+
EventPool();
2626
EventPool(const EventPool&) = delete;
2727
EventPool(EventPool&&) = delete;
2828
~EventPool();

paddle/fluid/distributed/collective/deep_ep/include/types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ struct Tensor {
7373
}
7474

7575
int64_t element_size() const { return phi::SizeOf(raw_tensor_.dtype()); }
76+
77+
int64_t stride(int64_t d) const { return raw_tensor_.strides().at(d); }
7678
};
7779

7880
} // namespace deep_ep::detail

0 commit comments

Comments
 (0)