Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions paddle/fluid/framework/tensor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace paddle {
namespace framework {

void TensorCopy(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx, Tensor* dst) {
const platform::DeviceContext& ctx, Tensor* dst, bool sync) {
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
src.check_memory_size();
Expand All @@ -47,9 +47,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
auto stream =
sync ? nullptr
: reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.stream();
memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
Expand All @@ -58,18 +60,22 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
auto stream =
sync ? nullptr
: reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.stream();
memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
auto stream =
sync ? nullptr
: reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.stream();
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
}
#endif
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace paddle {
namespace framework {

void TensorCopy(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx, Tensor* dst);
const platform::DeviceContext& ctx, Tensor* dst,
bool sync = false);
void TensorCopy(const Tensor& src, const platform::Place& dst_place,
Tensor* dst);

Expand Down
39 changes: 32 additions & 7 deletions paddle/fluid/memory/memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,23 @@ void Copy<platform::CPUPlace, platform::CUDAPlace>(
platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
const void* src, size_t num, cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
if (stream) {
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
} else {
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
}
}

template <>
void Copy<platform::CUDAPlace, platform::CPUPlace>(
platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
const void* src, size_t num, cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
if (stream) {
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
} else {
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
}
}

template <>
Expand All @@ -49,10 +57,19 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
const void* src, size_t num, cudaStream_t stream) {
if (dst_place == src_place) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
if (stream) {
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
} else {
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
}
} else {
platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num,
stream);
if (stream) {
platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
num, stream);
} else {
platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
num);
}
}
}

Expand Down Expand Up @@ -83,7 +100,11 @@ void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
platform::CUDAPlace src_place, const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
if (stream) {
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
} else {
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
}
}

template <>
Expand All @@ -92,7 +113,11 @@ void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
platform::CUDAPinnedPlace src_place, const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
if (stream) {
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
} else {
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
}
}

#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
auto* gpu_ctx = ctxs_[cached_tensor_id].get();
gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) {
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i],
true);
gpu_batch[i].set_lod(cpu_batch[i].lod());
}
}
Expand Down
19 changes: 16 additions & 3 deletions paddle/fluid/platform/gpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,24 @@ void GpuMemcpyAsync(void *dst, const void *src, size_t count,
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync");
}

void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
size_t count, cudaStream_t stream) {
void GpuMemcpySync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind) {
PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind),
"cudaMemcpy failed in paddle::platform::GpuMemcpySync");
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
int src_device, size_t count, cudaStream_t stream) {
PADDLE_ENFORCE(
cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream),
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer");
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync");
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
int src_device, size_t count) {
PADDLE_ENFORCE(
cudaMemcpyPeer(dst, dst_device, src, src_device, count),
"cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync");
}

void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
Expand Down
14 changes: 11 additions & 3 deletions paddle/fluid/platform/gpu_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,17 @@ size_t GpuMaxChunkSize();
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind, cudaStream_t stream);

//! Copy memory from one device to another device.
void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
size_t count, cudaStream_t stream);
//! Copy memory from address src to dst synchronously.
void GpuMemcpySync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind);

//! Copy memory from one device to another device asynchronously.
void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
int src_device, size_t count, cudaStream_t stream);

//! Copy memory from one device to another device synchronously.
void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
int src_device, size_t count);

//! Set memory dst with value count size asynchronously
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream);
Expand Down