Skip to content

Commit fd40065

Browse files
authored
Merge pull request #53 from apoorvkh/pytorch-compatibility
handle older PyTorch versions
2 parents 392b0ef + 2a08afd commit fd40065

File tree

8 files changed

+4604
-1596
lines changed

8 files changed

+4604
-1596
lines changed

.github/actions/setup-env/action.yml

Lines changed: 0 additions & 16 deletions
This file was deleted.

.github/workflows/main.yml

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,54 +9,82 @@ on:
99

1010
jobs:
1111

12-
setup-dev-env:
13-
runs-on: ubuntu-latest
14-
steps:
15-
- uses: actions/checkout@v4
16-
- uses: ./.github/actions/setup-env
17-
with:
18-
environment: dev
19-
2012
check:
2113
runs-on: ubuntu-latest
22-
needs: setup-dev-env
2314
steps:
2415
- uses: actions/checkout@v4
25-
- uses: ./.github/actions/setup-env
16+
- uses: prefix-dev/setup-[email protected]
2617
with:
27-
environment: dev
18+
pixi-version: v0.27.1
19+
frozen: true
20+
cache: true
21+
environments: dev
22+
activate-environment: dev
2823
- run: ruff check
2924
if: success() || failure()
3025
- run: ruff format --check
3126
if: success() || failure()
3227

3328
##
3429

35-
setup-default-env:
36-
runs-on: ubuntu-latest
37-
steps:
38-
- uses: actions/checkout@v4
39-
- uses: ./.github/actions/setup-env
40-
with:
41-
environment: default
42-
4330
typecheck:
4431
runs-on: ubuntu-latest
45-
needs: setup-default-env
4632
steps:
4733
- uses: actions/checkout@v4
48-
- uses: ./.github/actions/setup-env
34+
- uses: prefix-dev/setup-[email protected]
4935
with:
50-
environment: default
36+
pixi-version: v0.27.1
37+
frozen: true
38+
cache: true
39+
environments: default
40+
activate-environment: default
5141
- run: pyright
5242
if: success() || failure()
5343

44+
##
45+
46+
get-pytorch-versions:
47+
runs-on: ubuntu-latest
48+
outputs:
49+
versions: ${{ steps.get-pytorch-versions.outputs.versions }}
50+
steps:
51+
- name: Get PyTorch versions
52+
id: get-pytorch-versions
53+
run: |
54+
VERSIONS=$(
55+
curl -s https://pypi.org/pypi/torch/json | jq -r '.releases | keys[]' |
56+
# remove versions <2.0; strip "patch" from versions
57+
grep -v '^1\.' | grep -E '\.[0]+$' | sort -V | sed 's/\.0$//' |
58+
# to JSON array
59+
jq -R . | jq -sc .
60+
)
61+
echo "versions=$VERSIONS" >> $GITHUB_OUTPUT
62+
# e.g. ["2.0","2.1","2.2","2.3","2.4"]
63+
5464
test:
5565
runs-on: ubuntu-latest
56-
needs: setup-default-env
66+
needs: get-pytorch-versions
67+
strategy:
68+
fail-fast: false
69+
matrix:
70+
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
71+
pytorch: ${{fromJson(needs.get-pytorch-versions.outputs.versions)}}
72+
5773
steps:
5874
- uses: actions/checkout@v4
59-
- uses: ./.github/actions/setup-env
75+
76+
- name: Inject (python / pytorch) test deps into pixi.toml / pyproject.toml
77+
run: |
78+
sed -i 's/python = "3.8.1"/python = "${{ matrix.python }}.*"/' pixi.toml
79+
sed -i 's/torch>=2.0.0/torch~=${{ matrix.pytorch }}/' pyproject.toml
80+
81+
- uses: prefix-dev/[email protected]
6082
with:
61-
environment: default
83+
pixi-version: v0.27.1
84+
locked: false
85+
frozen: false
86+
cache: false
87+
environments: default
88+
activate-environment: default
89+
6290
- run: pytest tests

docs/source/advanced.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ The :mod:`torchrunx.launch` ``env_vars`` argument allows the user to specify whi
3535
:mod:`torchrunx.launch` also accepts the ``env_file`` argument, which is designed to expose more advanced environmental configuration to the user. When a file is provided as this argument, the launcher will source the file on each node before executing the agent. This allows for custom bash scripts to be provided in the environmental variables, and allows for node-specific environmental variables to be set.
3636

3737
..
38-
TODO: example env_file
38+
TODO: example env_file
39+
40+
Numpy >= 2.0
41+
------------
42+
only supported if `torch>=2.3`

examples/slurm_poc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_launch():
2121
print("PASS")
2222

2323

24-
def simple_matmul(test):
24+
def simple_matmul():
2525
rank = int(os.environ["RANK"])
2626
local_rank = int(os.environ["LOCAL_RANK"])
2727
device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu")
@@ -35,7 +35,6 @@ def simple_matmul(test):
3535

3636
i = torch.rand((500, 100), device=device) # batch, dim
3737
o = torch.matmul(i, w)
38-
print(test)
3938
dist.all_reduce(o, op=dist.ReduceOp.SUM)
4039
print(i)
4140
return o.detach().cpu()

0 commit comments

Comments
 (0)