Skip to content

[Feature] Support getting model from the name defined in the model-index file. #1236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions docs/en/api/apis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ These are some high-level APIs for classification tasks.
:local:
:backlinks: top

Inference
Model
------------------

.. autosummary::
:toctree: generated
:nosignatures:
.. autofunction:: list_models

.. autofunction:: get_model

.. autofunction:: init_model


Inference
------------------

init_model
inference_model
.. autofunction:: inference_model
1 change: 1 addition & 0 deletions mmcls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mmengine
from mmengine.utils import digit_version

from .apis import * # noqa: F401, F403
from .version import __version__

mmcv_minimum_version = '2.0.0rc1'
Expand Down
7 changes: 5 additions & 2 deletions mmcls/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model
from .inference import inference_model
from .model import ModelHub, get_model, init_model, list_models

__all__ = ['init_model', 'inference_model']
__all__ = [
'init_model', 'inference_model', 'list_models', 'get_model', 'ModelHub'
]
66 changes: 10 additions & 56 deletions mmcls/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Union

import numpy as np
import torch
from mmengine.config import Config
from mmengine.dataset import Compose, pseudo_collate
from mmengine.runner import load_checkpoint
from mmengine.dataset import Compose, default_collate
from mmengine.model import BaseModel
from mmengine.registry import DefaultScope

from mmcls.models import build_classifier
from mmcls.utils import register_all_modules
import mmcls.datasets # noqa: F401


