学习到了很多有用的东西。MMdetetion是现在最著名、算法包最多并且使用人数最多的训练框架,其中的源码非常值得学习,今天总结下我对其中HOOK(钩子)机制的理解。

​ MMdetection最近更新很多,我以2.4.0版本的代码进行解读,分享自己的理解,也吸纳观众的点评。HOOK、Runer的定义在MMCV当中,MMdetection和MMCV是版本匹配的,我这里使用的是MMCV 1.1.2的代码。(HOOK相关的定义主要在MMCV中,下面用的代码都是摘自于MMCV)。

1.HOOK机制的作用

​ MMdetection中的HOOK可以理解为一种触发器,也可以理解为一种训练框架的架构规范,它规定了在算法训练过程中的种种操作,并且我们可以通过继承HOOK类,然后注册HOOK自定义我们想要的操作。

​ 首先看一下HOOK的基类定义

# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import Registry

HOOKS = Registry('hook')


class Hook:

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

    def before_train_epoch(self, runner):
        self.before_epoch(runner)

    def before_val_epoch(self, runner):
        self.before_epoch(runner)

    def after_train_epoch(self, runner):
        self.after_epoch(runner)

    def after_val_epoch(self, runner):
        self.after_epoch(runner)

    def before_train_iter(self, runner):
        self.before_iter(runner)

    def before_val_iter(self, runner):
        self.before_iter(runner)

    def after_train_iter(self, runner):
        self.after_iter(runner)

    def after_val_iter(self, runner):
        self.after_iter(runner)

    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False

    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False

    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False

    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

​ 可以说基类函数中定义了许多我们在模型训练中需要用到的一些功能,如果想定义一些操作我们就可以继承这个类并定制化我们的功能,可以看到HOOK中每一个参数都是有runner作为参数传入的。关于Runner的作用下一篇文章接着说,简而言之,Runner是一个模型训练的工厂,在其中我们可以加载数据、训练、验证以及梯度backward等等全套流程。MMdetection在设计的时候也为runner传入丰富的参数,定义了一个非常好的训练范式。在你的每一个hook函数中,都可以对runner进行你想要的操作。

​ 而HOOK是怎么嵌套进runner中的呢?其实是在Runner中定义了一个hook的list,list中的每一个元素就是一个实例化的HOOK对象。其中提供了两种注册hook的方法,register_hook是传入一个实例化的HOOK对象,并将它插入到一个列表中,register_hook_from_cfg是传入一个配置项,根据配置项来实例化HOOK对象并插入到列表中。当然第二种方法又是MMLab的开源生态中定义的一种基础方法mmcv.build_from_cfg了,无论在MMdetection还是其他MMLab开源的算法框架中,都遵循着MMCV的这套基于配置项实例化对象的方法。毕竟MMCV是提供了一个基础的功能,服务于各个算法框架,这也是为什么MMLab的代码高质量的原因。不仅仅是算法的复现,更是架构、编程范式的一种体现,真·代码如诗

def register_hook(self, hook, priority='NORMAL'):
        """Register a hook into the hook list.
        The hook will be inserted into a priority queue, with the specified
        priority (See :class:`Priority` for details of priorities).
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.
        Args:
            hook (:obj:`Hook`): The hook to be registered.
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
        """
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
        priority = get_priority(priority)
        hook.priority = priority
        # insert the hook to a sorted list
        inserted = False
        # hook是分优先级插入到list中的,在MMdetection中不同的HOOK是有优先级的,为什么呢?稍后在hook的调用中解释哈
        for i in range(len(self._hooks) - 1, -1, -1):
            if priority >= self._hooks[i].priority:
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self._hooks.insert(0, hook)

    def register_hook_from_cfg(self, hook_cfg):
        """Register a hook from its cfg.
        Args:
            hook_cfg (dict): Hook config. It should have at least keys 'type'
              and 'priority' indicating its type and priority.
        Notes:
            The specific hook class to register should not use 'type' and
            'priority' arguments during initialization.
        """
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)

​ 调用HOOK函数

