Skip to content

Commit c317713

Browse files
committed
Add support for Intel GPU to Siamese Network example
1 parent 5dfeb46 commit c317713

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

siamese_network/README.md

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,42 @@
11
# Siamese Network Example
22

3+
Siamese network for image similarity estimation.
4+
The network is composed of two identical networks, one for each input.
5+
The output of each network is concatenated and passed to a linear layer.
6+
The output of the linear layer passed through a sigmoid function.
7+
[FaceNet](https://arxiv.org/pdf/1503.03832.pdf) is a variant of the Siamese network.
8+
This implementation varies from FaceNet as we use the `ResNet-18` model from
9+
[Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) as our feature extractor.
10+
In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick.
11+
312
```bash
413
pip install -r requirements.txt
514
python main.py
6-
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2
715
```
16+
17+
Optionally, you can add the following arguments to customize your execution.
18+
19+
```bash
20+
--batch-size input batch size for training (default: 64)
21+
--test-batch-size input batch size for testing (default: 1000)
22+
--epochs number of epochs to train (default: 14)
23+
--lr learning rate (default: 1.0)
24+
--gamma learning rate step gamma (default: 0.7)
25+
--no-cuda disables CUDA training
26+
--no-xpu disables XPU training
27+
--no-mps disables macOS GPU training
28+
--dry-run quickly check a single pass
29+
--seed random seed (default: 1)
30+
--log-interval how many batches to wait before logging training status
31+
--save-model Saving the current Model
32+
```
33+
34+
If a GPU device (CUDA, XPU, or MPS) is detected, the example will be executed on the GPU by default; otherwise, it will run on the CPU.
35+
36+
To disable the GPU option, add the appropriate argument to the command. For example:
37+
38+
```bash
39+
python main.py --no-xpu
40+
```
41+
42+
This command will execute the example on the CPU even if your system successfully detects an XPU.

siamese_network/main.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -247,32 +247,39 @@ def main():
247247
help='learning rate (default: 1.0)')
248248
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
249249
help='Learning rate step gamma (default: 0.7)')
250-
parser.add_argument('--no-cuda', action='store_true', default=False,
250+
parser.add_argument('--no-cuda', action='store_true',
251251
help='disables CUDA training')
252-
parser.add_argument('--no-mps', action='store_true', default=False,
252+
parser.add_argument('--no-xpu', action='store_true',
253+
help='disables XPU training')
254+
parser.add_argument('--no-mps', action='store_true',
253255
help='disables macOS GPU training')
254-
parser.add_argument('--dry-run', action='store_true', default=False,
256+
parser.add_argument('--dry-run', action='store_true',
255257
help='quickly check a single pass')
256258
parser.add_argument('--seed', type=int, default=1, metavar='S',
257259
help='random seed (default: 1)')
258260
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
259261
help='how many batches to wait before logging training status')
260-
parser.add_argument('--save-model', action='store_true', default=False,
262+
parser.add_argument('--save-model', action='store_true',
261263
help='For Saving the current Model')
262264
args = parser.parse_args()
263265

264266
use_cuda = not args.no_cuda and torch.cuda.is_available()
267+
use_xpu = not args.no_xpu and torch.xpu.is_available()
265268
use_mps = not args.no_mps and torch.backends.mps.is_available()
266269

267270
torch.manual_seed(args.seed)
268271

269272
if use_cuda:
270273
device = torch.device("cuda")
274+
elif use_xpu:
275+
device = torch.device("xpu")
271276
elif use_mps:
272277
device = torch.device("mps")
273278
else:
274279
device = torch.device("cpu")
275280

281+
print('Device to use: ', device)
282+
276283
train_kwargs = {'batch_size': args.batch_size}
277284
test_kwargs = {'batch_size': args.test_batch_size}
278285
if use_cuda:

0 commit comments

Comments
 (0)