def init_model(config, checkpoint=None, device='cuda:0', options=None):
"""Initialize a classifier from config file.

Args:
config (str or :obj:`mmengine.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
options (dict): Options to override some settings in the used config.

Returns:
nn.Module: The constructed classifier.
"""
register_all_modules()
if isinstance(config, str):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if options is not None:
config.merge_from_dict(options)
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
model = build_classifier(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmcls 1.x
model.CLASSES = checkpoint['meta']['dataset_meta']['classes']
elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls < 1.x
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets.categories import IMAGENET_CATEGORIES
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = IMAGENET_CATEGORIES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model


def inference_model(model, img):
def inference_model(model: BaseModel, img: Union[str, np.ndarray]):
"""Inference image(s) with the classifier.

Args:
Expand All @@ -67,7 +21,6 @@ def inference_model(model, img):
result (dict): The classification results that contains
`class_name`, `pred_label` and `pred_score`.
"""
register_all_modules()
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
Expand All @@ -79,9 +32,10 @@ def inference_model(model, img):
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
test_pipeline = Compose(test_pipeline_cfg)
with DefaultScope.overwrite_default_scope('mmcls'):
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = pseudo_collate([data])
data = default_collate([data])

# forward the model
with torch.no_grad():
Expand Down
220 changes: 220 additions & 0 deletions mmcls/apis/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import fnmatch
import os.path as osp
import warnings
from os import PathLike
from pathlib import Path
from typing import List, Union

from mmengine.config import Config
from mmengine.runner import load_checkpoint
from mmengine.utils import get_installed_path
from modelindex.load_model_index import load
from modelindex.models.Model import Model

import mmcls.models # noqa: F401
from mmcls.registry import MODELS


class ModelHub:
"""A hub to host the meta information of all pre-defined models."""
_models_dict = {}

@classmethod
def register_model_index(cls,
model_index_path: Union[str, PathLike],
config_prefix: Union[str, PathLike, None] = None):
"""Parse the model-index file and register all models.

Args:
model_index_path (str | PathLike): The path of the model-index
file.
config_prefix (str | PathLike | None): The prefix of all config
file paths in the model-index file.
"""
model_index = load(str(model_index_path))
model_index.build_models_with_collections()

for metainfo in model_index.models:
model_name = metainfo.name.lower()
if metainfo.name in cls._models_dict:
raise ValueError(
'The model name {} is conflict in {} and {}.'.format(
model_name, osp.abspath(metainfo.filepath),
osp.abspath(cls._models_dict[model_name].filepath)))
metainfo.config = cls._expand_config_path(metainfo, config_prefix)
cls._models_dict[model_name] = metainfo

@classmethod
def get(cls, model_name):
"""Get the model's metainfo by the model name.

Args:
model_name (str): The name of model.

Returns:
modelindex.models.Model: The metainfo of the specified model.
"""
# lazy load config
metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
if metainfo is None:
raise ValueError(f'Failed to find model {model_name}.')
if isinstance(metainfo.config, str):
metainfo.config = Config.fromfile(metainfo.config)
return metainfo

@staticmethod
def _expand_config_path(metainfo: Model,
config_prefix: Union[str, PathLike] = None):
if config_prefix is None:
config_prefix = osp.dirname(metainfo.filepath)

if metainfo.config is None or osp.isabs(metainfo.config):
config_path: str = metainfo.config
else:
config_path = osp.abspath(osp.join(config_prefix, metainfo.config))

return config_path


# register models in mmcls
mmcls_root = Path(get_installed_path('mmcls'))
model_index_path = mmcls_root / '.mim' / 'model-index.yml'
ModelHub.register_model_index(
model_index_path, config_prefix=mmcls_root / '.mim')


def init_model(config, checkpoint=None, device=None, **kwargs):
"""Initialize a classifier from config file.

Args:
config (str | :obj:`mmengine.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
**kwargs: Other keyword arguments of the model config.

Returns:
nn.Module: The constructed model.
"""
if isinstance(config, (str, PathLike)):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if kwargs:
config.merge_from_dict({'model': kwargs})
config.model.setdefault('data_preprocessor',
config.get('data_preprocessor', None))
model = MODELS.build(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmcls 1.x
model.CLASSES = checkpoint['meta']['dataset_meta']['classes']
elif 'CLASSES' in checkpoint.get('meta', {}):
# mmcls < 1.x
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets.categories import IMAGENET_CATEGORIES
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = IMAGENET_CATEGORIES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model


def get_model(model_name, pretrained=False, device=None, **kwargs):
"""Get a pre-defined model by the name of model.

Args:
model_name (str): The name of model.
pretrained (bool | str): If True, load the pre-defined pretrained
weights. If a string, load the weights from it. Defaults to False.
device (str | torch.device | None): Transfer the model to the target
device. Defaults to None.
**kwargs: Other keyword arguments of the model config.

Returns:
mmengine.model.BaseModel: The result model.

Examples:
Get a ResNet-50 model and extract images feature:

>>> import torch
>>> from mmcls import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
... print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])

Get Swin-Transformer model with pre-trained weights and inference:

>>> from mmcls import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'
""" # noqa: E501
metainfo = ModelHub.get(model_name)

if isinstance(pretrained, str):
ckpt = pretrained
elif pretrained:
if metainfo.weights is None:
raise ValueError(
f"The model {model_name} doesn't have pretrained weights.")
ckpt = metainfo.weights
else:
ckpt = None

if metainfo.config is None:
raise ValueError(
f"The model {model_name} doesn't support building by now.")
model = init_model(metainfo.config, ckpt, device=device, **kwargs)
return model


def list_models(pattern=None) -> List[str]:
"""List all models available in MMClassification.

Args:
pattern (str | None): A wildcard pattern to match model names.

Returns:
List[str]: a list of model names.

Examples:
List all models:

>>> from mmcls import list_models
>>> print(list_models())

List ResNet-50 models on ImageNet-1k dataset:

>>> from mmcls import list_models
>>> print(list_models('resnet*in1k'))
['resnet50_8xb32_in1k',
'resnet50_8xb32-fp16_in1k',
'resnet50_8xb256-rsb-a1-600e_in1k',
'resnet50_8xb256-rsb-a2-300e_in1k',
'resnet50_8xb256-rsb-a3-100e_in1k']
"""
if pattern is None:
return sorted(list(ModelHub._models_dict.keys()))
# Always match keys with any postfix.
matches = fnmatch.filter(ModelHub._models_dict.keys(), pattern + '*')
return matches
12 changes: 5 additions & 7 deletions mmcls/models/backbones/timm_backbone.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None

import warnings

from mmengine.logging import MMLogger
Expand Down Expand Up @@ -68,10 +63,13 @@ def __init__(self,
in_channels=3,
init_cfg=None,
**kwargs):
if timm is None:
raise RuntimeError(
try:
import timm
except ImportError:
raise ImportError(
'Failed to import timm. Please run "pip install timm". '
'"pip install dataclasses" may also be needed for Python 3.6.')

if not isinstance(pretrained, bool):
raise TypeError('pretrained must be bool, not str for model path')
if features_only and checkpoint_path:
Expand Down
5 changes: 4 additions & 1 deletion mmcls/models/classifiers/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmcls.registry import MODELS
Expand Down Expand Up @@ -96,7 +97,9 @@ def __init__(self,
**kwargs)
self.model = AutoModelForImageClassification.from_config(config)

self.loss_module = MODELS.build(loss)
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss

self.with_cp = with_cp
if self.with_cp:
Expand Down
Loading