神经网络模型训练时,有时候我们需要共享不同模型之间的网络参数,下面我将以一个案例展示一下如何共享模型训练参数。

⭐参数共享模块的模型结构必须完全一致才能实现参数共享

一. 指定共享某一模块

假设我们有以下两个模型:

class ANN1(nn.Module):
    def __init__(self,features):
        super(ANN1, self).__init__()
        self.features = features
        self.nn_same = torch.nn.Sequential(
            nn.Linear(features, 128),
            torch.nn.ReLU(),
        )
        self.nn_diff = torch.nn.Sequential(
            nn.Linear(128, 1)
        )

    def forward(self, x):
        # x(batch_size, features)
        x = self.nn_same(x)
        x = self.nn_diff(x)
        return x
class ANN2(nn.Module):
    def __init__(self,features):
        super(ANN2, self).__init__()
        self.features = features
        self.nn_same = torch.nn.Sequential(
            nn.Linear(features, 128),
            torch.nn.ReLU(),
        )
        self.nn_diff = torch.nn.Sequential(
            nn.Linear(128, 1)
        )

    def forward(self, x):
        # x(batch_size, features)
        x = self.nn_same(x)
        x = self.nn_diff(x)
        return x

model1 = ANN1(10)
model2 = ANN2(10)
print(model1)
print(model2)
ANN1(
  (nn_same): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
  )
  (nn_diff): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
  )
)
ANN2(
  (nn_same): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
  )
  (nn_diff): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
  )
)

其中 nn_same 代表要共享参数的模块,模块名称可以不相同,但是模块结构必须完全相同。

因为模型初始化时参数是随机初始化的,所以两个模型的参数肯定不相同。假如我们要将 model1nn_same 模块的参数迁移到 model2 中的 nn_same 中,首先看一下 model1.nn_same 的参数:

for param_tensor in model1.nn_same.state_dict():#输出迁移前的参数
    print(param_tensor, "\t", model1.nn_same.state_dict()[param_tensor])
0.weight      tensor([[ 0.1321, -0.0178,  0.1631,  ..., -0.2531, -0.1584,  0.0588],
        [-0.2466, -0.0381,  0.2394,  ..., -0.2924, -0.1267, -0.1791],
        [-0.1713, -0.0716,  0.0598,  ...,  0.1655, -0.1947,  0.0927],
        ...,
        [-0.1795, -0.3082, -0.2846,  ...,  0.2588, -0.0998, -0.1285],
        [-0.2739, -0.1587,  0.1803,  ..., -0.1905, -0.2832, -0.0724],
        [ 0.1375, -0.1854, -0.1928,  ...,  0.1470,  0.2928,  0.1385]])
0.bias      tensor([-0.2251, -0.3036,  0.2147, -0.0798, -0.1079, -0.0396, -0.1078,  0.1006,
        -0.1884, -0.0616,  0.0698,  0.0044,  0.1615, -0.2090,  0.0584, -0.0743,
         ···,
         0.3010, -0.1674,  0.0982,  0.2267, -0.0865, -0.1350, -0.2501,  0.1475,
         0.0187,  0.0819,  0.1840, -0.0988,  0.0133, -0.2082,  0.0376,  0.2993])

下面我们进行参数迁移:

print("****************迁移前*****************")
for param_tensor in model2.nn_same.state_dict():#输出迁移前的参数
    print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])

model_nn_same = model1.nn_same.state_dict() ##获取model的nn_same部分的参数
model2.nn_same.load_state_dict(model_nn_same,strict=True) #更新model2 nn_same部分的参数,#更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)

print("****************迁移后*****************")
for param_tensor in model2.nn_same.state_dict():#输出迁移后的参数
    print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
#此时nn_same参数更新,nn_diff2参数不变
****************迁移前*****************
0.weight      tensor([[-0.1030, -0.0111,  0.0989,  ..., -0.3142, -0.0167,  0.0485],
        [ 0.1671,  0.2833,  0.1353,  ...,  0.1657, -0.2497, -0.1680],
        [ 0.0470,  0.1208,  0.1707,  ..., -0.0018,  0.2497,  0.0419],
        ...,
        [-0.2406, -0.2757,  0.2527,  ..., -0.0888, -0.2772,  0.1019],
        [-0.3035, -0.0227, -0.0194,  ...,  0.1280, -0.1167,  0.1060],
        [ 0.0565,  0.1870, -0.2729,  ..., -0.1215,  0.1343, -0.1057]])
0.bias      tensor([ 0.0855,  0.3137,  0.2336, -0.2197,  0.0132, -0.1812, -0.1490, -0.1348,
         0.1027,  0.0284,  0.1064,  0.2046,  0.1106, -0.2034, -0.1283, -0.1561,
         ···,
         0.0328, -0.1035, -0.2942, -0.2368, -0.2290,  0.1846, -0.0270,  0.1286,
        -0.2331,  0.1111,  0.2172, -0.2865,  0.2086, -0.1388, -0.2077, -0.2976])
