@@ -42,16 +42,20 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) {
42
42
/* channel->devPeers is not shared, so just free it when calling commFree() */
43
43
NCCLCHECK (ncclCudaCallocAsync (&channel->devPeers , nPeers, sharedRes->deviceStream .cudaStream ));
44
44
ncclCommPushCudaFree (comm, channel->devPeers );
45
+ NCCLCHECK (ncclCalloc (&channel->devPeersHostPtr , nPeers));
45
46
for (int r = 0 ; r < nRanks; r++) {
46
47
uintptr_t addr = (uintptr_t )(comm->sharedRes ->devPeers [channelId] + comm->topParentRanks [r]);
47
48
NCCLCHECK (ncclCudaMemcpyAsync ((uintptr_t *)(channel->devPeers + r), (uintptr_t *)&addr, 1 , sharedRes->deviceStream .cudaStream ));
49
+ channel->devPeersHostPtr [r] = (struct ncclDevChannelPeer *)addr;
48
50
}
49
51
}
50
52
51
53
channel->ring .userRanks = ncclMemoryStackAlloc<int >(&comm->memPermanent , nRanks);
52
54
NCCLCHECK (ncclCudaCallocAsync (&channel->devRingUserRanks , nRanks, sharedRes->deviceStream .cudaStream ));
53
55
ncclCommPushCudaFree (comm, channel->devRingUserRanks );
54
56
57
+ /* guarantee addr has been copied into channel->devPeers */
58
+ NCCLCHECK (ncclStrongStreamSynchronize (&sharedRes->deviceStream ));
55
59
NCCLCHECK (ncclStrongStreamRelease (ncclCudaGraphNone (), &sharedRes->deviceStream ));
56
60
57
61
return ncclSuccess;
@@ -77,6 +81,7 @@ ncclResult_t initNvlsChannel(struct ncclComm* comm, int channelId, struct ncclCo
77
81
uintptr_t addr = (uintptr_t )(parent->channels [channelId].nvlsDevPeers + tr);
78
82
channel->peers [comm->nRanks + 1 + r] = parent->channels [channelId].nvlsPeers + tr;
79
83
NCCLCHECK (ncclCudaMemcpyAsync ((uintptr_t *)(channel->devPeers + comm->nRanks + 1 + r), (uintptr_t *)&addr, 1 , sharedRes->deviceStream .cudaStream ));
84
+ channel->devPeersHostPtr [comm->nRanks + 1 + r] = (struct ncclDevChannelPeer *)addr;
80
85
ncclAtomicRefCountIncrement (&parent->channels [channelId].nvlsPeers [tr].refCount );
81
86
}
82
87
} else {
@@ -86,10 +91,12 @@ ncclResult_t initNvlsChannel(struct ncclComm* comm, int channelId, struct ncclCo
86
91
uintptr_t addr = (uintptr_t )(channel->nvlsDevPeers + r);
87
92
channel->peers [comm->nRanks + 1 + r] = channel->nvlsPeers + r;
88
93
NCCLCHECK (ncclCudaMemcpyAsync ((uintptr_t *)(channel->devPeers + comm->nRanks + 1 + r), (uintptr_t *)&addr, 1 , sharedRes->deviceStream .cudaStream ));
94
+ channel->devPeersHostPtr [comm->nRanks + 1 + r] = (struct ncclDevChannelPeer *)addr;
89
95
ncclAtomicRefCountIncrement (&channel->nvlsPeers [r].refCount );
90
96
}
91
97
}
92
98
99
+ NCCLCHECK (ncclStrongStreamSynchronize (&sharedRes->deviceStream ));
93
100
NCCLCHECK (ncclStrongStreamRelease (ncclCudaGraphNone (), &sharedRes->deviceStream ));
94
101
95
102
return ncclSuccess;
@@ -114,16 +121,19 @@ ncclResult_t initCollnetChannel(struct ncclComm* comm, int channelId, struct ncc
114
121
addr = (uintptr_t )parent->channels [channelId].collnetDevPeers ;
115
122
channel->peers [comm->nRanks ] = parent->channels [channelId].collnetPeers ;
116
123
NCCLCHECK (ncclCudaMemcpyAsync ((uintptr_t *)(channel->devPeers + comm->nRanks ), (uintptr_t *)&addr, 1 , sharedRes->deviceStream .cudaStream ));
124
+ channel->devPeersHostPtr [comm->nRanks ] = (struct ncclDevChannelPeer *)addr;
117
125
ncclAtomicRefCountIncrement (&parent->channels [channelId].collnetPeers ->refCount );
118
126
} else {
119
127
NCCLCHECK (ncclCalloc (&channel->collnetPeers , 1 ));
120
128
NCCLCHECK (ncclCudaCallocAsync (&channel->collnetDevPeers , 1 , sharedRes->deviceStream .cudaStream ));
121
129
addr = (uintptr_t )channel->collnetDevPeers ;
122
130
channel->peers [comm->nRanks ] = channel->collnetPeers ;
123
131
NCCLCHECK (ncclCudaMemcpyAsync ((uintptr_t *)(channel->devPeers + comm->nRanks ), (uintptr_t *)&addr, 1 , sharedRes->deviceStream .cudaStream ));
132
+ channel->devPeersHostPtr [comm->nRanks ] = (struct ncclDevChannelPeer *)addr;
124
133
ncclAtomicRefCountIncrement (&channel->collnetPeers ->refCount );
125
134
}
126
135
136
+ NCCLCHECK (ncclStrongStreamSynchronize (&sharedRes->deviceStream ));
127
137
NCCLCHECK (ncclStrongStreamRelease (ncclCudaGraphNone (), &sharedRes->deviceStream ));
128
138
129
139
return ncclSuccess;
@@ -156,5 +166,7 @@ ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks, int collnetNRa
156
166
}
157
167
}
158
168
}
169
+
170
+ free (channel->devPeersHostPtr );
159
171
return ncclSuccess;
160
172
}
0 commit comments