52
52
import contextlib
53
53
import distutils .ccompiler
54
54
import distutils .command .clean
55
+ import importlib .util
55
56
import os
56
57
import re
57
58
import requests
61
62
import tempfile
62
63
import zipfile
63
64
64
- import build_util
65
+ # This gloop imports build_util.py such that it works in Python 3.12's isolated
66
+ # build environment while also not contaminating sys.path which breaks bdist_wheel.
67
+ _PROJECT_DIR = os .path .dirname (os .path .abspath (__file__ ))
68
+ _build_util_path = os .path .join (_PROJECT_DIR , 'build_util.py' )
69
+ spec = importlib .util .spec_from_file_location ('build_util' , _build_util_path )
70
+ build_util = importlib .util .module_from_spec (spec )
71
+ spec .loader .exec_module (build_util )
65
72
66
73
import platform
67
74
@@ -151,7 +158,7 @@ def get_git_head_sha(base_dir):
151
158
152
159
153
160
def get_build_version (xla_git_sha ):
154
- version = os .getenv ('TORCH_XLA_VERSION' , '2.8 .0' )
161
+ version = os .getenv ('TORCH_XLA_VERSION' , '2.9 .0' )
155
162
if build_util .check_env_flag ('GIT_VERSIONED_XLA_BUILD' , default = 'TRUE' ):
156
163
try :
157
164
version += '+git' + xla_git_sha [:7 ]
@@ -270,40 +277,51 @@ def __init__(self, bazel_target):
270
277
class BuildBazelExtension (build_ext .build_ext ):
271
278
"""A command that runs Bazel to build a C/C++ extension."""
272
279
273
- def run (self ):
274
- for ext in self .extensions :
275
- self .bazel_build (ext )
276
- command .build_ext .build_ext .run (self ) # type: ignore
280
+ def build_extension (self , ext : Extension ) -> None :
281
+ """
282
+ This method is called by setuptools to build a single extension.
283
+ We override it to implement our custom Bazel build logic.
284
+ """
285
+ if not isinstance (ext , BazelExtension ):
286
+ # If it's not our custom extension type, let setuptools handle it.
287
+ super ().build_extension (ext )
288
+ return
277
289
278
- def bazel_build ( self , ext ):
290
+ # 1. Ensure the temporary build directory exists
279
291
if not os .path .exists (self .build_temp ):
280
292
os .makedirs (self .build_temp )
281
293
294
+ # 2. Prepare the Bazel command
282
295
bazel_argv = [
283
296
'bazel' , 'build' , ext .bazel_target ,
284
297
f"--symlink_prefix={ os .path .join (self .build_temp , 'bazel-' )} "
285
298
]
286
299
287
- build_cpp_tests = build_util .check_env_flag ('BUILD_CPP_TESTS' , default = '0' )
288
- if build_cpp_tests :
289
- bazel_argv .append ('//:cpp_tests' )
290
-
291
- import torch
292
- cxx_abi = os .getenv ('CXX_ABI' ) or getattr (torch ._C ,
293
- '_GLIBCXX_USE_CXX11_ABI' , None )
294
- if cxx_abi is not None :
295
- bazel_argv .append (f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={ int (cxx_abi )} ' )
300
+ cxx_abi = os .getenv ('CXX_ABI' )
301
+ if cxx_abi is None :
302
+ try :
303
+ import torch
304
+ cxx_abi = getattr (torch ._C , '_GLIBCXX_USE_CXX11_ABI' , None )
305
+ except :
306
+ pass
307
+ if cxx_abi is None :
308
+ # Default to building with C++11 ABI, which has been the case since PyTorch 2.7
309
+ cxx_abi = "1"
310
+ bazel_argv .append (f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={ int (cxx_abi )} ' )
296
311
297
312
bazel_argv .extend (build_util .bazel_options_from_env ())
298
313
314
+ # 3. Run the Bazel build
299
315
self .spawn (bazel_argv )
300
316
317
+ # 4. Copy the output file to the location setuptools expects
301
318
ext_bazel_bin_path = os .path .join (self .build_temp , 'bazel-bin' , ext .relpath ,
302
319
ext .target_name )
303
320
ext_dest_path = self .get_ext_fullpath (ext .name )
304
321
ext_dest_dir = os .path .dirname (ext_dest_path )
305
322
if not os .path .exists (ext_dest_dir ):
306
323
os .makedirs (ext_dest_dir )
324
+
307
325
shutil .copyfile (ext_bazel_bin_path , ext_dest_path )
308
326
309
327
@@ -313,17 +331,21 @@ def bazel_build(self, ext):
313
331
long_description = f .read ()
314
332
315
333
# Finds torch_xla and its subpackages
316
- packages_to_include = find_packages ( include = [ 'torch_xla*' ])
317
- # Explicitly add torchax
318
- packages_to_include . extend ( find_packages ( where = ' torchax' , include = [ 'torchax*' ]) )
334
+ # We specify 'torchax.torchax' to find the nested package correctly.
335
+ packages_to_include = find_packages (
336
+ include = [ 'torch_xla' , 'torch_xla.*' , ' torchax' , 'torchax. *' ])
319
337
320
338
# Map the top-level 'torchax' package name to its source location
321
- torchax_dir = os . path . join ( cwd , 'torchax' )
322
- package_dir_mapping = { 'torch_xla ' : os . path . join ( cwd , 'torch_xla' )}
323
- package_dir_mapping [ 'torchax' ] = os . path . join ( torchax_dir , 'torchax' )
339
+ package_dir_mapping = {
340
+ 'torchax ' : 'torchax/torchax' ,
341
+ }
324
342
325
343
326
344
class Develop (develop .develop ):
345
+ """
346
+ Custom develop command to build C++ extensions and create a .pth file
347
+ for a multi-package editable install.
348
+ """
327
349
328
350
def run (self ):
329
351
# Build the C++ extension
@@ -348,44 +370,42 @@ def link_packages(self):
348
370
(`python setup.py develop`). Nightly and release wheel builds work out of the box
349
371
without egg-link/pth.
350
372
"""
373
+ import glob
374
+
351
375
# Ensure paths like self.install_dir are set
352
376
self .ensure_finalized ()
353
377
354
- # Get the site-packages directory
355
- target_dir = self .install_dir
356
-
357
- # Remove the standard .egg-link file
358
- # It's usually named based on the distribution name
359
378
dist_name = self .distribution .get_name ()
360
- egg_link_file = os .path .join (target_dir , dist_name + '.egg-link' )
361
- if os .path .exists (egg_link_file ):
362
- print (f"Removing default egg-link file: { egg_link_file } " )
363
- try :
364
- os .remove (egg_link_file )
365
- except OSError as e :
366
- print (f"Warning: Could not remove { egg_link_file } : { e } " )
367
-
368
- # Create our custom .pth file with specific paths
369
- cwd = os .path .dirname (__file__ )
370
- # Path containing 'torch_xla' package source: ROOT
371
- path_for_torch_xla = os .path .abspath (cwd )
372
- # Path containing 'torchax' package source: ROOT/torchax
373
- path_for_torchax = os .path .abspath (os .path .join (cwd , 'torchax' ))
374
-
375
- paths_to_add = {path_for_torch_xla , path_for_torchax }
376
-
377
- # Construct a suitable .pth filename (PEP 660 style is good practice)
378
- version = self .distribution .get_version ()
379
- # Sanitize name and version for filename (replace runs of non-alphanumeric chars with '-')
380
- sanitized_name = re .sub (r"[^a-zA-Z0-9.]+" , "_" , dist_name )
381
- sanitized_version = re .sub (r"[^a-zA-Z0-9.]+" , "_" , version )
382
- pth_filename = os .path .join (
383
- target_dir , f"__editable_{ sanitized_name } _{ sanitized_version } .pth" )
384
-
385
- # Ensure site-packages exists
386
- os .makedirs (target_dir , exist_ok = True )
387
-
388
- # Write the paths to the .pth file, one per line
379
+ install_cmd = self .get_finalized_command ('install' )
380
+ target_dir = install_cmd .install_lib
381
+ assert target_dir is not None
382
+
383
+ # Use glob to robustly find and remove the conflicting files.
384
+ # This is safer than trying to guess the exact sanitized filename.
385
+ safe_name_part = re .sub (r"[^a-zA-Z0-9]+" , "_" , dist_name )
386
+
387
+ for pattern in [
388
+ # Remove `.pth` files generated in Python 3.12.
389
+ f"__editable__.*{ safe_name_part } *.pth" ,
390
+ f"__editable___*{ safe_name_part } *_finder.py" ,
391
+ # Also remove the legacy egg-link format.
392
+ f"{ dist_name } .egg-link"
393
+ ]:
394
+ for filepath in glob .glob (os .path .join (target_dir , pattern )):
395
+ print (f"Cleaning up conflicting install file: { filepath } " )
396
+ with contextlib .suppress (OSError ):
397
+ os .remove (filepath )
398
+
399
+ # Finally, create our own simple, multi-path .pth file.
400
+ # We name it simply, e.g., "torch_xla.pth".
401
+ pth_filename = os .path .join (target_dir , f"{ dist_name } .pth" )
402
+
403
+ project_root = os .path .dirname (os .path .abspath (__file__ ))
404
+ paths_to_add = {
405
+ project_root , # For `torch_xla`
406
+ os .path .abspath (os .path .join (project_root , 'torchax' )), # For `torchax`
407
+ }
408
+
389
409
with open (pth_filename , "w" , encoding = 'utf-8' ) as f :
390
410
for path in sorted (paths_to_add ):
391
411
f .write (path + "\n " )
@@ -403,7 +423,7 @@ def _get_jax_install_requirements():
403
423
jax = f'jax @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jax/jax-{ _jax_version } -py3-none-any.whl'
404
424
405
425
jaxlib = []
406
- for python_minor_version in [9 , 10 , 11 ]:
426
+ for python_minor_version in [9 , 10 , 11 , 12 ]:
407
427
jaxlib .append (
408
428
f'jaxlib @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jaxlib/jaxlib-{ _jaxlib_version } -cp3{ python_minor_version } -cp3{ python_minor_version } -manylinux2014_x86_64.whl ; python_version == "3.{ python_minor_version } "'
409
429
)
0 commit comments