-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add init interface for customize devices. #10167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1bdea0a
48b7b54
f31bb14
a0b2582
e470856
a4b452a
3d96b38
ad3f6f4
848fb00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,19 +15,40 @@ limitations under the License. */ | |
| #include <algorithm> | ||
| #include <stdexcept> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "paddle/fluid/framework/init.h" | ||
| #include "paddle/fluid/framework/operator.h" | ||
| #include "paddle/fluid/platform/device_context.h" | ||
| #include "paddle/fluid/platform/device_context.h" | ||
| #include "paddle/fluid/platform/place.h" | ||
| #include "paddle/fluid/string/piece.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| DEFINE_string(devices, "", "The devices to be used."); | ||
| DEFINE_bool(init_p2p, true, "Whether to init p2p."); | ||
|
|
||
| std::once_flag gflags_init_flag; | ||
| std::once_flag p2p_init_flag; | ||
|
|
||
| using paddle::platform::DeviceContextPool; | ||
|
|
||
| void Init(int argc, char **argv) { | ||
| std::call_once(gflags_init_flag, | ||
| [&]() { google::ParseCommandLineFlags(&argc, &argv, true); }); | ||
|
||
|
|
||
| // init devices | ||
| std::vector<int> devices; | ||
| std::string token; | ||
| std::istringstream tokenStream(FLAGS_devices); | ||
| while (std::getline(tokenStream, token, ',')) { | ||
| devices.push_back(std::stoi(token)); | ||
| } | ||
| InitDevices(FLAGS_init_p2p, devices); | ||
| } | ||
|
|
||
| void InitGflags(std::vector<std::string> &argv) { | ||
| std::call_once(gflags_init_flag, [&]() { | ||
| int argc = argv.size(); | ||
|
|
@@ -64,6 +85,30 @@ void InitP2P(int count) { | |
| #endif | ||
| } | ||
|
|
||
| void InitP2P(std::vector<int> devices) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line67行的
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. Thx.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我的意思是:void InitP2P(int count) 和 你新加的
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qingqing01 明白了,我改成在 |
||
| #ifdef PADDLE_WITH_CUDA | ||
| std::call_once(p2p_init_flag, [&]() { | ||
| int count = devices.size(); | ||
| for (int i = 0; i < count; ++i) { | ||
| for (int j = 0; j < count; ++j) { | ||
| if (devices[i] == devices[j]) continue; | ||
| int can_acess = -1; | ||
| PADDLE_ENFORCE( | ||
| cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]), | ||
| "Failed to test P2P access."); | ||
| if (can_acess != 1) { | ||
| LOG(WARNING) << "Cannot enable P2P access from " << devices[i] | ||
| << " to " << devices[j]; | ||
| } else { | ||
| cudaSetDevice(devices[i]); | ||
| cudaDeviceEnablePeerAccess(devices[j], 0); | ||
| } | ||
| } | ||
| } | ||
| }); | ||
| #endif | ||
| } | ||
|
|
||
| void InitDevices(bool init_p2p) { | ||
| /*Init all avaiable devices by default */ | ||
|
|
||
|
|
@@ -91,6 +136,34 @@ void InitDevices(bool init_p2p) { | |
| platform::DeviceContextPool::Init(places); | ||
| } | ||
|
|
||
| void InitDevices(bool init_p2p, const std::vector<int> devices) { | ||
| std::vector<platform::Place> places; | ||
| int count = 0; | ||
| #ifdef PADDLE_WITH_CUDA | ||
| try { | ||
| count = platform::GetCUDADeviceCount(); | ||
| } catch (const std::exception &exp) { | ||
| LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; | ||
| } | ||
| #else | ||
| LOG(WARNING) | ||
| << "'CUDA' is not supported, Please re-compile with WITH_GPU option"; | ||
| #endif | ||
|
|
||
| for (size_t i = 0; i < devices.size(); ++i) { | ||
| if (devices[i] >= count) { | ||
|
||
| LOG(WARNING) << "Invalid devices id."; | ||
| continue; | ||
| } | ||
| places.emplace_back(platform::CUDAPlace(devices[i])); | ||
| } | ||
| if (init_p2p) { | ||
| InitP2P(devices); | ||
| } | ||
| places.emplace_back(platform::CPUPlace()); | ||
| platform::DeviceContextPool::Init(places); | ||
| } | ||
|
|
||
| void InitGLOG(const std::string &prog_name) { | ||
| // glog will not hold the ARGV[0] inside. | ||
| // Use strdup to alloc a new string. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,8 @@ namespace inference { | |
|
|
||
| void Init(bool init_p2p); | ||
|
||
|
|
||
| void Init(int argc, char** argv); | ||
|
|
||
| void LoadPersistables(framework::Executor* executor, framework::Scope* scope, | ||
| const framework::ProgramDesc& main_program, | ||
| const std::string& dirname, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件重复了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix.