Skip to content

Commit dd14f37

Browse files
committed
Actually make Python 3.12 work
This PR defaults the dev container to Python 3.12 because Python 3.10 is soon unusable with future JAX versions. Python 3.12 builds a package for editing in an isolated pip environment. That means we need a pyproject.toml file declaring which packages setup.py depends on, which is added in this PR. Also setup.py depends on torch to get the C++11 ABI status but that just won't work in Python 3.12 so that logic is adjusted accordingly. We need a bunch of other fixes due to the isolated environment change in Python 3.12. Finally, we need to add `12` into the list of Python versions when generating JAX dependencies. TEST: `scripts/build_developer.sh -a -t`
1 parent 9c8ae9f commit dd14f37

File tree

5 files changed

+132
-64
lines changed

5 files changed

+132
-64
lines changed

.devcontainer/tpu-contributor/devcontainer.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
{
22
"name": "tpu-contributor",
3-
"image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu",
3+
"image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.12_tpuvm",
44
"runArgs": [
55
"--privileged",
66
"--net=host",
77
"--shm-size=16G"
88
],
9-
"initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu",
9+
"initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.12_tpuvm",
1010
"customizations": {
1111
"vscode": {
1212
"extensions": [
@@ -23,4 +23,4 @@
2323
]
2424
}
2525
}
26-
}
26+
}

