Skip to content

handle older PyTorch versions #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1514090
handle older PyTorch versions
pmcurtin Jul 25, 2024
91a246a
fix
pmcurtin Jul 25, 2024
41fdd70
Update main.yml
pmcurtin Aug 12, 2024
c705041
tox
pmcurtin Jul 30, 2024
11465e1
pixi toml features
pmcurtin Aug 12, 2024
332215d
update lock
pmcurtin Aug 12, 2024
d76e207
Update main.yml
pmcurtin Aug 12, 2024
abb21c5
add editable, pytest
pmcurtin Aug 12, 2024
7f0496c
fix tests
pmcurtin Aug 12, 2024
e3c29a2
Update main.yml
pmcurtin Aug 12, 2024
f8b4cbf
Update main.yml
pmcurtin Aug 12, 2024
a0a964b
remove numpy requirement?
pmcurtin Aug 12, 2024
6507dcf
numpy<2, for failing tests
pmcurtin Aug 12, 2024
5e585f9
format
pmcurtin Aug 12, 2024
294937e
Delete tox.ini
apoorvkh Aug 13, 2024
dfa4278
updates to matrix test action
apoorvkh Aug 13, 2024
f3cec00
Merge branch 'main' into pytorch-compatibility
apoorvkh Aug 13, 2024
fb599f9
Update main.yml
apoorvkh Aug 13, 2024
7cf6f54
Update main.yml
apoorvkh Aug 13, 2024
f7124f3
downgrade pixi
apoorvkh Aug 13, 2024
67239be
Update main.yml
apoorvkh Aug 13, 2024
71de1aa
Delete .github/actions directory
apoorvkh Aug 13, 2024
24d6d9c
Update main.yml
apoorvkh Aug 13, 2024
f1ff9cc
Update main.yml
apoorvkh Aug 13, 2024
97032c3
Updated numpy version bounds
apoorvkh Aug 13, 2024
4d41601
note about numpy upper bound
apoorvkh Aug 13, 2024
bee7fec
Merge branch 'main' into pytorch-compatibility
apoorvkh Aug 13, 2024
6702522
Update advanced.rst
apoorvkh Aug 13, 2024
77ae65d
Automatically generate pytorch versions for testing
apoorvkh Aug 13, 2024
38e4312
Update main.yml
apoorvkh Aug 13, 2024
d87bb91
Update main.yml
apoorvkh Aug 13, 2024
2a08afd
Update main.yml
apoorvkh Aug 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions .github/actions/setup-env/action.yml

This file was deleted.

78 changes: 53 additions & 25 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,82 @@ on:

jobs:

setup-dev-env:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup-env
with:
environment: dev

check:
runs-on: ubuntu-latest
needs: setup-dev-env
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup-env
- uses: prefix-dev/setup-[email protected]
with:
environment: dev
pixi-version: v0.27.1
frozen: true
cache: true
environments: dev
activate-environment: dev
- run: ruff check
if: success() || failure()
- run: ruff format --check
if: success() || failure()

##

setup-default-env:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup-env
with:
environment: default

typecheck:
runs-on: ubuntu-latest
needs: setup-default-env
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup-env
- uses: prefix-dev/setup-[email protected]
with:
environment: default
pixi-version: v0.27.1
frozen: true
cache: true
environments: default
activate-environment: default
- run: pyright
if: success() || failure()

##

get-pytorch-versions:
runs-on: ubuntu-latest
outputs:
versions: ${{ steps.get-pytorch-versions.outputs.versions }}
steps:
- name: Get PyTorch versions
id: get-pytorch-versions
run: |
VERSIONS=$(
curl -s https://pypi.org/pypi/torch/json | jq -r '.releases | keys[]' |
# remove versions <2.0; strip "patch" from versions
grep -v '^1\.' | grep -E '\.[0]+$' | sort -V | sed 's/\.0$//' |
# to JSON array
jq -R . | jq -sc .
)
echo "versions=$VERSIONS" >> $GITHUB_OUTPUT
# e.g. ["2.0","2.1","2.2","2.3","2.4"]

test:
runs-on: ubuntu-latest
needs: setup-default-env
needs: get-pytorch-versions
strategy:
fail-fast: false
matrix:
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
pytorch: ${{fromJson(needs.get-pytorch-versions.outputs.versions)}}

steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup-env

- name: Inject (python / pytorch) test deps into pixi.toml / pyproject.toml
run: |
sed -i 's/python = "3.8.1"/python = "${{ matrix.python }}.*"/' pixi.toml
sed -i 's/torch>=2.0.0/torch~=${{ matrix.pytorch }}/' pyproject.toml

- uses: prefix-dev/[email protected]
with:
environment: default
pixi-version: v0.27.1
locked: false
frozen: false
cache: false
environments: default
activate-environment: default

- run: pytest tests
6 changes: 5 additions & 1 deletion docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ The :mod:`torchrunx.launch` ``env_vars`` argument allows the user to specify whi
:mod:`torchrunx.launch` also accepts the ``env_file`` argument, which is designed to expose more advanced environmental configuration to the user. When a file is provided as this argument, the launcher will source the file on each node before executing the agent. This allows for custom bash scripts to be provided in the environmental variables, and allows for node-specific environmental variables to be set.

..
TODO: example env_file
TODO: example env_file

Numpy >= 2.0
------------
only supported if `torch>=2.3`
3 changes: 1 addition & 2 deletions examples/slurm_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_launch():
print("PASS")


def simple_matmul(test):
def simple_matmul():
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu")
Expand All @@ -35,7 +35,6 @@ def simple_matmul(test):

i = torch.rand((500, 100), device=device) # batch, dim
o = torch.matmul(i, w)
print(test)
dist.all_reduce(o, op=dist.ReduceOp.SUM)
print(i)
return o.detach().cpu()
Expand Down
Loading
Loading