Skip to content

Commit ba8b262

Browse files
committed
Migrate to Hermetic CUDA.
- Update bazel files. - Use `clang` for PyTorch/XLA only. - Fix StableHLO tests. - Compile debugging information with dwarf-4.
1 parent 066e69e commit ba8b262

File tree

9 files changed

+66
-30
lines changed

9 files changed

+66
-30
lines changed

.bazelrc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,8 @@ build --copt=-fexceptions
5353
# safer than o+rx.
5454
build --spawn_strategy=sandboxed
5555

56-
# Use GCC for C/C++ compilation.
57-
build --action_env=CC=gcc
58-
build --action_env=CXX=g++
59-
60-
###########################################################################
61-
62-
build:clang --action_env=CC=/usr/bin/clang-17
63-
build:clang --action_env=CXX=/usr/bin/clang++-17
56+
build --action_env=CC=clang
57+
build --action_env=CXX=clang++
6458

6559
###########################################################################
6660

@@ -85,8 +79,22 @@ build:cuda --repo_env TF_NEED_CUDA=1
8579
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
8680
build:cuda --@local_config_cuda//:enable_cuda
8781
build:cuda --define=xla_python_enable_gpu=true
82+
# Define XLA_CUDA for C++ files.
8883
build:cuda --cxxopt=-DXLA_CUDA=1
8984

85+
# Default hermetic CUDA and CUDNN versions.
86+
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
87+
build:cuda --repo_env HERMETIC_CUDA_VERSION="12.3.2"
88+
build:cuda --repo_env HERMETIC_CUDNN_VERSION="9.1.1"
89+
90+
# Link NCCL statically, as it was with the non-hermetic build.
91+
build:cuda --repo_env TF_NCCL_USE_STUB=0
92+
93+
# Use NVCC for compiling CUDA.
94+
build:cuda --action_env TF_NVCC_CLANG=1
95+
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
96+
build:cuda --@local_config_cuda//cuda:include_cuda_libs=false
97+
9098
# Coverage with cuda/gcc/nvcc requires manually setting coverage flags.
9199
coverage:cuda --per_file_copt=third_party/.*,torch_xla/.*@--coverage
92100
coverage:cuda --linkopt=-lgcov
@@ -254,8 +262,10 @@ build:linux --copt="-Werror=unused-result"
254262
build:linux --copt="-Wswitch"
255263
build:linux --copt="-Werror=switch"
256264
# Required for building with clang
257-
build:linux --copt="-Wno-error=unused-but-set-variable"
265+
build:linux --copt="-Qunused-arguments"
266+
build:linux --copt="-Wno-unused-command-line-argument"
258267

259268
# Only include debug info for files not under XLA.
260269
build:dbg -c dbg
270+
build:dbg --copt=-gdwarf-4
261271
build:dbg --per_file_copt=external/xla/.*@-g0,-DNDEBUG

WORKSPACE

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ http_archive(
5656
],
5757
patch_tool = "patch",
5858
patches = [
59-
"//openxla_patches:gpu_nvml.diff",
6059
"//openxla_patches:gpu_race_condition.diff",
6160
"//openxla_patches:count_down.diff",
6261
],
@@ -134,17 +133,49 @@ load("@xla//:workspace0.bzl", "xla_workspace0")
134133

135134
xla_workspace0()
136135

136+
load(
137+
"@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
138+
"cuda_json_init_repository",
139+
)
140+
141+
cuda_json_init_repository()
142+
143+
load(
144+
"@cuda_redist_json//:distributions.bzl",
145+
"CUDA_REDISTRIBUTIONS",
146+
"CUDNN_REDISTRIBUTIONS",
147+
)
148+
load(
149+
"@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
150+
"cuda_redist_init_repositories",
151+
"cudnn_redist_init_repository",
152+
)
153+
154+
cuda_redist_init_repositories(
155+
cuda_redistributions = CUDA_REDISTRIBUTIONS,
156+
)
157+
158+
cudnn_redist_init_repository(
159+
cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
160+
)
137161

138162
load(
139-
"@xla//third_party/gpus:cuda_configure.bzl",
140-
"cuda_configure",
163+
"@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
164+
"cuda_configure",
141165
)
142166

143167
cuda_configure(name = "local_config_cuda")
144168

145169
load(
146-
"@xla//third_party/nccl:nccl_configure.bzl",
147-
"nccl_configure",
170+
"@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
171+
"nccl_redist_init_repository",
172+
)
173+
174+
nccl_redist_init_repository()
175+
176+
load(
177+
"@xla//third_party/nccl/hermetic:nccl_configure.bzl",
178+
"nccl_configure",
148179
)
149180

