为了看懂基于MMDetection/MMDetection3D的目标检测模型代码,有必要先了解一些重要但平时不常用的python基础知识。

1.类的继承

参考:Python中的init和super() - 知乎

python定义类的语句如下:

class ClassName:
    <statements>

也可在类名后加括号,括号内写上另一个已定义类的名称表示新类继承旧类的属性和方法:

class DerivedClassName(BaseClassName): 
    <statements>

这里DerivedClassName称为子类(派生类),BaseClassName称为父类(基类)。

假设我们有一个类Fruit,其定义如下:

class Fruit:
    def __init__(self, name="Apple"):
        self.name = name

以Fruit作为父类,定义下面的子类:

class Apple(Fruit):
    pass

创建该子类的实例:

f = Apple()

使用print函数输出其属性:

print(f.name)     # 输出Apple

可以看到,即使在定义Apple类时没有类似name="Apple"的语句,也能获取其name属性的值。这就是因为Apple类继承了其父类Fruit的属性和初始化方法。

被继承的方法可以重写。例如,新建一个Apple_Init类(仍以Fruit为父类),创建其实例并输出属性:

class Apple_Init(Fruit):
    def __init__(self, color):
        self.color = color

fi = Apple_Init('red')

print(fi.name)    # 该语句会报错说Fruit_Init没有name属性
print(fi.color)   # 输出red

可以看到,由于定义Apple_Init时定义了初始化方法,覆盖了继承自其父类的初始化方法,因此访问属性name失败。

若要同时继承其父类的初始化方法并添加新的属性,则可以使用super()函数,其语法为

super(DerivedClassName, self).BaseClassMethodName(*ArgsOfBaseClassMethod)

表示继承父类BaseClassName的BaseClassMethodName方法(参数为ArgsOfBaseClassMethod)。

仍以Fruit为父类创建另一Apple_Super类,在继承Fruit类初始化方法的基础上添加新的属性:

class Apple_Super(Fruit):
    def __init__(self, name, color):
        self.color = color
        super(Apple_Super, self).__init__(name)

fs = Apple_Super('Apple','red')

print(fs.name)    # 输出Apple
print(fs.color)   # 输出red

可见Apple_Super类的实例同时拥有父类属性name和子类属性color。

2.函数修饰符@

参考:mmdetecion 中类注册的实现(@x.register_module())

假设现在有函数func1,以函数为参数:

def func1(fn):
    fn()
    print(1)

假设现在有另一函数func2,功能是在屏幕上打印“2”。我们使用func1修饰func2:

@func1
def func2():
    print(2)

然后执行func2,可观察到依次输出2和1。实际上,经过@的修饰,func2与下面的语句等价:

def func1(fn):
    fn()
    print(1)

def func2():
    print(2)

func1(func2)

类似的,若func以类为输入和返回值,修饰类my_class:

def func(cls):
    print(0)
    return cls

@func
class my_class():
    def __init__(self):
        ...

直接运行上述程序(注意该程序并没有对类进行实例化),会发现屏幕输出0,说明执行了func中的语句。

在MMDetection中,自定义模型时用到的@x.register_module()即在运行时调用注册函数(维护一个模块列表,在搭建(build)时从配置文件的type字段取出类别名,然后将剩余字段传入该类中进行初始化)。

3.*args和**kwargs

若定义函数时允许函数接收可变数量的参数,可以使用*args或**kwargs。二者区别在于:

(1)*args会将非键值对参数打包为元组。例如

def func(*args):
    print(args)

func("name", "color")  # 输出('name', 'color')

(2)**kwargs会将键值对参数打包为字典。例如

def func(**kwargs):
    print(kwargs)

func(name='apple',color='red')  # 输出{'name': 'apple', 'color': 'red'}

注意,在同时使用*args和**kwargs时,*args必须放在**kwargs前面。此外,args和kwargs的名称可以随意修改。
类似地,在传参时,也可以使用*和**,会自动将传入的列表/元组以及字典分开。例如:

def func(name,color):
    print(name,color)

func_inputs = ['apple','red']
func(*func_inputs)    # 输出:apple red

附:其他一些pytorch中可能会遇见的不常用操作
(1)None作为tensor索引,如

a = torch.zeros(2,3)
print(a[:,None,None].shape)   # 输出torch.Size([2,1,1,3])

即None作为索引时相对于在相应的维度进行一次unsqueeze操作。
(2)@运算符,即矩阵乘法。

a = torch.zeros(2,3)
b = torch.zeros(3,2)
c = a @ b    # 即a和b的矩阵乘积