目标检测(MMDetection)-Registry

​ 工作日趋繁忙,上一篇讲HOOK的文章获得了大家的赞同受宠若惊,虽然才疏学浅,今天就接着上次的内容,讲一下MMDetection中注册器Registry,希望能帮助到大家理解这个框架。

1. Registry注册器实现以及作用原理

​ 注册器其实在HOOK中就已经有体现,在MMDection中所有功能都是基于注册器来完成模块化操作的。其中最经典的就是在MMdetection构件模型的build.py中就通过注册器完成模型的模块化。

​ 首先,要明确注册器的使用目的就是为了在算法训练、调参中通过直接更改配置文件,你想从Faster RCNN切换到Retina,或者是更改学习率等,只需要更改配置即可。因为注册器完成了字符串->类的映射,代码会自动解析你config中的内容。下面这个就是在build.py中的实例化的注册器。

from mmcv.utils import Registry, build_from_cfg
from torch import nn

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

​ 其中,模型被拆解成了Backbones、Necks、Roi_extractors、Shared_Heads、Heads、Losses、Detectors几个部分,当时如果我们需要添加更多的模块也是可以的,只需要实例化我们各自的注册器就行。

​ 从代码中也可以看到注册器只是通过一个类完成了string类型->类名的一个映射而已。在Registry代码中,注册类的源码如下。

def _register_module(self, module_class, module_name=None, force=False):
    # 首先判断判断参数是否是Class类别
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
          # 获取类名
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        # 核心就这一句话,是不是超级简单?就是一个dict,key值是string, value是类
        self._module_dict[module_name] = module_class

​ 咋一看将string和Class之间的映射链接起来也不过如此啊,不过就是一个dict罢了,python初学者都能理解的东西。但是高就高在大道至简,看似平凡普通的东西,在大师的手里就是一个利器。

​ 首先,注册类的操作,在MMDetection中使用了装饰器解决了。不懂装饰器的可以先行百度,简单来说。装饰器的功能就是有A和B两个事情要做。但是做B事情的时候你希望能用到A的一些功能来达成你的需求。在Registry.py中,有一个用于装饰器函数来实现这个功能,并且注释也十分浅显易懂。

def register_module(self, name=None, force=False, module=None):
        """Register a module.

        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.

        Example:
        # 这里写了三种将类注册到backbones注册器中的方法。
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass
                # 一个小扩展,如果你不想每个string的名字都是和类名完全一样,你也可以自定义你记得住的名字,但是这样的话你在config中写你的Backbone名字就是'mnet'了
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass
                # 这是不用装饰器的时候,我们的正常操作,当然这显然没够上python优雅简洁的特点。
            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)

        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # raise the error ahead of time
        if not (name is None or isinstance(name, str)):
            raise TypeError(f'name must be a str, but got {type(name)}')
                # 以下方式是用的最多的一种
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

​ 理解了注册器功能、实现代码和使用方式之后,好像对这玩意已经掌握了。但是我有一点需要说明的是,当你自己使用了一个注册器用装饰器添加了之后就可以直接在外部调用了吗?这里有一个小细节,如果你写了一个类并通过注册了,那么需要导入过这个类才会生效,具体来说就是在某个地方import这个类之后才会有用,(不过在我的demo中将自己的类和注册器写在了一个函数里就不用了)。

​ 当然注册器只是完成一个模块化功能,具体怎么把这些模块化的东西组织起来其实就是一些逻辑控制了。需要注意的是torch.nn.Sequential代码的操作并不是把backbone-neck-head组装起来的地方,如下所示,这只是把一些更细粒度的模块组成一起而已。比如配置文件的第44行。

def build(cfg, registry, default_args=None):
    """Build a module.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a dict
            or a list of configs.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        # 注意这只是把一些细节的模块拼在一起
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

​ 其实从设计原则上来说,backbone neck head loss等等应该是要在代码中手动操作的。比如在faster_rcnn中(faster_rcnn继承这个基类),forward的流程是这样的。