150181
nccl_configure(name = "local_config_nccl")

infra/ansible/config/env.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22
# They'll be accessible for all processes on the host, also in the development image.
33
release_env:
44
common:
5-
# Force GCC because clang/bazel has issues.
65
CC: gcc-10
76
CXX: g++-10
8-
# CC: "clang-{{ clang_version }}"
9-
# CXX: "clang++-{{ clang_version }}"
107
LD_LIBRARY_PATH: "$LD_LIBRARY_PATH:/usr/local/lib"
118

129
tpu:
@@ -49,3 +46,7 @@ build_env:
4946
ACCELERATOR: tpu
5047
TPUVM_MODE: 1
5148
BUNDLE_LIBTPU: "{{ bundle_libtpu }}"
49+
50+
clang_compiler:
51+
CC: /usr/lib/{{ clang_version }}/bin/clang
52+
CXX: /usr/lib/{{ clang_version }}/bin/clang++

infra/ansible/playbook.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
combine(build_env[arch] | default({}, true)) |
8686
combine(build_env[accelerator] | default({}, true))
8787
}}"
88+
clang_compiler: "{{ clang_compiler }}"
8889
when: stage == "build"
8990
tags: build_srcs
9091

@@ -94,7 +95,8 @@
9495
env_vars: "{{
9596
build_env.common | default({}, true) |
9697
combine(build_env[arch] | default({}, true)) |
97-
combine(build_env[accelerator] | default({}, true))
98+
combine(build_env[accelerator] | default({}, true)) |
99+
combine(clang_compiler)
98100
}}"
99101
when: stage == "build_plugin"
100102
tags: build_plugin

infra/ansible/roles/build_srcs/tasks/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
ansible.builtin.command:
4747
cmd: python setup.py bdist_wheel
4848
chdir: "{{ (src_root, 'pytorch/xla') | path_join }}"
49-
environment: "{{ env_vars }}"
49+
environment: "{{ env_vars | combine(clang_compiler) }}"
5050

5151
- name: Find XLA *.whl files in pytorch/xla/dist
5252
ansible.builtin.find:

test/test_operations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2209,8 +2209,8 @@ def test_inplace_mul_scalar_different_dtype(self):
22092209
def fn(inp, s):
22102210
return inp.mul_(s)
22112211

2212-
inp = torch.rand(10, dtype=torch.half)
2213-
s = torch.tensor(7, dtype=torch.double)
2212+
inp = torch.arange(10).to(torch.half)
2213+
s = torch.tensor(3, dtype=torch.double)
22142214

22152215
Xinp = inp.to(xm.xla_device())
22162216
Xs = s.to(xm.xla_device())
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
import torch
22

3-
torch.library.register_kernel("aten::upsample_trilinear3d", "xla",
3+
torch.library.register_kernel("aten::upsample_trilinear3d", "XLA",
44
torch._decomp.decompositions.upsample_trilinear3d)

torch_xla/csrc/runtime/stablehlo_composite_helper.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
namespace torch_xla {
1717
namespace runtime {
1818

19-
namespace {
20-
2119
using nlohmann::json;
2220

2321
static bool IsXlaMarkTensorOp(mlir::Operation* op) {
@@ -529,8 +527,6 @@ class RemoveXlaMarkTensorOpsPass
529527
}
530528
};
531529

532-
} // namespace
533-
534530
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
535531
CreateBuildStableHLOCompositePass() {
536532
return std::make_unique<BuildStableHLOCompositePass>();

torch_xla/csrc/runtime/xla_mlir_debuginfo_helper.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
namespace torch_xla {
1010
namespace runtime {
1111

12-
namespace {
13-
1412
// Defined in torch_xla/experimental/xla_mlir_debuginfo.py
1513
static constexpr char XLA_MLIR_DEBUGINFO_BEGIN[] = "<XLA_MLIR_DEBUGINFO_BEGIN>";
1614
static constexpr char XLA_MLIR_DEBUGINFO_END[] = "<XLA_MLIR_DEBUGINFO_END>";
@@ -81,8 +79,6 @@ class PrepareXlaMlirDebuginfoPass : public mlir::OperationPass<mlir::ModuleOp> {
8179
}
8280
};
8381

84-
} // namespace
85-
8682
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
8783
CreatePrepareXlaMlirDebuginfoPass() {
8884
return std::make_unique<PrepareXlaMlirDebuginfoPass>();

0 commit comments

Comments
 (0)