@@ -335,14 +335,21 @@ def build_extension(self, ext: Extension) -> None:
335
335
long_description = f .read ()
336
336
337
337
# Finds torch_xla and its subpackages
338
- # We specify 'torchax.torchax' to find the nested package correctly.
339
- packages_to_include = find_packages (
340
- include = ['torch_xla' , 'torch_xla.*' , 'torchax' , 'torchax.*' ])
341
-
342
- # Map the top-level 'torchax' package name to its source location
343
- package_dir_mapping = {
344
- 'torchax' : 'torchax/torchax' ,
345
- }
338
+ # 1. Find `torch_xla` and its subpackages automatically from the root.
339
+ packages_to_include = find_packages (include = ['torch_xla' , 'torch_xla.*' ])
340
+
341
+ # 2. Explicitly find the contents of the nested `torchax` package.
342
+ # Find all sub-packages within the torchax directory (e.g., 'ops').
343
+ torchax_source_dir = 'torchax/torchax'
344
+ torchax_subpackages = find_packages (where = torchax_source_dir )
345
+ # Construct the full list of packages, starting with the top-level
346
+ # 'torchax' and adding all the discovered sub-packages.
347
+ packages_to_include .extend (['torchax' ] +
348
+ ['torchax.' + pkg for pkg in torchax_subpackages ])
349
+
350
+ # 3. The package_dir mapping explicitly tells setuptools where the 'torchax'
351
+ # package's source code begins. `torch_xla` source code is inferred.
352
+ package_dir_mapping = {'torchax' : torchax_source_dir }
346
353
347
354
348
355
class Develop (develop .develop ):
0 commit comments