Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions release/air_tests/air_benchmarks/workloads/torch_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from pathlib import Path
from typing import Dict, Tuple
import tempfile

import click
import numpy as np
Expand Down Expand Up @@ -205,13 +206,23 @@ def collate_fn(x):

local_time_taken = time.monotonic() - local_start_time

if use_ray:
train.report(dict(loss=loss, local_time_taken=local_time_taken))
else:
print(f"Reporting loss: {loss:.4f}")
if local_rank == 0:
with open(VANILLA_RESULT_JSON, "w") as f:
json.dump({"loss": loss, "local_time_taken": local_time_taken}, f)
if use_ray:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
if train.get_context().get_world_rank() == 0:
torch.save(
model.state_dict(),
os.path.join(temp_checkpoint_dir, "model.pt"),
)

train.report(
dict(loss=loss, local_time_taken=local_time_taken),
checkpoint=train.Checkpoint.from_directory(temp_checkpoint_dir),
)
else:
print(f"Reporting loss: {loss:.4f}")
if local_rank == 0:
with open(VANILLA_RESULT_JSON, "w") as f:
json.dump({"loss": loss, "local_time_taken": local_time_taken}, f)


def train_torch_ray_air(
Expand All @@ -223,7 +234,7 @@ def train_torch_ray_air(
) -> Tuple[float, float, float]:
# This function is kicked off by the main() function and runs a full training
# run using Ray AIR.
from ray.train import ScalingConfig
from ray.train import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer

def train_loop(config):
Expand All @@ -234,11 +245,11 @@ def train_loop(config):
train_loop_per_worker=train_loop,
train_loop_config=config,
scaling_config=ScalingConfig(
trainer_resources={"CPU": 0},
num_workers=num_workers,
resources_per_worker={"CPU": cpus_per_worker},
use_gpu=use_gpu,
),
run_config=RunConfig(storage_path="/mnt/cluster_storage"),
)
result = trainer.fit()
time_taken = time.monotonic() - start_time
Expand Down
2 changes: 1 addition & 1 deletion release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@

run:
timeout: 4800
script: python workloads/torch_benchmark.py run --num-runs 3 --num-epochs 120 --num-workers 16 --cpus-per-worker 4 --batch-size 1024 --use-gpu
script: RAY_TRAIN_V2_ENABLED=1 python workloads/torch_benchmark.py run --num-runs 3 --num-epochs 120 --num-workers 16 --cpus-per-worker 4 --batch-size 1024 --use-gpu

wait_for_nodes:
num_nodes: 4
Expand Down