-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix/copyfrom context #6954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/copyfrom context #6954
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,6 +82,28 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place, | |
| #endif | ||
| } | ||
|
|
||
| /** | ||
| * @brief CopyFrom support CPU <-> CPU | ||
| */ | ||
|
|
||
| inline void CopyFrom(const Tensor& src, const platform::Place& dst_place, | ||
| Tensor* dst) { | ||
| src.check_memory_size(); | ||
| dst->Resize(src.dims()); | ||
|
|
||
| auto src_place = src.place(); | ||
| auto src_ptr = src.data<void>(); | ||
|
|
||
| auto dst_ptr = dst->mutable_data(dst_place, src.type()); | ||
|
|
||
| auto size = src.numel() * SizeOfType(src.type()); | ||
|
|
||
| if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { | ||
|
||
| memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, | ||
| boost::get<platform::CPUPlace>(src_place), src_ptr, size); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * @brief Copy the content of an external vector to a tensor. | ||
| * | ||
|
|
@@ -115,6 +137,21 @@ inline void CopyFromVector(const std::vector<T>& src, | |
| #endif | ||
| } | ||
|
|
||
| /** | ||
| * @brief CopyFromVector CPU vector -> CPU Tensor | ||
| */ | ||
| template <typename T> | ||
| inline void CopyFromVector(const std::vector<T>& src, Tensor* dst) { | ||
| platform::CPUPlace dst_place = platform::CPUPlace(); | ||
| auto src_ptr = static_cast<const void*>(src.data()); | ||
| platform::CPUPlace src_place; | ||
| dst->Resize({static_cast<int64_t>(src.size())}); | ||
| auto dst_ptr = static_cast<void*>(dst->mutable_data<T>(dst_place)); | ||
| auto size = src.size() * sizeof(T); | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add place checker.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CopyToVector create a cpu_place inside. No need to check it. |
||
| memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Copy the content of a tensor to a vector | ||
| * | ||
|
|
@@ -147,6 +184,24 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx, | |
| } | ||
| #endif | ||
| } | ||
| /** | ||
| * @brief CopyToVector CPUTensor <-> CPU Vector | ||
| */ | ||
|
|
||
| template <typename T> | ||
| inline void CopyToVector(const Tensor& src, std::vector<T>* dst) { | ||
| auto src_ptr = static_cast<const void*>(src.data<T>()); | ||
| auto size = src.numel() * sizeof(T); | ||
|
|
||
| platform::CPUPlace dst_place; | ||
| dst->resize(src.numel()); | ||
| auto dst_ptr = static_cast<void*>(dst->data()); | ||
|
|
||
| if (platform::is_cpu_place(src.place())) { | ||
| memory::Copy(dst_place, dst_ptr, | ||
| boost::get<platform::CPUPlace>(src.place()), src_ptr, size); | ||
| } | ||
| } | ||
|
|
||
| } // namespace framework | ||
| } // namespace paddle | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete the redundant black line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.