def call_hook(self, fn_name):
        """Call all hooks.
        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

​ 可以看到HOOK是调用的时候是遍历List,然后根据HOOK的名字来调用。这也是为什么要区分优先级的原因,优先级越高的放在List的前面,这样就能更快地被调用。当你想用before_run_epoch来做A和B两件事情的时候,在runner里面就是调用一次self.before_run_epoch,但是先做A还是先做B,就是通过不同的HOOK的优先级来决定了。比如在evaluation的时候对需要做测试,但是测试前对参数做滑动平均。比如emaHOOK中的72行,也写明了要在测试之前做指数滑动平均。

def after_train_epoch(self, runner):
        """We load parameter values from ema backup to model before the
        EvalHook."""
        self._swap_ema_parameters()

checkpoint.py的HOOK中,同样也定义了after_train_epoch函数如下:

@master_only
    def after_train_epoch(self, runner):
        if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
            return

        runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
        if not self.out_dir:
            self.out_dir = runner.work_dir
        runner.save_checkpoint(
            self.out_dir, save_optimizer=self.save_optimizer, **self.args)

        # remove other checkpoints
        if self.max_keep_ckpts > 0:
            filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth')
            current_epoch = runner.epoch + 1
            for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):
                ckpt_path = os.path.join(self.out_dir,
                                         filename_tmpl.format(epoch))
                if os.path.exists(ckpt_path):
                    os.remove(ckpt_path)
                else:
                    break

​ 从测试代码中可以看到不同的HOOK虽然都是重写了after_train_epoch函数,但是调用的顺序还是先调用ema.py中的,然后再调用checkpoint.py中的after_train_epoch

resume_ema_hook = EMAHook(
        momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
    runner = _build_demo_runner()
    runner.model = demo_model
    # 设置了HIGHREST的优先级
    runner.register_hook(resume_ema_hook, priority='HIGHEST')
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(checkpointhook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 2)

具体的优先级定义有以下7种,作为HOOK的类成员属性。具体定义在链接中。

+------------+------------+
    | Level      | Value      |
    +============+============+
    | HIGHEST    | 0          |
    +------------+------------+
    | VERY_HIGH  | 10         |
    +------------+------------+
    | HIGH       | 30         |
    +------------+------------+
    | NORMAL     | 50         |
    +------------+------------+
    | LOW        | 70         |
    +------------+------------+
    | VERY_LOW   | 90         |
    +------------+------------+
    | LOWEST     | 100        |
    +------------+------------+

2.举一个简单的例子

​ 最近打算好好锻炼身体,健康生活,努力工作,我打算让自己变得更加自律。我给自己定下了几个条例,每天吃早饭之前得晨练30分钟,运动完之后才会感觉充满活力。每天吃午饭之前我得跑上一个实验,吃完饭之后回来刚好可以看下中间结果,吃完午饭之后我感觉结果没问题我需要午休30分钟, 晚上下班前我如果没什么事再锻炼30分钟。秉承着这样的原则我给自己定义一个HOOK来规范我的生活。

  • 定义我的HOOK
import sys
class HOOK:

    def before_breakfirst(self, runner):
        print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))

    def after_breakfirst(self, runner):
        print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))

    def before_lunch(self, runner):
        print('{}:吃午饭之前跑上实验'.format(sys._getframe().f_code.co_name))

    def after_lunch(self, runner):
        print('{}:吃完午饭午休30分钟'.format(sys._getframe().f_code.co_name))

    def before_dinner(self, runner):
        print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))

    def after_dinner(self, runner):
        print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))

    def after_finish_work(self, runner, are_you_busy=False):
        if are_you_busy:
            print('{}:今天事贼多,还是加班吧'.format(sys._getframe().f_code.co_name))
        else:
            print('{}:今天没啥事,去锻炼30分钟'.format(sys._getframe().f_code.co_name))
  • 定义我的Runner
class Runner(object):
    def __init__(self, ):
        pass
        self._hooks = []

    def register_hook(self, hook):
        # 这里不做优先级判断,直接在头部插入HOOK
        self._hooks.insert(0, hook)

    def call_hook(self, hook_name):
        for hook in self._hooks:
            getattr(hook, hook_name)(self)

    def run(self):
        print('开始启动我的一天')
        self.call_hook('before_breakfirst')
        self.call_hook('after_breakfirst')
        self.call_hook('before_lunch')
        self.call_hook('after_lunch')
        self.call_hook('before_dinner')
        self.call_hook('after_dinner')
        self.call_hook('after_finish_work')
        print('~~睡觉~~')
  • 运行main函数,注册HOOK并且调用Runner.run()开启我的一天
from MyHook import HOOK
from MyRunner import Runner
runner = Runner()
hook = HOOK()
runner.register_hook(hook)
runner.run()
  • 得到的输出结果如下:
开始启动我的一天
before_breakfirst:吃早饭之前晨练30分钟
after_breakfirst:吃早饭之前晨练30分钟
before_lunch:吃午饭之前跑上实验
after_lunch:吃完午饭午休30分钟
before_dinner: 没想好做什么
after_dinner: 没想好做什么
after_finish_work:今天没啥事,去锻炼30分钟
~~睡觉~~

3.总结

​ MMdetection中的HOOK设计巧妙,很好地对算法训练、测试进行了抽象和解耦。每一个做上层算法模型的,都值得一看。感谢MMLab贡献这么优质的代码,让我等凡夫俗子醍醐灌顶。

​ 除了HOOK之外,这个代码中还有很多优质的思想。比如Runner是怎么做到包办一切的?注册器这个中枢管理系统是怎么工作的?多卡训练的一些坑是怎么解决的?等等等等,我也在持续地学习和消化。路漫漫其修远兮,吾将上下而求索。

​ 一个小题目:我的代码中每个函数输出的时候都会打印出这个函数名,这个可以用装饰器很方便地解决奥。装饰器这个东西在MMLab的系列项目中有大量的应用。其中对fp16的支持让大家赞不绝口。接下来有时间,对Runner、Register、装饰器这些东西好好盘一盘。