深度学习PyTorch笔记(3):Tensor的索引

这是《动手学深度学习》(PyTorch版)(Dive-into-DL-PyTorch)的学习笔记,里面有一些代码是我自己拓展的。

其他笔记在专栏 深度学习 中。

1.2.2 索引

裁剪

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
torch.clamp(x, 2, 7)  #对x进行在2和7之间的裁剪
tensor([[2, 2, 3],
       [4, 5, 6],
       [7, 7, 7]])

x[行切片,列切片]

x = torch.tensor([[1,2,3,4], [3,4,5,6], [0,9,0,1], [8,2,1,3]])
print(x)
tensor([[1, 2, 3, 4],
        [3, 4, 5, 6],
        [0, 9, 0, 1],
        [8, 2, 1, 3]])
print(x[1:,])  #逗号前面是行切片,索引为1的行开始切片到最后,当对列没有修改时,逗号可省略。x[1:]=x[1:,]=x[1:,:]
tensor([[3, 4, 5, 6],
        [0, 9, 0, 1],
        [8, 2, 1, 3]])
print(x[1:3,])  #只切片索引为1的行
tensor([[3, 4, 5, 6],
        [0, 9, 0, 1]])
print(x[1:3,1:3])  #切片索引为1、2和行和列
tensor([[4, 5],
        [9, 0]])
print(x[:,1:3])  #但是这里如果是x[,1:3]会报错
tensor([[2, 3],
        [4, 5],
        [9, 0],
        [2, 1]])
print(x[::2, ::3])  #跳着访问,第0行和第2行,第0列和第3列
tensor([[1, 4],
        [0, 1]])
print(x[-1])  #可以用负索引
tensor([8, 2, 1, 3])
x[1, 2] = 10  #根据索引更改
print(x)
x[0:2, :] = 12
print(x)
tensor([[ 1,  2,  3,  4],
        [ 3,  4, 10,  6],
        [ 0,  9,  0,  1],
        [ 8,  2,  1,  3]])
tensor([[12, 12, 12, 12],
        [12, 12, 12, 12],
        [ 0,  9,  0,  1],
        [ 8,  2,  1,  3]])

可以用切片来修改:

x = torch.tensor([[1,2,3,4], [3,4,5,6], [0,9,0,1], [8,2,1,3]])
y = x[0,]
print(x)
print(y)

y += 10
print(y)
print(x[0])  

x[0] -= 5
print(x[0])
print(y)
tensor([[1, 2, 3, 4],
        [3, 4, 5, 6],
        [0, 9, 0, 1],
        [8, 2, 1, 3]])
tensor([1, 2, 3, 4])
tensor([11, 12, 13, 14])
tensor([11, 12, 13, 14])
tensor([6, 7, 8, 9])
tensor([6, 7, 8, 9])

  • 这里需要注意!!!索引出来的结果与原数据内存共享,修改了一个,另一个也修改,所以后面的结果是同步的

PyTorch还提供了一些高级的选择函数:

#index_select(input, dim, index):在指定维度dim上选取(dim=0,按行选取;dim=1,按列选取),index是索引的序号。
x = torch.tensor([[1,2,3,4], [3,4,5,6], [0,9,0,1], [8,2,1,3]])
print(x)
y = torch.index_select(x, 0, torch.tensor([0, 2]))  #按行选取,索引第0行和第2行
print(y)
z = torch.index_select(x, 1, torch.tensor(1))  #按列选取,第1列
print(z)
tensor([[1, 2, 3, 4],
        [3, 4, 5, 6],
        [0, 9, 0, 1],
        [8, 2, 1, 3]])
tensor([[1, 2, 3, 4],
        [0, 9, 0, 1]])
tensor([[2],
        [4],
        [9],
        [2]])
#masked_select(input, mask):mask取出的是布尔值索引(掩码)(即真为1,假为0),然后根据取出的非0掩码从中取值
x = torch.tensor([[0,2,4], [1,3,5]])
print(x)
y = torch.masked_select(x, x<5)  #x<5时,布尔值索引是[1,1,1,1,1,0],所以取出[0, 2, 4, 1, 3]
print(y)
tensor([[0, 2, 4],
        [1, 3, 5]])
tensor([0, 2, 4, 1, 3])
#nonzero(input):取出非0元素的下标
x = torch.tensor([[0,2,4], [1,0,5]])
print(torch.nonzero(x))
tensor([[0, 1],
        [0, 2],
        [1, 0],
        [1, 2]])
#gather(input, dim, index):根据index,在dim维度上选取数据,输出的size与index一样
x = torch.tensor([[0,2,4], [1,3,5], [2,1,0]])
print(x)
index = torch.tensor([[2,0], [1,2], [0,1]])
print(index)

a = torch.gather(x, 0, index)
b = torch.gather(x, 1, index)
print(a)
print(b)
tensor([[0, 2, 4],
        [1, 3, 5],
        [2, 1, 0]])
tensor([[2, 0],
        [1, 2],
        [0, 1]])
tensor([[2, 2],
        [1, 1],
        [0, 3]])
tensor([[4, 0],
        [3, 5],
        [2, 1]])

a中,dim=0,表示在行上取数据。那么就以列作为取值的基准。index中第一列的[2,1,0]表示a的第0列是x的第0列中,行号为[2,1,0]的数,以此类推。
由于index只有两列,所以a的结果不涉及x的第3列。

b中,dim=1,表示在列上取数据。那么就以行作为取值的基准。index中第一行的[2,0]表示b的第0行是x的第0行中,列号为[2,0]的数,以此类推。