!转载请注明原文地址!——东方旅行者

更多行人重识别文章移步我的专栏:行人重识别专栏

难样本挖掘算法

一、难样本挖掘算法作用

本文件用于将先前在TriHard_Loss.py中编写的难样本挖掘算法单独形成一个文件,并加之返回难样本在距离矩阵中的索引的功能,便于在计算局部特征损失时也可以使用难样本挖掘相关结果。

二、难样本挖掘算法编写思路

主要的函数hard_example_mining(dist_mat, labels, return_indexs=False)。

输入:
1.距离矩阵dist_mat,维度(batch_size,batch_size)
2.本批次特征向量对应的行人ID labels,维度(batch_size)
3.是否返回最小相似度正样本与最大相似度负样本所对应的距离矩阵的序号return_indexs,默认为False
输出:
1.正样本区最小相似度张量dist_ap,维度(batch_size)
2.负样本区最大相似度张量dist_an,维度(batch_size)
3.正样本区最小相似度样本对应的距离矩阵下标p_indexs,维度(batch_size)
4.负样本区最大相似度样本对应的距离矩阵下标,n_indexs,维度(batch_size)

hard_example_mining先判断距离矩阵是不是二维,若不是二维则报错。然后判断距离矩阵是否是方阵,若不是方阵则报错。然后计算正样本区掩码与负样本区掩码,然后计算最小相似度(最大距离)正样本距离与最小相似度所对应正样本的序号(序号范围0~n-1)。计算最大相似度(最小距离)负样本距离与最大相似度所对应负样本的序号(序号范围0~n-1)。然后进行维度压缩。受max函数的影响,计算的序号是对应样本在某一行的索引,不是在距离矩阵中的索引,需要对索引进行计算转换成距离矩阵的索引。例如,假设每个ID取样4个样本,则每一行就有四个样本距离,经过计算后某一难样本的序号为0-3某一数字,但该样本实际在距离矩阵中索引为18,所以就需要进行计算,将行中的索引转化成距离矩阵中的索引。

假设距离矩阵为

[
[0, 1, 3, 5],
[1, 0 ,4, 6],
[3, 4, 0, 2],
[5, 6, 2, 0]
]

行人ID为

[0, 0, 1, 1]

则经过max函数后返回的最小相似度正样本的索引为[1,0,1,0],但实际索引应该为[1,0,3,2]。所以就需要进行计算进行索引转换。
根据计算得到的序号计算最小相似度正样本与最大相似度负样本在距离矩阵中的序号,这里需要使用torch.gather函数。计算最小相似度正样本与最大相似度负样本在距离矩阵中的序号后同样需要进行维度压缩。

三、代码

import torch

"""
本文件用于定义难样本挖掘算法
输入:
1.距离矩阵dist_mat,维度(batch_size,batch_size)
2.本批次特征向量对应的行人ID,维度(batch_size)
3.是否返回最小相似度正样本与最大相似度负样本所对应的距离矩阵的序号return_indexs,默认为False

输出:
1.正样本区最小相似度张量dist_ap,维度(batch_size)
2.负样本区最大相似度张量dist_an,维度(batch_size)
3.正样本区最小相似度样本对应的距离矩阵下标p_indexs,维度(batch_size)
4.负样本区最大相似度样本对应的距离矩阵下标,n_indexs,维度(batch_size)
"""
def hard_example_mining(dist_mat, labels, return_inds=False):
    #先判断距离矩阵是不是二维,若不是二维则报错
    assert len(dist_mat.size()) == 2
    #判断距离矩阵是否是方阵,若不是方阵则报错
    assert dist_mat.size(0) == dist_mat.size(1)
    N = dist_mat.size(0)#获取方阵长度
    is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())#正样本区掩码,负样本区元素为0
    is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())#负样本区掩码,正样本区元素为0
    
    #计算最小相似度(最大距离)正样本距离与最小相似度所对应正样本的序号(序号范围0~n-1)
    """
    .contiguous()用于将正样本区的距离拉成一维连续向量
    .view(N,-1)用于按照N为行形成矩阵
    torch.max函数不仅返回每一列中最大值的那个元素,并且返回最大值对应索引
    """
    dist_ap, relative_p_indexs = torch.max(dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
    
    #计算最大相似度(最小距离)负样本距离与最大相似度所对应负样本的序号(序号范围0~n-1)
    dist_an, relative_n_indexs = torch.min(
    dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
    
    #上面计算得到的dist_ap与dist_an维度为(batch_size,1)需要将最后一维进行压缩
    dist_ap = dist_ap.squeeze(1)
    dist_an = dist_an.squeeze(1)
    
    #根据计算得到的序号计算最小相似度正样本与最大相似度负样本在距离矩阵中的序号
    if return_inds:
        indexs = (labels.new().resize_as_(labels).copy_(torch.arange(0, N).long()).unsqueeze( 0).expand(N, N))
        """
        gather函数的用法torch.gather(input, dim, index, out=None) 
        就是从index中找到某值,作为input的某一维度的索引,取出的input的值作为output的某一元素
        核心思想:
        out[i][j][k] = input[index[i][j][k]] [j][k]  # if dim == 0
        out[i][j][k] = input[i][index[i][j][k]][k]   # if dim == 1
        out[i][j][k] = input[i][j][index[i][j][k]]   # if dim == 2
        注,output的维度同index一样
        
        对于下式就是p_indexs[i][0]=indexs[is_neg].contiguous().view(N, -1)[i][relative_n_indexs.data[i][0]]
        因为index即relative_n_indexs.data最后一维值为1,所以只能等于0
        """
        #计算最小相似度正样本与最大相似度负样本在距离矩阵中的序号,结果维度(batch_size,1)
        p_indexs = torch.gather(indexs[is_pos].contiguous().view(N, -1), 1, relative_p_indexs.data)
        n_indexs = torch.gather(indexs[is_neg].contiguous().view(N, -1), 1, relative_n_indexs.data)
        
        #将结果最后一维压缩
        p_indexs = p_indexs.squeeze(1)
        n_indexs = n_indexs.squeeze(1)
        
        return dist_ap, dist_an, p_indexs, n_indexs
    return dist_ap, dist_an