Skip to content

Commit c9c584f

Browse files
committed
Add support for Intel GPU to ImageNet example
1 parent 5dfeb46 commit c9c584f

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

imagenet/README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ python main.py -a resnet18 --dummy
3333

3434
## Multi-processing Distributed Data Parallel Training
3535

36-
You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.
36+
If running on CUDA, you should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.
37+
38+
For XPU multiprocessing is not supported as of PyTorch 2.6.
3739

3840
### Single node, multiple GPUs:
3941

imagenet/main.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def main_worker(gpu, ngpus_per_node, args):
147147
print("=> creating model '{}'".format(args.arch))
148148
model = models.__dict__[args.arch]()
149149

150-
if not torch.cuda.is_available() and not torch.backends.mps.is_available():
150+
if not torch.cuda.is_available() and not torch.backends.mps.is_available() and not torch.xpu.is_available():
151151
print('using CPU, this will be slow')
152152
elif args.distributed:
153153
# For multiprocessing distributed, DistributedDataParallel constructor
@@ -171,6 +171,9 @@ def main_worker(gpu, ngpus_per_node, args):
171171
elif args.gpu is not None and torch.cuda.is_available():
172172
torch.cuda.set_device(args.gpu)
173173
model = model.cuda(args.gpu)
174+
elif torch.xpu.is_available():
175+
device = torch.device("xpu")
176+
model = model.to(device)
174177
elif torch.backends.mps.is_available():
175178
device = torch.device("mps")
176179
model = model.to(device)
@@ -187,10 +190,15 @@ def main_worker(gpu, ngpus_per_node, args):
187190
device = torch.device('cuda:{}'.format(args.gpu))
188191
else:
189192
device = torch.device("cuda")
193+
elif torch.xpu.is_available():
194+
device = torch.device("xpu")
190195
elif torch.backends.mps.is_available():
191196
device = torch.device("mps")
192197
else:
193198
device = torch.device("cpu")
199+
200+
print (f"Device to use: ", {device.type})
201+
194202
# define loss function (criterion), optimizer, and learning rate scheduler
195203
criterion = nn.CrossEntropyLoss().to(device)
196204

@@ -354,14 +362,19 @@ def run_validate(loader, base_progress=0):
354362
end = time.time()
355363
for i, (images, target) in enumerate(loader):
356364
i = base_progress + i
357-
if args.gpu is not None and torch.cuda.is_available():
358-
images = images.cuda(args.gpu, non_blocking=True)
359-
if torch.backends.mps.is_available():
360-
images = images.to('mps')
361-
target = target.to('mps')
365+
362366
if torch.cuda.is_available():
367+
if args.gpu is not None:
368+
images = images.cuda(args.gpu, non_blocking=True)
363369
target = target.cuda(args.gpu, non_blocking=True)
364370

371+
elif torch.xpu.is_available():
372+
images = images.to("xpu")
373+
target = target.to("xpu")
374+
elif torch.backends.mps.is_available():
375+
images = images.to('mps')
376+
target = target.to('mps')
377+
365378
# compute output
366379
output = model(images)
367380
loss = criterion(output, target)
@@ -443,6 +456,8 @@ def update(self, val, n=1):
443456
def all_reduce(self):
444457
if torch.cuda.is_available():
445458
device = torch.device("cuda")
459+
elif torch.xpu.is_available():
460+
device = torch.device("xpu")
446461
elif torch.backends.mps.is_available():
447462
device = torch.device("mps")
448463
else:

0 commit comments

Comments
 (0)