Skip to content

Problem with using "mps" device  #2977

Open
@guptaaryan16

Description

@guptaaryan16

🐛 Bug description

The template is for Vision classification from code generator app


from pprint import pformat
from typing import Any

from ignite import distributed as idist
import yaml
from data import setup_data
from ignite.engine import Events
from ignite.metrics import Accuracy, Loss
from ignite.utils import manual_seed
from models import setup_model
from torch import nn, optim
from trainers import setup_evaluator, setup_trainer
from utils import *


def run(local_rank: int, config: Any):
    # make a certain seed
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # create output folder
    config.output_dir = setup_output_dir(config, rank)

    # donwload datasets and create dataloaders
    dataloader_train, dataloader_eval = setup_data(config)

    # model, optimizer, loss function, device
    device = torch.device('mps')
    model = idist.auto_model(setup_model(config.model))
    optimizer = idist.auto_optim(optim.Adam(model.parameters(), config.lr))
    loss_fn = nn.CrossEntropyLoss().to(device=device)

    # trainer and evaluator
    trainer = setup_trainer(
        config, model, optimizer, loss_fn, device, dataloader_train.sampler
    )
    evaluator = setup_evaluator(config, model, device)

    # attach metrics to evaluator
    accuracy = Accuracy(device=device)
    metrics = {
        "eval_accuracy": accuracy,
        "eval_loss": Loss(loss_fn, device=device),
        "eval_error": (1.0 - accuracy) * 100,
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # setup engines logger with python logging
    # print training configurations
    logger = setup_logging(config)
    logger.info("Configuration: \n%s", pformat(vars(config)))
    (config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
    trainer.logger = evaluator.logger = logger

    # setup ignite handlers
    to_save_train = {"model": model, "optimizer": optimizer, "trainer": trainer}
    to_save_eval = {"model": model}
    ckpt_handler_train, ckpt_handler_eval = setup_handlers(
        trainer, evaluator, config, to_save_train, to_save_eval
    )
    # experiment tracking
    if rank == 0:
        exp_logger = setup_exp_logging(config, trainer, optimizer, evaluator)

    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=config.log_every_iters),
        log_metrics,
        tag="train",
    )

    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def _():
        evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, "eval")

    # let's try run evaluation first as a sanity check
    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length)

    # setup if done. let's run the training
    trainer.run(
        dataloader_train,
        max_epochs=config.max_epochs,
        epoch_length=config.train_epoch_length,
    )
    # close logger
    if rank == 0:
        exp_logger.close()

    # show last checkpoint names
    logger.info(
        "Last training checkpoint name - %s",
        ckpt_handler_train.last_checkpoint,
    )

    logger.info(
        "Last evaluation checkpoint name - %s",
        ckpt_handler_eval.last_checkpoint,
    )


# main entrypoint
def main():
    config = setup_config()
    with idist.Parallel(backend=None) as p:
        p.run(run, config=config)


if __name__ == "__main__":
    main()


Error Report

[ignite]: Engine run starting with max_epochs=20.
[ignite]: Engine run starting with max_epochs=1.
[ignite]: Current run is terminating due to exception: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same
[ignite]: Engine run is terminating due to exception: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same
[ignite]: Engine run is terminating due to exception: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same
Traceback (most recent call last):
  File "/Users/guptaaryan16/Downloads/ignite-template-vision-classification-2/main.py", line 121, in <module>
    main()
  File "/Users/guptaaryan16/Downloads/ignite-template-vision-classification-2/main.py", line 117, in main
    p.run(run, config=config)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/distributed/launcher.py", line 316, in run
    func(local_rank, *args, **kwargs)
  File "/Users/guptaaryan16/Downloads/ignite-template-vision-classification-2/main.py", line 92, in run
    trainer.run(
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 892, in run
    return self._internal_run()
           ^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 935, in _internal_run
    return next(self._internal_run_generator)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 993, in _internal_run_as_gen
    self._handle_exception(e)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 638, in _handle_exception
    raise e
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 946, in _internal_run_as_gen
    self._fire_event(Events.STARTED)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 425, in _fire_event
    func(*first, *(event_args + others), **kwargs)
  File "/Users/guptaaryan16/Downloads/ignite-template-vision-classification-2/main.py", line 89, in _
    evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 892, in run
    return self._internal_run()
           ^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 935, in _internal_run
    return next(self._internal_run_generator)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 993, in _internal_run_as_gen
    self._handle_exception(e)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 638, in _handle_exception
    raise e
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 959, in _internal_run_as_gen
    epoch_time_taken += yield from self._run_once_on_dataset_as_gen()
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 1087, in _run_once_on_dataset_as_gen
    self._handle_exception(e)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 638, in _handle_exception
    raise e
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ignite/engine/engine.py", line 1068, in _run_once_on_dataset_as_gen
    self.state.output = self._process_function(self, self.state.batch)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guptaaryan16/Downloads/ignite-template-vision-classification-2/trainers.py", line 66, in eval_function
    outputs = model(samples)
              ^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torchvision/models/resnet.py", line 285, in forward
    return self._forward_impl(x)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torchvision/models/resnet.py", line 268, in _forward_impl
    x = self.conv1(x)
        ^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same

Environment

(MacBook M1)

  • PyTorch Version (e.g., 1.4): 2.0.1
  • Ignite Version (e.g., 0.3.0):0.4.12
  • OS (e.g., Linux): MacOs
  • How you installed Ignite (conda, pip, source): pip
  • Python version:3.11
  • Any other relevant information:

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions