-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add CPU/GPU Memcpy in memory folder #2970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
gangliao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments
| // For more details, please check https://stackoverflow.com/a/43870188/724872. | ||
| #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) | ||
|
|
||
| template <typename T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix special case of PADDLE_ENFORCE
PADDLE_ENFORCE(condition, "hello world"); // OK, if using old implementation
PADDLE_ENFORCE(condition) // Failed, if using old implementation. But, it's addressed.| const void* src, size_t num, | ||
| cudaStream_t stream) { | ||
| platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, use cudaMemcpyDeviceToHost and cudaMemcpyHostToDevice, cudaMemcpyDeviceToDevice. But the cudaMemcpyKind of cudaMemcpyDefault:
cudaMemcpyDefault = 4
Direction of the transfer is inferred from the pointer values. Requires unified virtual addressing
Can we use cudaMemcpyDefault to simplify code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. But, how to specialize one function to support both cases: (CPUPlace, GPUPlace),
(GPUPlace, CPUPlace).
One way may achieve that is std::enable_if, but it will dump too many annoying code.
|
|
||
| //!Get the memory usage of current GPU device. | ||
| void GpuMemoryUsage(size_t& available, size_t& total); | ||
| void GpuMemoryUsage(size_t &available, size_t &total); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, we should unify the code style
size_t& available, size_t& total;
& and * should close to type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't change this. Clang-format takes this job
| platform::CPUPlace src_place, | ||
| const void* src, size_t num, | ||
| cudaStream_t stream) { | ||
| platform::SetDeviceId(dst_place.device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should use platform::GPUPlaceGuard here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's unnecessary to use the guard to implicitly roll back the device id.
For GPU device, it's better to explicitly set device id.
wangkuiyi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| #ifndef PADDLE_ONLY_CPU | ||
| template <typename DstPlace, typename SrcPlace> | ||
| void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, | ||
| cudaStream_t stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to add a comment telling when would users call this second form of Copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for replying too late due to we are on duty yesterday. Yeah, I will annotate this function. Thanks.
|
|
||
| template <typename... Args> | ||
| inline void throw_on_error(cudaError_t e, const Args&... args) { | ||
| inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
| size_t GpuMaxChunkSize(); | ||
|
|
||
| //! Copy memory from address src to dst asynchronously. | ||
| void GpuMemcpyAsync(void *dst, const void *src, size_t count, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move these copying functions into a new source file, say, copy.{h,cc}? I am not sure. Just mention it.
No description provided.