Skip to content

Commit 4365458

Browse files
KaimingOuyangsjeaugey
authored andcommitted
Fix cudaMemcpyAsync bug
We are trying to use the copy result of first cudaMemcpyAsync in the second cudaMemcpyAsync without sync in between. This patch fixes it by allocating a CPU side array to cache device side addr so that we can avoid this consecutive cuda mem copy. Fixes #957
1 parent 559b70f commit 4365458

File tree

5 files changed

+21
-13
lines changed

5 files changed

+21
-13
lines changed

src/channel.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,20 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) {
4242
/* channel->devPeers is not shared, so just free it when calling commFree() */
4343
NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nPeers, sharedRes->deviceStream.cudaStream));
4444
ncclCommPushCudaFree(comm, channel->devPeers);
45+
NCCLCHECK(ncclCalloc(&channel->devPeersHostPtr, nPeers));
4546
for (int r = 0; r < nRanks; r++) {
4647
uintptr_t addr = (uintptr_t)(comm->sharedRes->devPeers[channelId] + comm->topParentRanks[r]);
4748
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
49+
channel->devPeersHostPtr[r] = (struct ncclDevChannelPeer*)addr;
4850
}
4951
}
5052

5153
channel->ring.userRanks = ncclMemoryStackAlloc<int>(&comm->memPermanent, nRanks);
5254
NCCLCHECK(ncclCudaCallocAsync(&channel->devRingUserRanks, nRanks, sharedRes->deviceStream.cudaStream));
5355
ncclCommPushCudaFree(comm, channel->devRingUserRanks);
5456

57+
/* guarantee addr has been copied into channel->devPeers */
58+
NCCLCHECK(ncclStrongStreamSynchronize(&sharedRes->deviceStream));
5559
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));
5660

5761
return ncclSuccess;
@@ -77,6 +81,7 @@ ncclResult_t initNvlsChannel(struct ncclComm* comm, int channelId, struct ncclCo
7781
uintptr_t addr = (uintptr_t)(parent->channels[channelId].nvlsDevPeers + tr);
7882
channel->peers[comm->nRanks + 1 + r] = parent->channels[channelId].nvlsPeers + tr;
7983
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks + 1 + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
84+
channel->devPeersHostPtr[comm->nRanks + 1 + r] = (struct ncclDevChannelPeer*)addr;
8085
ncclAtomicRefCountIncrement(&parent->channels[channelId].nvlsPeers[tr].refCount);
8186
}
8287
} else {
@@ -86,10 +91,12 @@ ncclResult_t initNvlsChannel(struct ncclComm* comm, int channelId, struct ncclCo
8691
uintptr_t addr = (uintptr_t)(channel->nvlsDevPeers + r);
8792
channel->peers[comm->nRanks + 1 + r] = channel->nvlsPeers + r;
8893
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks + 1 + r), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
94+
channel->devPeersHostPtr[comm->nRanks + 1 + r] = (struct ncclDevChannelPeer*)addr;
8995
ncclAtomicRefCountIncrement(&channel->nvlsPeers[r].refCount);
9096
}
9197
}
9298

99+
NCCLCHECK(ncclStrongStreamSynchronize(&sharedRes->deviceStream));
93100
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));
94101

95102
return ncclSuccess;
@@ -114,16 +121,19 @@ ncclResult_t initCollnetChannel(struct ncclComm* comm, int channelId, struct ncc
114121
addr = (uintptr_t)parent->channels[channelId].collnetDevPeers;
115122
channel->peers[comm->nRanks] = parent->channels[channelId].collnetPeers;
116123
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
124+
channel->devPeersHostPtr[comm->nRanks] = (struct ncclDevChannelPeer*)addr;
117125
ncclAtomicRefCountIncrement(&parent->channels[channelId].collnetPeers->refCount);
118126
} else {
119127
NCCLCHECK(ncclCalloc(&channel->collnetPeers, 1));
120128
NCCLCHECK(ncclCudaCallocAsync(&channel->collnetDevPeers, 1, sharedRes->deviceStream.cudaStream));
121129
addr = (uintptr_t)channel->collnetDevPeers;
122130
channel->peers[comm->nRanks] = channel->collnetPeers;
123131
NCCLCHECK(ncclCudaMemcpyAsync((uintptr_t*)(channel->devPeers + comm->nRanks), (uintptr_t*)&addr, 1, sharedRes->deviceStream.cudaStream));
132+
channel->devPeersHostPtr[comm->nRanks] = (struct ncclDevChannelPeer*)addr;
124133
ncclAtomicRefCountIncrement(&channel->collnetPeers->refCount);
125134
}
126135

136+
NCCLCHECK(ncclStrongStreamSynchronize(&sharedRes->deviceStream));
127137
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &sharedRes->deviceStream));
128138

