!转载请注明原文地址!——东方旅行者

更多行人重识别文章移步我的专栏:行人重识别专栏

本文目录

  • 难样本挖掘三元组损失(TriHard_Loss.py)
    • 一、难样本挖掘三元组损失作用
    • 二、难样本挖掘三元组损失编写思路
    • 三、代码
    • 四、测试结果

难样本挖掘三元组损失(TriHard_Loss.py)

一、难样本挖掘三元组损失作用

用于计算度量损失,与表征学习阶段分类损失协同使用反向传播优化网络参数,且难样本三元组损失有利于网络学到更好的特征。

二、难样本挖掘三元组损失编写思路

在实现难样本挖掘三元损失时借助相似度矩阵进行计算,将批次图片按照顺序形成一个大小为P×K的方阵,方阵元素(0,2)代表第0张图片与第二张图片的相似度。如下图所示(图片来自浙江大学罗浩博士教学视频
难样本挖掘三元组损失实现方法示意图
红色区域为每个行人与各自正样本之间的距离,而绿色区域为每个行人与各自负样本的距离。将该矩阵进行变换,将红色区域移到同一侧,绿色区域同一侧,对红色区域按行求最大值得到正样本最大距离向量(P×K,1),对绿色区域按行求最小值得到负样本最小距离向量(P*K,1),得到这两个向量即可根据公式计算难样本挖掘三元组损失。

编写代码时,计算损失需要继承nn.Module,重写__init__方法与forward方法。
__init__方法需要传入margin,并调用MarginRankingLoss计算三元组损失。
在这里插入图片描述

三、代码

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from IPython import embed

"""
本文件用于自定义难样本挖掘三元组损失,定义难样本挖掘三元组损失计算过程。
"""
class TripletLoss(nn.Module):
    def __init__(self, margin=0.3):
        super().__init__()
        self.margin=margin
        #计算三元组损失使用的函数
        self.ranking_loss=nn.MarginRankingLoss(margin=margin)
        
    def forward(self, inputs, targets):
        n=inputs.size(0)
        """
        计算图片之间的欧氏距离
        矩阵A,B欧氏距离等于√(A^2 + (B^T)^2 - 2A(B^T))
        """
        #计算A^2
        distance=torch.pow(inputs,2).sum(dim=1, keepdim=True).expand(n,n)
        #计算A^2 + (B^T)^2
        distance=distance+distance.t()
        #计算A^2 + (B^T)^2 - 2A(B^T)
        distance.addmm(1,-2,inputs,inputs.t())
        #计算√(A^2 + (B^T)^2 - 2A(B^T))
        distance=distance.clamp(min=1e-12).sqrt()#该distance矩阵为对称矩阵
        
        #获取对角线
        mask=targets.expand(n,n)==targets.expand(n,n).t()#mask矩阵用于区分红绿色区域,即正样本区与负样本区,便于进行损失计算。
        
        #list类型
        distance_ap,distance_an=[],[]
        
        for i in range(n):
            distance_ap.append(distance[i][mask[i]].max().unsqueeze(0))#distance[i][mask[i]]使distance保留正样本区
            distance_an.append(distance[i][mask[i]==0].min().unsqueeze(0))#distance[i][mask[i]==0]使distance保留负样本区
        
        #经过for循环后,正样本最大距离与负样本最小距离都存储在list当中,需要将list元素连接成一个torch张量
        distance_ap=torch.cat(distance_ap)
        distance_an=torch.cat(distance_an)
        #y指明ranking_loss前一个参数大于后一个参数
        y=torch.ones_like(distance_an)
        loss=self.ranking_loss(distance_an, distance_ap, y)
        
        return loss

if __name__=='__main__':
    target=[1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6,7,7,7,7,8,8,8,8]
    target=torch.Tensor(target)
    features=torch.rand(32,2048)
    a=TripletLoss()
    loss=a.forward(features,target)
    print(loss)

四、测试结果

tensor(0.8285)