PyTorch1.9版本正式发布了torch.fx预览版,而在1.10版本发布了稳定版本,torch.fx这个工具包的主要功能是实现对nn.Module实例的变换,或者说用来操作模型。听起来这个功能有点怪,但是如果你深入了解的话,你会发现其实它的用处还不少。这里先简单介绍一下torch.fx核心概念,然后介绍它的几个具体应用实例。 ​

torch.fx主要有3个组件:符号追踪器(symbolic tracer),中间表示(intermediate representation), Python代码生成(Python code generation)。为了快速理解,这里给出一个简单的例子:

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# 符号追踪这个模块
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# 中间表示
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# 生成代码
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪器对模块的forward代码进行符号执行,它送入的是假的输入,叫做Proxies,代码中的所有operations都被记录下来。这个过程和TensorFlow构建静态图有点类似,Proxies类似placeholder。 ​

这个追踪最终可以得到代码计算图的中间表示:torch.fx.Graph,和TensorFlow的Graph基本是相同的概念。Graph记录了所有的operations,具体的,一个Graph包括一系列的torch.fx.Node,Node是Graph的最基本单元,它对应的是一个operation,Node.op记录的具体的操作类型,主要包括以下几种类型:placeholder,get_attr,call_function,call_module,call_method,output。这里通过一个例子来理解这几个具体的类型: ​

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

# 打印graph的所有node
gm.graph.print_tabular()

opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_attr       linear_weight  linear.weight                                            ()                  {}
call_function  add            <built-in function add>                                  (x, linear_weight)  {}
call_module    linear         linear                                                   (add,)              {}
call_method    relu           relu                                                     (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x7f98c972d200>   (relu,)             {'dim': -1}
call_function  topk           <built-in method topk of type object at 0x7f98c972d200>  (sum_1, 3)          {}
output         output         output                          

可以看到每个Node除了op之外,还有name,target,args和kwargs,对于不同的op其含义有点区别。placeholder其实就是graph的输入,而output是graph的输出,它们的target和name一样。get_attr其实就是获取module的参数;call_function是调用函数,它的target指明了具体的函数;call_module是调用子module,target就是子module名;call_method是调用torch的函数。args和kwargs其实就是op对应的tuple和dict参数,可以看到很多ops的args其实就是其它Node的name,所以这样各个Node就是建立了联系,从而构成了Graph。 ​

最后一个组件,就是用于Python代码生成,就是根据Graph的语义自动生成相应的执行代码。 ​

直观看起来,torch.fx做的就是将一个Module转换为静态图,这和转换Module有什么关系。试想一下,如果我们将一个Module追踪得到的Graph进行变换,加上Python代码生成工具,是不是就可以达到变换一个Module的目的。这整个流程就是:symbolic tracing -> intermediate representation -> transforms -> Python code generation。这就实现了一个Module到另外一个Module的Python-to-Python转换流程。整个代码流程如下所示:

import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # 首先得到模块的graph
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)
 
    # 然后对graph做一些修改操作
    # Step 2: Modify this Graph or create a new one
    graph = ...
 
    # 最后用新得到的graph构建新的模块
    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

这里最终得到的torch.fx.GraphModule除了包含graph和code属性外就和正常的nn.Module一样,它的forward执行的就是graph的语义代码。这里来看一个修改Module的简单例子,这个例子中我们将模块中所有的torch.add()操作替换成 torch.mul() :

import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

最后,torch.fx的应用场景是什么呢,其实还真没有不少需要用来变换模块的场景。目前,官方文档里面已经给出了一些具体的应用案例:

比如Conv/Batch Norm fusion,我们知道在推理阶段将BN融合到Conv里合成一个操作可以加速推理速度,那么torch.fx就很容易实现这个功能,具体的代码实现如下:

# Works for length 2 patterns with 2 modules
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
    if len(node.args) == 0:
        return False
    nodes: Tuple[Any, fx.Node] = (node.args[0], node)
    for expected_type, current_node in zip(pattern, nodes):
        if not isinstance(current_node, fx.Node):
            return False
        if current_node.op != 'call_module':
            return False
        if not isinstance(current_node.target, str):
            return False
        if current_node.target not in modules:
            return False
        if type(modules[current_node.target]) is not expected_type:
            return False
    return True


def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
    assert(isinstance(node.target, str))
    parent_name, name = _parent_name(node.target)
    modules[node.target] = new_module
    setattr(modules[parent_name], name, new_module)

def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
    """
    Fuses convolution/BN layers for inference purposes. Will deepcopy your
    model by default, but can modify the model inplace as well.
    """
    patterns = [(nn.Conv1d, nn.BatchNorm1d),
                (nn.Conv2d, nn.BatchNorm2d),
                (nn.Conv3d, nn.BatchNorm3d)]
    if not inplace:
        model = copy.deepcopy(model)
    fx_model = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    new_graph = copy.deepcopy(fx_model.graph)

    for pattern in patterns:
        for node in new_graph.nodes:
            # 找到目标Node:args是Conv,target是BN
            if matches_module_pattern(pattern, node, modules):
                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
                    continue
                conv = modules[node.args[0].target]
                bn = modules[node.target]
                # 融合BN和Conv
                fused_conv = fuse_conv_bn_eval(conv, bn)
                # 替换Node的module,其实就是将融合后的module替换Conv Node的target,背后是模块替换
                replace_node_module(node.args[0], modules, fused_conv)
                # 将所有用到BN Node的替换为Conv Node(已经融合后的Conv)
                node.replace_all_uses_with(node.args[0])
                # 删除BN Node
                new_graph.erase_node(node)
    return fx.GraphModule(fx_model, new_graph)

from torch.fx.experimental.optimization import fuse
from torchvision.models import resnet18

model = resnet18()
model.eval() # 必须在eval模型下fuse

'''
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
'''

fused_model = fuse(model)

'''
  (layer4): Module(
    (0): Module(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (downsample): Module(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
      )
    )
    (1): Module(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
'''

torch.fx还有一个比较实用的使用场景,那就是对模型进行特征提取,比如我们希望得到模型中间特征用来分析,或者用一些中间特征用于构建其它模型,比如检测和分割模型。目前torchvision已经支持了这项功能:feature_extraction。这里也给出一个简单的代码例子,比如我们要提取ResNet模型的C4和C5特征: ​

import torchvision
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

# 构建模型
model = torchvision.models.resnet50()

# 获取模型的所有的nodes
train_nodes, eval_nodes = get_graph_node_names(model)

# 定义输出node
return_nodes = {'layer3.5.relu_2': 'C4', 'layer4.2.relu_2': 'C5'}

# 进行重建
n_model = create_feature_extractor(model, return_nodes)

out = n_model(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
[('C4', torch.Size([1, 1024, 14, 14])), ('C5', torch.Size([1, 2048, 7, 7]))]

注意这里新得到的model会移除所有和输出Node不相关的Node,比如这里的最后的fc层是没有的。

最后要说的是torch.fx文档里面还有更多的内容,这里只是简单了介绍最主要的部分,感兴趣的可以再深入地看一下文档。