Skip to content

Commit 37c8eb1

Browse files
Style improvements. (#9410)
Co-authored-by: Zhanyong Wan <[email protected]>
1 parent cafff83 commit 37c8eb1

File tree

5 files changed

+70
-99
lines changed

5 files changed

+70
-99
lines changed

torch_xla/csrc/runtime/env_vars.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1 @@
11
#include "torch_xla/csrc/runtime/env_vars.h"
2-
3-
namespace torch_xla {
4-
namespace runtime {
5-
namespace env {
6-
7-
const char* const kEnvNumTpu = "TPU_NUM_DEVICES";
8-
const char* const kEnvNumGpu = "GPU_NUM_DEVICES";
9-
const char* const kEnvNumCpu = "CPU_NUM_DEVICES";
10-
const char* const kEnvTpuvmMode = "TPUVM_MODE";
11-
const char* const kEnvPjRtDevice = "PJRT_DEVICE";
12-
const char* const kEnvPjRtTpuMaxInflightComputations =
13-
"PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS";
14-
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";
15-
const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT";
16-
const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH";
17-
const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH";
18-
const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH";
19-
const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH";
20-
const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR";
21-
const char* const kEnvPjRtLocalProcessCount = "PJRT_LOCAL_PROCESS_COUNT";
22-
const char* const kEnvPjRtLocalRank = "PJRT_LOCAL_PROCESS_RANK";
23-
const char* const kEnvPjrtAllocatorCudaAsync = "PJRT_ALLOCATOR_CUDA_ASYNC";
24-
const char* const kEnvPjrtAllocatorPreallocate = "PJRT_ALLOCATOR_PREALLOCATE";
25-
const char* const kEnvPjrtAllocatorFraction = "PJRT_ALLOCATOR_FRACTION";
26-
const char* const kEnvPjrtDynamicPlugins = "PJRT_DYNAMIC_PLUGINS";
27-
const char* const kEnvDistSvcHeartbeatIntervalInSec =
28-
"DIST_SERVICE_HEARTBEAT_INTERVAL_IN_SEC";
29-
const char* const kEnvDistSvcMaxMissingHeartbeats =
30-
"DIST_SERVICE_MAX_MISSING_HEARTBEATS";
31-
const char* const kEnvDistSvcShutdownTimeoutInMin =
32-
"DIST_SERVICE_SHUTDOWN_TIMEOUT_IN_MIN";
33-
34-
} // namespace env
35-
} // namespace runtime
36-
} // namespace torch_xla

torch_xla/csrc/runtime/env_vars.h

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,42 @@
1+
// Names of environment variables.
2+
13
#ifndef XLA_CLIENT_ENV_VARS_H_
24
#define XLA_CLIENT_ENV_VARS_H_
35

46
namespace torch_xla {
57
namespace runtime {
68
namespace env {
79

8-
extern const char* const kEnvNumTpu;
9-
extern const char* const kEnvNumGpu;
10-
extern const char* const kEnvNumCpu;
11-
extern const char* const kEnvLocalWorker;
12-
extern const char* const kEnvTpuConfig;
13-
extern const char* const kEnvDeviceMap;
14-
extern const char* const kEnvWorkers;
15-
extern const char* const kEnvMeshService;
16-
extern const char* const kEnvWorldSize;
17-
extern const char* const kEnvMpDevice;
18-
extern const char* const kEnvHostOrdinal;
19-
extern const char* const kEnvShardOrdinal;
20-
extern const char* const kEnvStartService;
21-
extern const char* const kEnvTpuvmMode;
22-
extern const char* const kEnvPjRtDevice;
23-
extern const char* const kEnvPjRtTpuMaxInflightComputations;
24-
extern const char* const kEnvPjrtAsyncCpuClient;
25-
extern const char* const kEnvPjrtAsyncGpuClient;
26-
extern const char* const kEnvTpuLibraryPath;
27-
extern const char* const kEnvInferredTpuLibraryPath;
28-
extern const char* const kEnvXpuLibraryPath;
29-
extern const char* const kEnvNeuronLibraryPath;
30-
extern const char* const kEnvPjrtDistServiceAddr;
31-
extern const char* const kEnvPjRtLocalProcessCount;
32-
extern const char* const kEnvPjRtLocalRank;
33-
extern const char* const kEnvPjrtAllocatorCudaAsync;
34-
extern const char* const kEnvPjrtAllocatorPreallocate;
35-
extern const char* const kEnvPjrtAllocatorFraction;
36-
extern const char* const kEnvPjrtDynamicPlugins;
37-
extern const char* const kEnvDistSvcHeartbeatIntervalInSec;
38-
extern const char* const kEnvDistSvcMaxMissingHeartbeats;
39-
extern const char* const kEnvDistSvcShutdownTimeoutInMin;
10+
inline constexpr char kEnvLocalWorker[] = "LOCAL_WORKER";
11+
inline constexpr char kEnvTpuConfig[] = "TPU_CONFIG";
12+
inline constexpr char kEnvNumTpu[] = "TPU_NUM_DEVICES";
13+
inline constexpr char kEnvNumGpu[] = "GPU_NUM_DEVICES";
14+
inline constexpr char kEnvNumCpu[] = "CPU_NUM_DEVICES";
15+
inline constexpr char kEnvTpuvmMode[] = "TPUVM_MODE";
16+
inline constexpr char kEnvPjRtDevice[] = "PJRT_DEVICE";
17+
inline constexpr char kEnvPjRtTpuMaxInflightComputations[] =
18+
"PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS";
19+
inline constexpr char kEnvPjrtAsyncCpuClient[] = "PJRT_CPU_ASYNC_CLIENT";
20+
inline constexpr char kEnvPjrtAsyncGpuClient[] = "PJRT_GPU_ASYNC_CLIENT";
21+
inline constexpr char kEnvTpuLibraryPath[] = "TPU_LIBRARY_PATH";
22+
inline constexpr char kEnvInferredTpuLibraryPath[] = "PTXLA_TPU_LIBRARY_PATH";
23+
inline constexpr char kEnvXpuLibraryPath[] = "XPU_LIBRARY_PATH";
24+
inline constexpr char kEnvNeuronLibraryPath[] = "NEURON_LIBRARY_PATH";
25+
inline constexpr char kEnvPjrtDistServiceAddr[] = "PJRT_DIST_SERVICE_ADDR";
26+
inline constexpr char kEnvPjRtLocalProcessCount[] = "PJRT_LOCAL_PROCESS_COUNT";
27+
inline constexpr char kEnvPjRtLocalRank[] = "PJRT_LOCAL_PROCESS_RANK";
28+
inline constexpr char kEnvPjrtAllocatorCudaAsync[] =
29+
"PJRT_ALLOCATOR_CUDA_ASYNC";
30+
inline constexpr char kEnvPjrtAllocatorPreallocate[] =
31+
"PJRT_ALLOCATOR_PREALLOCATE";
32+
inline constexpr char kEnvPjrtAllocatorFraction[] = "PJRT_ALLOCATOR_FRACTION";
33+
inline constexpr char kEnvPjrtDynamicPlugins[] = "PJRT_DYNAMIC_PLUGINS";
34+
inline constexpr char kEnvDistSvcHeartbeatIntervalInSec[] =
35+
"DIST_SERVICE_HEARTBEAT_INTERVAL_IN_SEC";
36+
inline constexpr char kEnvDistSvcMaxMissingHeartbeats[] =
37+
"DIST_SERVICE_MAX_MISSING_HEARTBEATS";
38+
inline constexpr char kEnvDistSvcShutdownTimeoutInMin[] =
39+
"DIST_SERVICE_SHUTDOWN_TIMEOUT_IN_MIN";
4040

4141
} // namespace env
4242
} // namespace runtime

torch_xla/csrc/runtime/runtime.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
#include "torch_xla/csrc/runtime/runtime.h"
2+
13
#include <torch/csrc/lazy/backend/backend_device.h>
24

35
#include "absl/log/absl_check.h"
4-
#include "torch_xla/csrc/device.h"
56
#include "torch_xla/csrc/runtime/computation_client.h"
67
#include "torch_xla/csrc/runtime/env_vars.h"
78
#include "torch_xla/csrc/runtime/ifrt_computation_client.h"
@@ -10,12 +11,17 @@
1011

1112
namespace torch_xla::runtime {
1213

13-
std::atomic<bool> g_computation_client_initialized(false);
14+
static std::atomic<bool> g_computation_client_initialized(false);
1415

1516
// Creates a new instance of a `ComputationClient` (e.g.
16-
// `PjRtComputationClient`), and initializes the computation client
17+
// `PjRtComputationClient`), and initializes the computation client.
18+
// Can only be called when g_computation_client_initialized is false.
1719
static absl::StatusOr<ComputationClient * absl_nonnull>
1820
InitializeComputationClient() {
21+
ABSL_CHECK(!g_computation_client_initialized)
22+
<< "InitializeComputationClient() can only be called once.";
23+
g_computation_client_initialized = true;
24+
1925
if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
2026
tsl::testing::InstallStacktraceHandler();
2127
}
@@ -25,27 +31,24 @@ InitializeComputationClient() {
2531
//
2632
// static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false);
2733
const bool use_ifrt = false;
28-
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
29-
auto* client =
30-
(use_ifrt)
31-
? static_cast<ComputationClient*>(new IfrtComputationClient())
32-
: static_cast<ComputationClient*>(new PjRtComputationClient());
33-
g_computation_client_initialized = true;
34-
return client;
35-
} else {
34+
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") == "") {
3635
return absl::FailedPreconditionError("$PJRT_DEVICE is not set.");
3736
}
37+
38+
if (use_ifrt) {
39+
return new IfrtComputationClient();
40+
}
41+
return new PjRtComputationClient();
3842
}
3943

40-
absl::StatusOr<ComputationClient * absl_nonnull> GetComputationClient() {
44+
const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient() {
4145
// Reference to singleton Status-wrapped ComputationClient instance.
4246
//
4347
// Since we only allow a single initialization, as soon as this function is
4448
// called, we store the initialization result in this trivially destructible
4549
// reference.
46-
static auto& maybe_client =
47-
*new absl::StatusOr<ComputationClient * absl_nonnull>(
48-
InitializeComputationClient());
50+
static const auto& maybe_client =
51+
*new absl::StatusOr<ComputationClient*>(InitializeComputationClient());
4952
return maybe_client;
5053
}
5154

torch_xla/csrc/runtime/runtime.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#ifndef XLA_CLIENT_RUNTIME_H_
22
#define XLA_CLIENT_RUNTIME_H_
33

4+
#include "absl/base/attributes.h"
5+
#include "absl/status/statusor.h"
46
#include "torch_xla/csrc/runtime/computation_client.h"
57

68
namespace torch_xla::runtime {
79

810
// Returns the ComputationClient singleton.
9-
absl::StatusOr<ComputationClient * absl_nonnull> GetComputationClient();
11+
const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient();
1012

1113
ABSL_DEPRECATED(
1214
"Use status::GetComputationClient(), instead. "
@@ -15,12 +17,13 @@ ABSL_DEPRECATED(
1517
"safer.")
1618
ComputationClient* absl_nonnull GetComputationClientOrDie();
1719

18-
// Returns the ComputationClient singleton, if successfully initialized.
19-
// Returns a nullptr, if the ComputationClient wasn't initialized yet, or
20-
// if there was an error on initialization.
20+
// Returns the ComputationClient singleton if it was successfully initialized.
21+
// Returns a nullptr if the ComputationClient wasn't initialized yet.
22+
// Throws an exception if the ComputationClient was initialized but the
23+
// initialization failed.
2124
ComputationClient* GetComputationClientIfInitialized();
2225

23-
// Run the XRT local service, this will block the caller unitl the server
26+
// Runs the XRT local service, this will block the caller unitl the server
2427
// being stopped.
2528
void RunLocalService(uint64_t service_port);
2629

torch_xla/csrc/runtime/sys_util.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ namespace torch_xla {
1010
namespace runtime {
1111
namespace sys_util {
1212

13-
std::string GetEnvString(const char* name, const std::string& defval) {
14-
const char* env = std::getenv(name);
13+
std::string GetEnvString(const char* const name, const std::string& defval) {
14+
const char* const env = std::getenv(name);
1515
return env != nullptr ? env : defval;
1616
}
1717

18-
std::string GetEnvOrdinalPath(const char* name, const std::string& defval,
18+
std::string GetEnvOrdinalPath(const char* const name, const std::string& defval,
1919
const int64_t ordinal) {
2020
std::string path = GetEnvString(name, defval);
2121
if (!path.empty()) {
@@ -26,23 +26,23 @@ std::string GetEnvOrdinalPath(const char* name, const std::string& defval,
2626
return path;
2727
}
2828

29-
std::string GetEnvOrdinalPath(const char* name, const std::string& defval,
30-
const char* ordinal_env) {
29+
std::string GetEnvOrdinalPath(const char* const name, const std::string& defval,
30+
const char* const ordinal_env) {
3131
return GetEnvOrdinalPath(name, defval, GetEnvInt(ordinal_env, -1));
3232
}
3333

34-
int64_t GetEnvInt(const char* name, int64_t defval) {
35-
const char* env = std::getenv(name);
34+
int64_t GetEnvInt(const char* const name, const int64_t defval) {
35+
const char* const env = std::getenv(name);
3636
return env != nullptr ? std::atol(env) : defval;
3737
}
3838

39-
double GetEnvDouble(const char* name, double defval) {
40-
const char* env = std::getenv(name);
39+
double GetEnvDouble(const char* const name, const double defval) {
40+
const char* const env = std::getenv(name);
4141
return env != nullptr ? std::atof(env) : defval;
4242
}
4343

44-
bool GetEnvBool(const char* name, bool defval) {
45-
const char* env = std::getenv(name);
44+
bool GetEnvBool(const char* const name, const bool defval) {
45+
const char* const env = std::getenv(name);
4646
if (env == nullptr) {
4747
return defval;
4848
}
@@ -56,7 +56,7 @@ bool GetEnvBool(const char* name, bool defval) {
5656
}
5757

5858
int64_t NowNs() {
59-
auto now = std::chrono::high_resolution_clock::now();
59+
const auto now = std::chrono::high_resolution_clock::now();
6060
return std::chrono::duration_cast<std::chrono::nanoseconds>(
6161
now.time_since_epoch())
6262
.count();

0 commit comments

Comments
 (0)