pytorch使用过程中,经常需要对张量进行维度校准。我们简单把一个2x2的tensor升维到1x2x2,该怎么做呢?

方法一:a[None]

import torch
a = torch.rand(2, 2)
b = a[None]
print(b.shape)
# torch.Size([1, 2, 2])

方法二:a.unsqueeze(0)

import torch
a = torch.rand(2, 2)
b = a.unsqueeze(0)
print(b.shape)
# torch.Size([1, 2, 2])

虽然方法二比方法一更繁琐,但比方法一更加灵活。如果你想升维度2x2->2x1x2,可以:

b = a.unsqueeze(1)
print(b.shape)
# torch.Size([2, 1, 2])

同理,a.unsqueeze(2)也可以升维成(2,2,1)。方法一的优势是,看起来更老练。