Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions paddle/fluid/framework/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,40 @@ limitations under the License. */
#include <algorithm>
#include <stdexcept>
#include <string>
#include <vector>

#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件重复了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix.

#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h"

namespace paddle {
namespace framework {

DEFINE_string(devices, "", "The devices to be used.");
DEFINE_bool(init_p2p, true, "Whether to init p2p.");

std::once_flag gflags_init_flag;
std::once_flag p2p_init_flag;

using paddle::platform::DeviceContextPool;

void Init(int argc, char **argv) {
std::call_once(gflags_init_flag,
[&]() { google::ParseCommandLineFlags(&argc, &argv, true); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里调用InitGflags,然后修复下IntiGflags的问题?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thx.


// init devices
std::vector<int> devices;
std::string token;
std::istringstream tokenStream(FLAGS_devices);
while (std::getline(tokenStream, token, ',')) {
devices.push_back(std::stoi(token));
}
InitDevices(FLAGS_init_p2p, devices);
}

void InitGflags(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, [&]() {
int argc = argv.size();
Expand Down Expand Up @@ -64,6 +85,30 @@ void InitP2P(int count) {
#endif
}

void InitP2P(std::vector<int> devices) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line67行的InitP2P函数,可以直接调用该函数吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我的意思是:void InitP2P(int count) 和 你新加的 InitP2P(std::vector<int> devices) 内部实现几乎一样,可以先把 std::vector<int> devices构造出来,直接调用新增的这个。 或者上面那个还有必要存在吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qingqing01 明白了,我改成在InitDevices(bool initP2P) 里调用InitDevices(bool initP2P, vector<int> devices), 然后void InitP2P(int count)也就没用了,已将其删除。

#ifdef PADDLE_WITH_CUDA
std::call_once(p2p_init_flag, [&]() {
int count = devices.size();
for (int i = 0; i < count; ++i) {
for (int j = 0; j < count; ++j) {
if (devices[i] == devices[j]) continue;
int can_acess = -1;
PADDLE_ENFORCE(
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]),
"Failed to test P2P access.");
if (can_acess != 1) {
LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
<< " to " << devices[j];
} else {
cudaSetDevice(devices[i]);
cudaDeviceEnablePeerAccess(devices[j], 0);
}
}
}
});
#endif
}

void InitDevices(bool init_p2p) {
/*Init all avaiable devices by default */

Expand Down Expand Up @@ -91,6 +136,34 @@ void InitDevices(bool init_p2p) {
platform::DeviceContextPool::Init(places);
}

void InitDevices(bool init_p2p, const std::vector<int> devices) {
std::vector<platform::Place> places;
int count = 0;
#ifdef PADDLE_WITH_CUDA
try {
count = platform::GetCUDADeviceCount();
} catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
}
#else
LOG(WARNING)
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif

for (size_t i = 0; i < devices.size(); ++i) {
if (devices[i] >= count) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

devices[i]必须:

  1. > 0
  2. 不能重复
    这个能不能保证下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xreki DeviceContextPool有对places去重,而且CUDAPlace按device_id判断是否相等:https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/platform/place.h#L41
所以,我们不用在这里check是否重复

LOG(WARNING) << "Invalid devices id.";
continue;
}
places.emplace_back(platform::CUDAPlace(devices[i]));
}
if (init_p2p) {
InitP2P(devices);
}
places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places);
}

void InitGLOG(const std::string &prog_name) {
// glog will not hold the ARGV[0] inside.
// Use strdup to alloc a new string.
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/init.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ limitations under the License. */
namespace paddle {
namespace framework {

void Init(int argc, char **argv);

void InitGflags(std::vector<std::string> &argv);

void InitGLOG(const std::string &prog_name);

void InitDevices(bool init_p2p);

void InitDevices(bool init_p2p, const std::vector<int> devices);

} // namespace framework
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/inference/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace inference {
// linking the inference shared library.
void Init(bool init_p2p) { framework::InitDevices(init_p2p); }

void Init(int argc, char** argv) { framework::Init(argc, argv); }

void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/inference/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ namespace inference {

void Init(bool init_p2p);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Init接口可以删掉了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


void Init(int argc, char** argv);

void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program,
const std::string& dirname,
Expand Down