.devcontainer/tpu-internal/devcontainer.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
{
22
"name": "tpu-internal",
3-
"image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu",
3+
"image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.12_tpuvm",
44
"runArgs": [
55
"--privileged",
66
"--net=host",
77
"--shm-size=16G"
88
],
99
"containerEnv": {
1010
"BAZEL_REMOTE_CACHE": "1",
11-
"SILO_NAME": "cache-silo-${localEnv:USER}-tpuvm"
11+
"SILO_NAME": "cache-silo-${localEnv:USER}-tpuvm-312"
1212
},
13-
"initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu",
13+
"initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.12_tpuvm",
1414
"customizations": {
1515
"vscode": {
1616
"extensions": [

pyproject.toml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
[build-system]
2+
# These are the packages required to run `setup.py` in an isolated environment.
3+
# Pip will install these *before* executing the build backend.
4+
requires = [
5+
"setuptools>=42",
6+
"wheel",
7+
"requests",
8+
"numpy",
9+
"pyyaml",
10+
]
11+
build-backend = "setuptools.build_meta"
12+
13+
[project]
14+
name = "torch-xla"
15+
description = "XLA bridge for PyTorch"
16+
readme = "README.md"
17+
authors = [
18+
{ name = "PyTorch/XLA Dev Team", email = "[email protected]" },
19+
]
20+
license = { file = "LICENSE" }
21+
requires-python = ">=3.10"
22+
classifiers = [
23+
"Development Status :: 5 - Production/Stable",
24+
"Intended Audience :: Developers",
25+
"Intended Audience :: Science/Research",
26+
"License :: OSI Approved :: BSD License",
27+
"Programming Language :: Python :: 3",
28+
"Programming Language :: Python :: 3.10",
29+
"Programming Language :: Python :: 3.11",
30+
"Programming Language :: Python :: 3.12",
31+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
32+
"Topic :: Software Development :: Libraries :: Python Modules",
33+
]
34+
keywords = ["pytorch", "xla", "tpu", "deep learning", "compiler"]
35+
36+
# This tells build tools to get this info from setup.py instead of this file.
37+
dynamic = [
38+
"version",
39+
"dependencies",
40+
"optional-dependencies",
41+
"entry-points",
42+
"scripts"
43+
]
44+
45+
[project.urls]
46+
Homepage = "https://github.com/pytorch/xla"
47+
Repository = "https://github.com/pytorch/xla"
48+
"Bug Tracker" = "https://github.com/pytorch/xla/issues"

setup.py

Lines changed: 77 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import contextlib
5353
import distutils.ccompiler
5454
import distutils.command.clean
55+
import importlib.util
5556
import os
5657
import re
5758
import requests
@@ -61,7 +62,13 @@
6162
import tempfile
6263
import zipfile
6364

64-
import build_util
65+
# This gloop imports build_util.py such that it works in Python 3.12's isolated
66+
# build environment while also not contaminating sys.path which breaks bdist_wheel.
67+
_PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
68+
_build_util_path = os.path.join(_PROJECT_DIR, 'build_util.py')
69+
spec = importlib.util.spec_from_file_location('build_util', _build_util_path)
70+
build_util = importlib.util.module_from_spec(spec)
71+
spec.loader.exec_module(build_util)
6572

6673
import platform
6774

@@ -151,7 +158,7 @@ def get_git_head_sha(base_dir):
151158

152159

153160
def get_build_version(xla_git_sha):
154-
version = os.getenv('TORCH_XLA_VERSION', '2.8.0')
161+
version = os.getenv('TORCH_XLA_VERSION', '2.9.0')
155162
if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'):
156163
try:
157164
version += '+git' + xla_git_sha[:7]
@@ -270,40 +277,51 @@ def __init__(self, bazel_target):
270277
class BuildBazelExtension(build_ext.build_ext):
271278
"""A command that runs Bazel to build a C/C++ extension."""
272279

273-
def run(self):
274-
for ext in self.extensions:
275-
self.bazel_build(ext)
276-
command.build_ext.build_ext.run(self) # type: ignore
280+
def build_extension(self, ext: Extension) -> None:
281+
"""
282+
This method is called by setuptools to build a single extension.
283+
We override it to implement our custom Bazel build logic.
284+
"""
285+
if not isinstance(ext, BazelExtension):
286+
# If it's not our custom extension type, let setuptools handle it.
287+
super().build_extension(ext)
288+
return
277289

278-
def bazel_build(self, ext):
290+
# 1. Ensure the temporary build directory exists
279291
if not os.path.exists(self.build_temp):
280292
os.makedirs(self.build_temp)
281293

294+
# 2. Prepare the Bazel command
282295
bazel_argv = [
283296
'bazel', 'build', ext.bazel_target,
284297
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}"
285298
]
286299

287-
build_cpp_tests = build_util.check_env_flag('BUILD_CPP_TESTS', default='0')
288-
if build_cpp_tests:
289-
bazel_argv.append('//:cpp_tests')
290-
291-
import torch
292-
cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C,
293-
'_GLIBCXX_USE_CXX11_ABI', None)
294-
if cxx_abi is not None:
295-
bazel_argv.append(f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')
300+
cxx_abi = os.getenv('CXX_ABI')
301+
if cxx_abi is None:
302+
try:
303+
import torch
304+
cxx_abi = getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None)
305+
except:
306+
pass
307+
if cxx_abi is None:
308+
# Default to building with C++11 ABI, which has been the case since PyTorch 2.7
309+
cxx_abi = "1"
310+
bazel_argv.append(f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')
296311

297312
bazel_argv.extend(build_util.bazel_options_from_env())
298313

314+
# 3. Run the Bazel build
299315
self.spawn(bazel_argv)
300316

317+
# 4. Copy the output file to the location setuptools expects
301318
ext_bazel_bin_path = os.path.join(self.build_temp, 'bazel-bin', ext.relpath,
302319
ext.target_name)
303320
ext_dest_path = self.get_ext_fullpath(ext.name)
304321
ext_dest_dir = os.path.dirname(ext_dest_path)
305322
if not os.path.exists(ext_dest_dir):
306323
os.makedirs(ext_dest_dir)
324+
307325
shutil.copyfile(ext_bazel_bin_path, ext_dest_path)
308326

309327

@@ -313,17 +331,21 @@ def bazel_build(self, ext):
313331
long_description = f.read()
314332

315333
# Finds torch_xla and its subpackages
316-
packages_to_include = find_packages(include=['torch_xla*'])
317-
# Explicitly add torchax
318-
packages_to_include.extend(find_packages(where='torchax', include=['torchax*']))
334+
# We specify 'torchax.torchax' to find the nested package correctly.
335+
packages_to_include = find_packages(
336+
include=['torch_xla', 'torch_xla.*', 'torchax', 'torchax.*'])
319337

320338
# Map the top-level 'torchax' package name to its source location
321-
torchax_dir = os.path.join(cwd, 'torchax')
322-
package_dir_mapping = {'torch_xla': os.path.join(cwd, 'torch_xla')}
323-
package_dir_mapping['torchax'] = os.path.join(torchax_dir, 'torchax')
339+
package_dir_mapping = {
340+
'torchax': 'torchax/torchax',
341+
}
324342

325343

326344
class Develop(develop.develop):
345+
"""
346+
Custom develop command to build C++ extensions and create a .pth file
347+
for a multi-package editable install.
348+
"""
327349

328350
def run(self):
329351
# Build the C++ extension
@@ -348,44 +370,42 @@ def link_packages(self):
348370
(`python setup.py develop`). Nightly and release wheel builds work out of the box
349371
without egg-link/pth.
350372
"""
373+
import glob
374+
351375
# Ensure paths like self.install_dir are set
352376
self.ensure_finalized()
353377

354-
# Get the site-packages directory
355-
target_dir = self.install_dir
356-
357-
# Remove the standard .egg-link file
358-
# It's usually named based on the distribution name
359378
dist_name = self.distribution.get_name()
360-
egg_link_file = os.path.join(target_dir, dist_name + '.egg-link')
361-
if os.path.exists(egg_link_file):
362-
print(f"Removing default egg-link file: {egg_link_file}")
363-
try:
364-
os.remove(egg_link_file)
365-
except OSError as e:
366-
print(f"Warning: Could not remove {egg_link_file}: {e}")
367-
368-
# Create our custom .pth file with specific paths
369-
cwd = os.path.dirname(__file__)
370-
# Path containing 'torch_xla' package source: ROOT
371-
path_for_torch_xla = os.path.abspath(cwd)
372-
# Path containing 'torchax' package source: ROOT/torchax
373-
path_for_torchax = os.path.abspath(os.path.join(cwd, 'torchax'))
374-
375-
paths_to_add = {path_for_torch_xla, path_for_torchax}
376-
377-
# Construct a suitable .pth filename (PEP 660 style is good practice)
378-
version = self.distribution.get_version()
379-
# Sanitize name and version for filename (replace runs of non-alphanumeric chars with '-')
380-
sanitized_name = re.sub(r"[^a-zA-Z0-9.]+", "_", dist_name)
381-
sanitized_version = re.sub(r"[^a-zA-Z0-9.]+", "_", version)
382-
pth_filename = os.path.join(
383-
target_dir, f"__editable_{sanitized_name}_{sanitized_version}.pth")
384-
385-
# Ensure site-packages exists
386-
os.makedirs(target_dir, exist_ok=True)
387-
388-
# Write the paths to the .pth file, one per line
379+
install_cmd = self.get_finalized_command('install')
380+
target_dir = install_cmd.install_lib
381+
assert target_dir is not None
382+
383+
# Use glob to robustly find and remove the conflicting files.
384+
# This is safer than trying to guess the exact sanitized filename.
385+
safe_name_part = re.sub(r"[^a-zA-Z0-9]+", "_", dist_name)
386+
387+
for pattern in [
388+
# Remove `.pth` files generated in Python 3.12.
389+
f"__editable__.*{safe_name_part}*.pth",
390+
f"__editable___*{safe_name_part}*_finder.py",
391+
# Also remove the legacy egg-link format.
392+
f"{dist_name}.egg-link"
393+
]:
394+
for filepath in glob.glob(os.path.join(target_dir, pattern)):
395+
print(f"Cleaning up conflicting install file: {filepath}")
396+
with contextlib.suppress(OSError):
397+
os.remove(filepath)
398+
399+
# Finally, create our own simple, multi-path .pth file.
400+
# We name it simply, e.g., "torch_xla.pth".
401+
pth_filename = os.path.join(target_dir, f"{dist_name}.pth")
402+
403+
project_root = os.path.dirname(os.path.abspath(__file__))
404+
paths_to_add = {
405+
project_root, # For `torch_xla`
406+
os.path.abspath(os.path.join(project_root, 'torchax')), # For `torchax`
407+
}
408+
389409
with open(pth_filename, "w", encoding='utf-8') as f:
390410
for path in sorted(paths_to_add):
391411
f.write(path + "\n")
@@ -403,7 +423,7 @@ def _get_jax_install_requirements():
403423
jax = f'jax @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jax/jax-{_jax_version}-py3-none-any.whl'
404424

405425
jaxlib = []
406-
for python_minor_version in [9, 10, 11]:
426+
for python_minor_version in [9, 10, 11, 12]:
407427
jaxlib.append(
408428
f'jaxlib @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jaxlib/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
409429
)

torch_xla/experimental/gru.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(self, *args, **kwargs):
9999
super().__init__(*args, **kwargs)
100100

101101
def forward(self, input, hx=None):
102-
"""
102+
r"""
103103
Args:
104104
input: Tensor of shape (seq_len, batch, input_size)
105105
hx: Optional initial hidden state of shape (num_layers, batch, hidden_size).

0 commit comments

Comments
 (0)