@@ -69,14 +69,15 @@ Buffer::Buffer(int rank,
69
69
calc_ctx = reinterpret_cast <phi::GPUContext*>(
70
70
reinterpret_cast <paddle::distributed::ProcessGroupNCCL*>(pg)
71
71
->GetDeviceContext (place, true ));
72
- // Task fifo memory
73
- int64_t fifo_bytes = sizeof (int ) * NUM_MAX_FIFO_SLOTS;
74
- int64_t buffer_ptr_bytes = sizeof (void *) * NUM_MAX_NVL_PEERS;
75
- int64_t task_ptr_bytes = sizeof (int *) * NUM_MAX_NVL_PEERS;
72
+
73
+ // Metadata memory
74
+ int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof (int );
75
+ int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof (void *);
76
+ int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof (int *);
76
77
77
78
// Common checks
78
79
EP_HOST_ASSERT (num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
79
- (num_nvl_bytes <= std::numeric_limits<int64_t >::max () ||
80
+ (num_nvl_bytes <= std::numeric_limits<int >::max () ||
80
81
num_rdma_bytes == 0 ));
81
82
EP_HOST_ASSERT (
82
83
num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
@@ -90,40 +91,35 @@ Buffer::Buffer(int rank,
90
91
EP_HOST_ASSERT (num_ranks > NUM_MAX_NVL_PEERS || low_latency_mode);
91
92
92
93
// Get ranks
93
- // CUDA_CHECK(cudaGetDevice(&device_id));
94
94
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
95
- num_rdma_ranks = std::max (1 , num_ranks / NUM_MAX_NVL_PEERS),
95
+ num_rdma_ranks = std::max (1 , num_ranks / NUM_MAX_NVL_PEERS);
96
96
num_nvl_ranks = std::min (num_ranks, NUM_MAX_NVL_PEERS);
97
97
98
98
// Get device info
99
99
cudaDeviceProp device_prop = {};
100
100
CUDA_CHECK (cudaGetDeviceProperties (&device_prop, device_id));
101
101
102
102
if (num_nvl_bytes > 0 ) {
103
- // Local IPC: alloc local memory and set local IPC handle
104
- CUDA_CHECK (cudaMalloc (
105
- &buffer_ptrs[nvl_rank],
106
- num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes ));
103
+ // Local IPC: alloc local memory and set local IPC handles
104
+ CUDA_CHECK (cudaMalloc (&buffer_ptrs[nvl_rank],
105
+ num_nvl_bytes + barrier_signal_bytes +
106
+ buffer_ptr_bytes + barrier_signal_ptr_bytes ));
107
107
CUDA_CHECK (
108
108
cudaIpcGetMemHandle (&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
109
- buffer_ptrs_gpu = reinterpret_cast <void **>(
110
- reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
111
- fifo_bytes);
112
-
113
- // Set task fifo
114
- EP_HOST_ASSERT (NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0 );
115
- task_fifo_ptrs[nvl_rank] = reinterpret_cast <int *>(
116
- reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
117
- task_fifo_ptrs_gpu = reinterpret_cast <int **>(
118
- reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
119
- fifo_bytes + buffer_ptr_bytes);
109
+ buffer_ptrs_gpu =
110
+ reinterpret_cast <void **>(static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) +
111
+ num_nvl_bytes + barrier_signal_bytes);
112
+
113
+ // Set barrier signals
114
+ barrier_signal_ptrs[nvl_rank] = reinterpret_cast <int *>(
115
+ static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
116
+ barrier_signal_ptrs_gpu = reinterpret_cast <int **>(
117
+ static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
118
+ barrier_signal_bytes + buffer_ptr_bytes);
120
119
121
120
// No need to synchronize, will do a full device sync during `sync`
122
121
CUDA_CHECK (cudaMemsetAsync (
123
- buffer_ptrs[nvl_rank],
124
- 0 ,
125
- num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes,
126
- comm_stream));
122
+ barrier_signal_ptrs[nvl_rank], 0 , barrier_signal_bytes, comm_stream));
127
123
}
128
124
129
125
// Create 32 MiB workspace
@@ -165,8 +161,7 @@ Buffer::~Buffer() noexcept(false) {
165
161
if (num_nvl_bytes > 0 ) {
166
162
// Barrier
167
163
intranode::barrier (
168
- task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
169
- move_fifo_slots ();
164
+ barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
170
165
CUDA_CHECK (cudaDeviceSynchronize ());
171
166
172
167
// Close remote IPC
@@ -197,10 +192,6 @@ Buffer::~Buffer() noexcept(false) {
197
192
CUDA_CHECK (cudaFreeHost (const_cast <int *>(moe_recv_expert_counter)));
198
193
}
199
194
200
- void Buffer::move_fifo_slots (int num_slots) {
201
- head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
202
- }
203
-
204
195
bool Buffer::is_available () const { return available; }
205
196
206
197
bool Buffer::is_internode_available () const {
@@ -249,7 +240,7 @@ void Buffer::sync(
249
240
250
241
// Sync IPC handles
251
242
if (num_nvl_bytes > 0 ) {
252
- EP_HOST_ASSERT (num_ranks == static_cast < int64_t >( device_ids.size () ));
243
+ EP_HOST_ASSERT (num_ranks == device_ids.size ());
253
244
EP_HOST_ASSERT (device_ids.size () == all_gathered_handles.size ());
254
245
for (int i = 0 , offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks;
255
246
++i) {
@@ -261,22 +252,22 @@ void Buffer::sync(
261
252
ipc_handles[i].reserved , handle_str.c_str (), CUDA_IPC_HANDLE_SIZE);
262
253
CUDA_CHECK (cudaIpcOpenMemHandle (
263
254
&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
264
- task_fifo_ptrs [i] = reinterpret_cast <int *>(
265
- reinterpret_cast <uint8_t *>(buffer_ptrs[i]) + num_nvl_bytes);
255
+ barrier_signal_ptrs [i] = reinterpret_cast <int *>(
256
+ static_cast <uint8_t *>(buffer_ptrs[i]) + num_nvl_bytes);
266
257
} else {
267
258
EP_HOST_ASSERT (std::memcmp (ipc_handles[i].reserved ,
268
259
handle_str.c_str (),
269
260
CUDA_IPC_HANDLE_SIZE) == 0 );
270
261
}
271
262
}
272
263
273
- // Copy all buffer and task pointers to GPU
264
+ // Copy all buffer and barrier signal pointers to GPU
274
265
CUDA_CHECK (cudaMemcpy (buffer_ptrs_gpu,
275
266
buffer_ptrs,
276
267
sizeof (void *) * NUM_MAX_NVL_PEERS,
277
268
cudaMemcpyHostToDevice));
278
- CUDA_CHECK (cudaMemcpy (task_fifo_ptrs_gpu ,
279
- task_fifo_ptrs ,
269
+ CUDA_CHECK (cudaMemcpy (barrier_signal_ptrs_gpu ,
270
+ barrier_signal_ptrs ,
280
271
sizeof (int *) * NUM_MAX_NVL_PEERS,
281
272
cudaMemcpyHostToDevice));
282
273
CUDA_CHECK (cudaDeviceSynchronize ());
@@ -520,7 +511,7 @@ Buffer::intranode_dispatch(
520
511
521
512
// FP8 scales checks
522
513
float * x_scales_ptr = nullptr ;
523
- int num_scales = 0 ;
514
+ int num_scales = 0 , scale_token_stride = 0 , scale_hidden_stride = 0 ;
524
515
if (x_scales.has_value ()) {
525
516
EP_HOST_ASSERT (x.element_size () == 1 );
526
517
EP_HOST_ASSERT (x_scales->scalar_type () == deep_ep::detail::kFloat32 );
@@ -529,6 +520,8 @@ Buffer::intranode_dispatch(
529
520
EP_HOST_ASSERT (x_scales->size (0 ) == num_tokens);
530
521
num_scales = x_scales->dim () == 1 ? 1 : static_cast <int >(x_scales->size (1 ));
531
522
x_scales_ptr = x_scales->data_ptr <float >();
523
+ scale_token_stride = static_cast <int >(x_scales->stride (0 ));
524
+ scale_hidden_stride = static_cast <int >(x_scales->stride (1 ));
532
525
}
533
526
534
527
// Allocate all tensors on comm stream if set
@@ -564,12 +557,10 @@ Buffer::intranode_dispatch(
564
557
intranode::cached_notify_dispatch (rank_prefix_matrix.data_ptr <int >(),
565
558
num_memset_int,
566
559
buffer_ptrs_gpu,
567
- task_fifo_ptrs_gpu,
568
- head,
560
+ barrier_signal_ptrs_gpu,
569
561
rank,
570
562
num_ranks,
571
563
comm_stream);
572
- move_fifo_slots (2 );
573
564
} else {
574
565
rank_prefix_matrix = ConvertPaddleTensorToDetailTensor (
575
566
paddle::experimental::empty ({num_ranks, num_ranks},
@@ -604,12 +595,10 @@ Buffer::intranode_dispatch(
604
595
num_memset_int,
605
596
expert_alignment,
606
597
buffer_ptrs_gpu,
607
- task_fifo_ptrs_gpu,
608
- head,
598
+ barrier_signal_ptrs_gpu,
609
599
rank,
610
600
comm_stream,
611
601
num_channels);
612
- move_fifo_slots (3 );
613
602
614
603
// Synchronize total received tokens and tokens per expert
615
604
auto start_time = std::chrono::high_resolution_clock::now ();
@@ -719,10 +708,13 @@ Buffer::intranode_dispatch(
719
708
is_token_in_rank.data_ptr <bool >(),
720
709
channel_prefix_matrix.data_ptr <int >(),
721
710
num_tokens,
711
+ 0 , // num_worst_tokens (not exposed)
722
712
static_cast <int >(hidden * recv_x.element_size () / sizeof (int4)),
723
713
num_topk,
724
714
num_experts,
725
715
num_scales,
716
+ scale_token_stride,
717
+ scale_hidden_stride,
726
718
buffer_ptrs_gpu,
727
719
rank,
728
720
num_ranks,
@@ -867,15 +859,11 @@ Buffer::intranode_combine(
867
859
num_channels,
868
860
num_recv_tokens,
869
861
num_channels * num_ranks * 2 ,
870
- task_fifo_ptrs_gpu,
871
- head,
862
+ barrier_signal_ptrs_gpu,
872
863
rank,
873
864
num_ranks,
874
865
comm_stream);
875
866
876
- // NOTES: this function uses two FIFO slots (barrier before and after)
877
- move_fifo_slots (2 );
878
-
879
867
// Combine data
880
868
auto recv_x = ConvertPaddleTensorToDetailTensor (paddle::experimental::empty (
881
869
{num_recv_tokens, hidden}, x.dtype (), x.place ()));
@@ -895,6 +883,8 @@ Buffer::intranode_combine(
895
883
recv_topk_weights_ptr,
896
884
x.data_ptr (),
897
885
topk_weights_ptr,
886
+ nullptr , // bias_ptrs[0] (not exposed)
887
+ nullptr , // bias_ptrs[1] (not exposed)
898
888
src_idx.data_ptr <int >(),
899
889
rank_prefix_matrix.data_ptr <int >(),
900
890
channel_prefix_matrix.data_ptr <int >(),
@@ -1084,7 +1074,7 @@ Buffer::internode_dispatch(
1084
1074
1085
1075
// FP8 scales checks
1086
1076
float * x_scales_ptr = nullptr ;
1087
- int num_scales = 0 ;
1077
+ int num_scales = 0 , scale_token_stride = 0 , scale_hidden_stride = 0 ;
1088
1078
if (x_scales.has_value ()) {
1089
1079
EP_HOST_ASSERT (x.element_size () == 1 );
1090
1080
EP_HOST_ASSERT (x_scales->scalar_type () == deep_ep::detail::kFloat32 );
@@ -1093,6 +1083,8 @@ Buffer::internode_dispatch(
1093
1083
EP_HOST_ASSERT (x_scales->size (0 ) == num_tokens);
1094
1084
num_scales = x_scales->dim () == 1 ? 1 : static_cast <int >(x_scales->size (1 ));
1095
1085
x_scales_ptr = x_scales->data_ptr <float >();
1086
+ scale_token_stride = static_cast <int >(x_scales->stride (0 ));
1087
+ scale_hidden_stride = static_cast <int >(x_scales->stride (1 ));
1096
1088
}
1097
1089
1098
1090
// Allocate all tensors on comm stream if set
@@ -1144,15 +1136,13 @@ Buffer::internode_dispatch(
1144
1136
config.num_max_rdma_chunked_recv_tokens ,
1145
1137
buffer_ptrs_gpu,
1146
1138
config.num_max_nvl_chunked_recv_tokens ,
1147
- task_fifo_ptrs_gpu,
1148
- head,
1139
+ barrier_signal_ptrs_gpu,
1149
1140
rank,
1150
1141
comm_stream,
1151
1142
config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
1152
1143
num_nvl_bytes,
1153
1144
true ,
1154
1145
low_latency_mode);
1155
- move_fifo_slots (2 );
1156
1146
} else {
1157
1147
rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor (
1158
1148
paddle::experimental::empty ({num_rdma_ranks, num_channels},
@@ -1196,14 +1186,12 @@ Buffer::internode_dispatch(
1196
1186
config.num_max_rdma_chunked_recv_tokens ,
1197
1187
buffer_ptrs_gpu,
1198
1188
config.num_max_nvl_chunked_recv_tokens ,
1199
- task_fifo_ptrs_gpu,
1200
- head,
1189
+ barrier_signal_ptrs_gpu,
1201
1190
rank,
1202
1191
comm_stream,
1203
1192
config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
1204
1193
num_nvl_bytes,
1205
1194
low_latency_mode);
1206
- move_fifo_slots (3 );
1207
1195
1208
1196
// Synchronize total received tokens and tokens per expert
1209
1197
auto start_time = std::chrono::high_resolution_clock::now ();
@@ -1320,12 +1308,14 @@ Buffer::internode_dispatch(
1320
1308
recv_rdma_rank_prefix_sum.data_ptr <int >(),
1321
1309
gbl_channel_prefix_matrix.data_ptr <int >(),
1322
1310
recv_gbl_rank_prefix_sum.data_ptr <int >(),
1311
+ is_token_in_rank.data_ptr <bool >(),
1323
1312
num_tokens,
1324
1313
hidden_int4,
1325
1314
num_scales,
1326
1315
num_topk,
1327
1316
num_experts,
1328
- is_token_in_rank.data_ptr <bool >(),
1317
+ scale_token_stride,
1318
+ scale_hidden_stride,
1329
1319
rdma_buffer_ptr,
1330
1320
config.num_max_rdma_chunked_send_tokens ,
1331
1321
config.num_max_rdma_chunked_recv_tokens ,
@@ -1523,15 +1513,13 @@ Buffer::internode_combine(
1523
1513
config.num_max_rdma_chunked_recv_tokens ,
1524
1514
buffer_ptrs_gpu,
1525
1515
config.num_max_nvl_chunked_recv_tokens ,
1526
- task_fifo_ptrs_gpu,
1527
- head,
1516
+ barrier_signal_ptrs_gpu,
1528
1517
rank,
1529
1518
comm_stream,
1530
1519
config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
1531
1520
num_nvl_bytes,
1532
1521
false ,
1533
1522
low_latency_mode);
1534
- move_fifo_slots (2 );
1535
1523
1536
1524
// Launch data combine
1537
1525
auto combined_x =
@@ -1543,6 +1531,8 @@ Buffer::internode_combine(
1543
1531
is_combined_token_in_rank.data_ptr <bool >(),
1544
1532
x.data_ptr (),
1545
1533
topk_weights_ptr,
1534
+ nullptr , // bias_ptrs[0] (not exposed)
1535
+ nullptr , // bias_ptrs[1] (not exposed)
1546
1536
combined_rdma_head.data_ptr <int >(),
1547
1537
combined_nvl_head.data_ptr <int >(),
1548
1538
src_meta.data_ptr (),
0 commit comments