def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      proposals=None,
                      **kwargs):
        """
        Args:
            img (Tensor): of shape (N, C, H, W) encoding input images.
                Typically these should be mean centered and std scaled.
            img_metas (list[dict]): list of image info dict where each dict
                has: 'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.
            gt_masks (None | Tensor) : true segmentation masks for each box
                used if the architecture supports a segmentation task.
            proposals : override rpn proposals with custom proposals. Use when
                `with_rpn` is False.
        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        x = self.extract_feat(img)

        losses = dict()

        # RPN forward and loss
        if self.with_rpn:
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
            rpn_losses, proposal_list = self.rpn_head.forward_train(
                x,
                img_metas,
                gt_bboxes,
                gt_labels=None,
                gt_bboxes_ignore=gt_bboxes_ignore,
                proposal_cfg=proposal_cfg)
            losses.update(rpn_losses)
        else:
            proposal_list = proposals

        roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
                                                 gt_bboxes, gt_labels,
                                                 gt_bboxes_ignore, gt_masks,
                                                 **kwargs)
        losses.update(roi_losses)

        return losses

其中self.extract_feat函数是backbone+neck的输出feature map的操作。

def extract_feat(self, img):
        """Directly extract features from the backbone+neck."""
        x = self.backbone(img)
        if self.with_neck:
            x = self.neck(x)
        return x

​ 总而言之,注册器只是提供从配置文件生成实例对象的一种方式,最重要的还是在各个detector中对这些模块的调用,以及深度学习算法训练、测试、inference整个流程的处理。

2. 注册器实战一下

​ 学习需要学会模仿,锻炼需要健康饮食。饭后要吃水果促进肠胃吸收,主食和水果也要互相搭配才能活络肠胃,用注册器来实战一下这个逻辑流程。

​ 首先文件目录如下:

.
├── Registry.py
├── builder.py
├── demo.py
└── lunch.py

Registry.py摘自mmcv,代码较长就不贴了

Builder.py中定义了两种注册器:FRUIT FOOD,以及对应的两个类,通过装饰器的方式进行了注册。

from Registry import Registry, build_from_cfg

FRUIT = Registry('fruit')
FOOD = Registry('food')


def build(cfg, registry, default_args=None):
    return build_from_cfg(cfg, registry, default_args)


def build_fruit(cfg):
    return build(cfg, FRUIT)


def build_food(cfg):
    return build(cfg, FOOD)


@FOOD.register_module()
class Rice():
    def __init__(self, name):
        self.name = name


@FRUIT.register_module()
class Apple():
    def __init__(self, name):
        self.name = name

配置文件lunch.py中,定义了我今天想吃的东西

lunch=dict(
    food=dict(type='Rice', name='东北大米'),
    fruit=dict(type='Apple', name='青苹果')
)

​ Demo.py中为具体的调用,可以看到我的代码里没有import Riceimport Fruit因为为了简单起见我把他们写在了一个函数里,如果你把RiceFruit两个类在另一个python文件中定义的话,就需要导包了。这一点和python的运行流程、包导入规则有关,也是新手经常犯的错误,有时间我也会总结一下这个知识点。

from builder import build_fruit, build_food
from lunch import lunch
class COOKER():
    def __init__(self,food, fruit):
        print('今日饮食清单:{}, {}'.format(food, fruit))
        self.food = build_food(food)
        self.fruit = build_fruit(fruit)
    def run(self):
        print('具体饮食计划')
        print('主食吃: {}'.format(self.food.name))
        print('水果吃: {}'.format(self.fruit.name))

cook = COOKER(**lunch)
cook.run()

运行demo.py输出的结果如下:

今日饮食清单:{'type': 'Rice', 'name': '东北大米'}, {'type': 'Apple', 'name': '青苹果'}
具体饮食计划
主食吃: 东北大米
水果吃: 青苹果

​ 这只是一个简单的示范注册器的例子,但是基本也说明了注册器的使用方式和运行原理,一句话解释注册器就是“为模块化服务的字符串->类的字典”。希望这篇文章可以让你彻底理解它,下一篇盘一下Runner。盘完HOOK,Registry,Runner之后,MMDetection框架的基本要素应该是差不多了,接下来就是盘算法了。anchor-free和anchor-based都统统拿下,除此之外也会整理一些自己在学习python,学习算法的过程中的一些资料,踩过的坑。