问题背景
self.TowerA = xxx
self.TowerB = xxx
···
程序实现
class Tower(nn.Module): #Tower模型结构
def __init__(self):
super(Tower, self).__init__()
p = 0
self.tower = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU(),
nn.Dropout(p),
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(p),
nn.Linear(32, 1)
)
def forward(self, x):
out = self.tower(x)
return out
class SharedBottom(nn.Module):
def __init__(self,feature_size,n_task):
super(SharedBottom, self).__init__()
self.n_task = n_task
p = 0
self.sharedlayer = nn.Sequential(
nn.Linear(feature_size, 128),
nn.ReLU(),
nn.Dropout(p),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(p)
)
'''下面三种定义方式等价'''
#方法1
# self.tower1 = Tower()
# self.tower2 = Tower()
# ···
'''方法2和方法3为批量定义的写法,将所有tower存入一个列表中,方便推理'''
#方法2
# self.towers = [Tower() for i in range(n_task)]
# for i in range(n_task):
# setattr(self, "tower"+str(i+1), self.towers[i]) #语法:setattr(object, name, value)
#方法3
for i in range(n_task):
setattr(self, "tower"+str(i+1), Tower()) #语法:setattr(object, name, value)
self.towers = [getattr(self,"tower"+str(i+1)) for i in range(n_task)] #语法:getattr(object, name)
def forward(self, x):
h_shared = self.sharedlayer(x)
#如果像方式一那样定义,那么推理时就需要按以下方式,很麻烦
# out1 = self.tower1(h_shared)
# out2 = self.tower2(h_shared)
# ···
#将所有tower存入列表中,即可用循环来实现推理,len(out)=n_task
out = [tower(h_shared) for tower in self.towers]
return out
Model = SharedBottom(feature_size=32, n_task=2) #feature_size表示输入特征数量,n_task表示任务数量
print(Model)
当n=2时,输出为:
SharedBottom(
(sharedlayer): Sequential(
(0): Linear(in_features=32, out_features=128, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=128, out_features=64, bias=True)
(4): ReLU()
(5): Dropout(p=0, inplace=False)
)
(tower1): Tower(
(tower): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=64, out_features=32, bias=True)
(4): ReLU()
(5): Dropout(p=0, inplace=False)
(6): Linear(in_features=32, out_features=1, bias=True)
)
)
(tower2): Tower(
(tower): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=64, out_features=32, bias=True)
(4): ReLU()
(5): Dropout(p=0, inplace=False)
(6): Linear(in_features=32, out_features=1, bias=True)
)
)
)
当n=3时,输出为:
SharedBottom(
(sharedlayer): Sequential(
(0): Linear(in_features=32, out_features=128, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=128, out_features=64, bias=True)
(4): ReLU()
(5): Dropout(p=0, inplace=False)
)
(tower1): Tower(
(tower): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=64, out_features=32, bias=True)
(4): ReLU()
(5): Dropout(p=0, inplace=False)
(6): Linear(in_features=32, out_features=1, bias=True)
)
)
(tower2): Tower(
(tower): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=64, out_features=32, bias=True)
(4): ReLU()
(5): Dropout(p=0, inplace=False)
(6): Linear(in_features=32, out_features=1, bias=True)
)
)
(tower3): Tower(
(tower): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=64, out_features=32, bias=True)
(4): ReLU()
(5): Dropout(p=0, inplace=False)
(6): Linear(in_features=32, out_features=1, bias=True)
)
)
)
评论(0)
您还未登录,请登录后发表或查看评论