****************迁移后*****************
0.weight      tensor([[ 0.1321, -0.0178,  0.1631,  ..., -0.2531, -0.1584,  0.0588],
        [-0.2466, -0.0381,  0.2394,  ..., -0.2924, -0.1267, -0.1791],
        [-0.1713, -0.0716,  0.0598,  ...,  0.1655, -0.1947,  0.0927],
        ...,
        [-0.1795, -0.3082, -0.2846,  ...,  0.2588, -0.0998, -0.1285],
        [-0.2739, -0.1587,  0.1803,  ..., -0.1905, -0.2832, -0.0724],
        [ 0.1375, -0.1854, -0.1928,  ...,  0.1470,  0.2928,  0.1385]])
0.bias      tensor([-0.2251, -0.3036,  0.2147, -0.0798, -0.1079, -0.0396, -0.1078,  0.1006,
        -0.1884, -0.0616,  0.0698,  0.0044,  0.1615, -0.2090,  0.0584, -0.0743,
         ···,
         0.3010, -0.1674,  0.0982,  0.2267, -0.0865, -0.1350, -0.2501,  0.1475,
         0.0187,  0.0819,  0.1840, -0.0988,  0.0133, -0.2082,  0.0376,  0.2993])

可以看到 model2nn_same 模块的参数已经与 model1nn_same 模块的参数一致。

二. 共享所有相同名称的模块

假设我们有以下两个模型:

class ANN1(nn.Module):
    def __init__(self,features):
        super(ANN1, self).__init__()
        self.features = features
        self.nn_same1 = torch.nn.Sequential(
            nn.Linear(features, 128),
            torch.nn.ReLU(),
        )

        self.nn_same2 = torch.nn.Sequential(
            nn.Linear(features, 128),
            torch.nn.ReLU(),
        )

        self.nn_diff1 = torch.nn.Sequential(
            nn.Linear(128, 1)
        )

    def forward(self, x):
        # x(batch_size, features)
        x = self.nn_same(x)
        x = self.nn_diff(x)
        return x
class ANN2(nn.Module):
    def __init__(self,features):
        super(ANN2, self).__init__()
        self.features = features
        self.nn_same1 = torch.nn.Sequential(
            nn.Linear(features, 128),
            torch.nn.ReLU(),
        )

        self.nn_same2 = torch.nn.Sequential(
            nn.Linear(features, 128),
            torch.nn.ReLU(),
        )

        self.nn_diff2 = torch.nn.Sequential(
            nn.Linear(128, 1)
        )

    def forward(self, x):
        # x(batch_size, features)
        x = self.nn_same(x)
        x = self.nn_diff(x)
        return x

model1 = ANN1(10)
model2 = ANN2(10)
print(model1)
print(model2)
ANN1(
  (nn_same1): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
  )
  (nn_same2): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
  )
  (nn_diff1): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
  )
)
ANN2(
  (nn_same1): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
  )
  (nn_same2): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
  )
  (nn_diff2): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
  )
)

假如我们要将 model1nn_same1nn_same2 模块的参数迁移到 model2 中的 nn_same1nn_same2 中:

print("****************迁移前*****************")
for param_tensor in model2.state_dict():#输出迁移前的参数
    print(param_tensor, "\t", model2.state_dict()[param_tensor])

model_all = model1.state_dict() ##获取model的所有的参数
model2.load_state_dict(model_all,strict=False) #更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)

print("****************迁移后*****************")
for param_tensor in model2.state_dict():#输出迁移后的参数
    print(param_tensor, "\t", model2.state_dict()[param_tensor])
#此时nn_same参数更新,nn_diff2参数不变

其中需要注意的是在model2.load_state_dict(mode_all,strict=False)strict=False,表示两个模型的模块名不需要完全匹配,只会更新名称相同的模块。如果两个模型的模块名不完全相同但是strict=True那么就会报错:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-56-069ae53e28f3> in <module>
      4 
      5 model_all = model1.state_dict() ##获取model的所有的参数
----> 6 model2.load_state_dict(model_all,strict=True) #更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
      7 
      8 print("****************迁移后*****************")

D:\Anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in load_state_dict(self, state_dict, strict)
   1481         if len(error_msgs) > 0:
   1482             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1483                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1484         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1485 

RuntimeError: Error(s) in loading state_dict for ANN2:
    Missing key(s) in state_dict: "nn_diff2.0.weight", "nn_diff2.0.bias". 
    Unexpected key(s) in state_dict: "nn_diff1.0.weight", "nn_diff1.0.bias".