129139
return ncclSuccess;
@@ -156,5 +166,7 @@ ncclResult_t freeChannel(struct ncclChannel* channel, int nRanks, int collnetNRa
156166
}
157167
}
158168
}
169+
170+
free(channel->devPeersHostPtr);
159171
return ncclSuccess;
160172
}

src/include/comm.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ struct ncclSharedResources {
124124
struct ncclChannel {
125125
struct ncclChannelPeer** peers;
126126
struct ncclDevChannelPeer** devPeers;
127+
/* devPeer pointer array used for host side access */
128+
struct ncclDevChannelPeer** devPeersHostPtr;
127129
struct ncclRing ring;
128130
int* devRingUserRanks;
129131
struct ncclTree tree;

src/init.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) {
437437

438438
NCCLCHECKGOTO(ncclCudaMemcpyAsync(devCommAndChans, &tmpCommAndChans, 1, comm->sharedRes->deviceStream.cudaStream), ret, fail);
439439
exit:
440-
CUDACHECK(cudaStreamSynchronize(comm->sharedRes->deviceStream.cudaStream));
440+
NCCLCHECK(ncclStrongStreamSynchronize(&comm->sharedRes->deviceStream));
441441
NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->sharedRes->deviceStream));
442442
return ret;
443443
fail:

src/transport.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,9 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
147147
if (conn->connected == 0) {
148148
NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData[i] + sendDataOffset++, 1, comm->rank, conn), ret, fail);
149149
if (ret == ncclSuccess) {
150-
struct ncclDevChannelPeer* addr;
151150
conn->connected = 1;
152151
/* comm->channels[c].devPeers[sendPeer]->send[connIndex] is a device memory access. */
153-
CUDACHECKGOTO(cudaMemcpyAsync(&addr, &comm->channels[c].devPeers[sendPeer], sizeof(struct ncclDevChannelPeer*), cudaMemcpyDeviceToHost, comm->sharedRes->hostStream.cudaStream), ret, fail);
154-
CUDACHECKGOTO(cudaMemcpyAsync(&addr->send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), ret, fail);
152+
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeersHostPtr[sendPeer]->send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), ret, fail);
155153
} else if (ret == ncclInProgress) {
156154
allChannelsConnected = false;
157155
}
@@ -167,11 +165,9 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
167165
if (conn->connected == 0) {
168166
NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData[i] + recvDataOffset++, 1, comm->rank, conn), ret, fail);
169167
if (ret == ncclSuccess) {
170-
struct ncclDevChannelPeer* addr;
171168
conn->connected = 1;
172169
/* comm->channels[c].devPeers[recvPeer]->recv[connIndex] is a device memory access. */
173-
CUDACHECKGOTO(cudaMemcpyAsync(&addr, &comm->channels[c].devPeers[recvPeer], sizeof(struct ncclDevChannelPeer*), cudaMemcpyDeviceToHost, comm->sharedRes->hostStream.cudaStream), ret, fail);
174-
CUDACHECKGOTO(cudaMemcpyAsync(&addr->recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), ret, fail);
170+
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeersHostPtr[recvPeer]->recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), ret, fail);
175171
} else if (ret == ncclInProgress) {
176172
allChannelsConnected = false;
177173
}

src/transport/nvls.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,10 @@ ncclResult_t ncclNvlsSetup(struct ncclComm* comm, struct ncclComm* parent) {
359359
peer->send[0].conn.tail = (uint64_t*)(mem + buffSize + memSize / 2);
360360
peer->send[0].conn.flags |= NCCL_NVLS_MIN_POLL;
361361

362-
struct ncclDevChannelPeer* addr;
363-
CUDACHECKGOTO(cudaMemcpyAsync(&addr, comm->channels[c].devPeers + nvlsPeer, sizeof(struct ncclDevChannelPeer*), cudaMemcpyDeviceToHost, comm->sharedRes->hostStream.cudaStream), res, cleanup);
364-
CUDACHECKGOTO(cudaMemcpyAsync(&addr->send[0], &peer->send[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
365-
CUDACHECKGOTO(cudaMemcpyAsync(&addr->recv[0], &peer->recv[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
366-
CUDACHECKGOTO(cudaMemcpyAsync(&addr->send[1], &peer->send[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
367-
CUDACHECKGOTO(cudaMemcpyAsync(&addr->recv[1], &peer->recv[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
362+
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeersHostPtr[nvlsPeer]->send[0], &peer->send[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
363+
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeersHostPtr[nvlsPeer]->recv[0], &peer->recv[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
364+
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeersHostPtr[nvlsPeer]->send[1], &peer->send[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
365+
CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeersHostPtr[nvlsPeer]->recv[1], &peer->recv[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->sharedRes->hostStream.cudaStream), res, cleanup);
368366

369367
/*INFO(NCCL_INIT|NCCL_NVLS, "Peer %d Channel %d MC buff %p/%p UC Buff %p/%p",
370368
nvlsPeer, c,

0 commit comments

Comments
 (0)