Skip to content

当使用Wandb作为可视化后端时,程序卡在Saving Checkpinot #623

@LovingThresh

Description

@LovingThresh

使用示例代码

原始代码无误

原始代码 + TensorBoard 可视化后端无误

原始代码 + Wandb 可视化后端会一直卡在Saving Checkpoint

import torchvision
from torch.optim import SGD
import torch.nn.functional as F
from mmengine.runner import Runner
from mmengine.model import BaseModel
from torch.utils.data import DataLoader
from mmengine.evaluator import BaseMetric
import torchvision.transforms as transforms


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
    x = self.resnet(imgs)
    if mode == 'loss':
        return {'loss': F.cross_entropy(x, labels)}
    elif mode == 'predict':
        return x, labels


norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))

val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))


class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # 将一个批次的中间结果保存至 `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

def compute_metrics(self, results):
    total_correct = sum(item['correct'] for item in results)
    total_size = sum(item['batch_size'] for item in results)
    # 返回保存有评测指标结果的字典,其中键为指标名称
    return dict(accuracy=100 * total_correct / total_size)


runner = Runner(
    # 用以训练和验证的模型,需要满足特定的接口需求
    model=MMResNet50(),
    # 工作路径,用以保存训练日志、权重文件信息
    work_dir='./work_dir',
    # 训练数据加载器,需要满足 PyTorch 数据加载器协议
    train_dataloader=train_dataloader,
    # 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # 训练配置,用于指定训练周期、验证间隔等信息
    train_cfg=dict(by_epoch=True, max_epochs=1, val_interval=1),
    # 验证数据加载器,需要满足 PyTorch 数据加载器协议
    val_dataloader=val_dataloader,
    # 验证配置,用于指定验证所需要的额外参数
    val_cfg=dict(),
    default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),
    # 用于验证的评测器,这里使用默认评测器,并评测指标
    val_evaluator=dict(type=Accuracy),
    visualizer=dict(type='Visualizer',
                    vis_backends=[dict(type='WandbVisBackend', init_kwargs={'project': "my-awesome-project"})]),
)

runner.train()

10/17 22:16:10 - mmengine - INFO - Epoch(train) [1][1560/1563] lr: 1.0000e-03 eta: 0:00:00 time: 0.0515 data_time: 0.0101 memory: 369 loss: 1.9646
10/17 22:16:11 - mmengine - INFO - Exp name: 20221017_221433
10/17 22:16:11 - mmengine - INFO - Saving checkpoint at 1 epochs (就一直卡在这,也不知道是怎么回事)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions