Skip to content

Commit 2c435c7

Browse files
authored
Use torch.accelerator API in GAT example (#1335)
Refactor GAT example to utilize `torch.accelerator` API `torch.accelerator` API allows to abstract some of the accelerator specifics in the user scripts. By leveraging this API, the code becomes more adaptable to various hardware accelerators. Signed-off-by: jafraustro <[email protected]>
1 parent 65722fe commit 2c435c7

File tree

3 files changed

+9
-14
lines changed

3 files changed

+9
-14
lines changed

gat/README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ options:
8787
--concat-heads wether to concatinate attention heads, or average over them (default: False)
8888
--val-every VAL_EVERY
8989
epochs to wait for print training and validation evaluation (default: 20)
90-
--no-cuda disables CUDA training
91-
--no-mps disables macOS GPU training
90+
--no-accel disables accelerator
9291
--dry-run quickly check a single pass
9392
--seed S random seed (default: 13)
9493
```

gat/main.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -303,29 +303,25 @@ def test(model, criterion, input, target, mask):
303303
help='dimension of the hidden representation (default: 64)')
304304
parser.add_argument('--num-heads', type=int, default=8,
305305
help='number of the attention heads (default: 4)')
306-
parser.add_argument('--concat-heads', action='store_true', default=False,
306+
parser.add_argument('--concat-heads', action='store_true',
307307
help='wether to concatinate attention heads, or average over them (default: False)')
308308
parser.add_argument('--val-every', type=int, default=20,
309309
help='epochs to wait for print training and validation evaluation (default: 20)')
310-
parser.add_argument('--no-cuda', action='store_true', default=False,
310+
parser.add_argument('--no-accel', action='store_true',
311311
help='disables CUDA training')
312-
parser.add_argument('--no-mps', action='store_true', default=False,
313-
help='disables macOS GPU training')
314-
parser.add_argument('--dry-run', action='store_true', default=False,
312+
parser.add_argument('--dry-run', action='store_true',
315313
help='quickly check a single pass')
316314
parser.add_argument('--seed', type=int, default=13, metavar='S',
317315
help='random seed (default: 13)')
318316
args = parser.parse_args()
319317

320318
torch.manual_seed(args.seed)
321-
use_cuda = not args.no_cuda and torch.cuda.is_available()
322-
use_mps = not args.no_mps and torch.backends.mps.is_available()
319+
320+
use_accel = not args.no_accel and torch.accelerator.is_available()
323321

324322
# Set the device to run on
325-
if use_cuda:
326-
device = torch.device('cuda')
327-
elif use_mps:
328-
device = torch.device('mps')
323+
if use_accel:
324+
device = torch.accelerator.current_accelerator()
329325
else:
330326
device = torch.device('cpu')
331327
print(f'Using {device} device')

gat/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
torch
22
requests
3-
numpy<2
3+
numpy

0 commit comments

Comments
 (0)