!转载请注明原文地址!——东方旅行者

更多行人重识别文章移步我的专栏:行人重识别专栏

本文目录

  • 随机采样器(sampler.py)
    • 一、随机采样器作用
    • 二、随机采样器编写思路
    • 三、代码
    • 四、测试结果

随机采样器(sampler.py)

一、随机采样器作用

本文件用于自定义采样器类RandomIdentitySampler,RandomIdentitySampler根据指定的数据集(索引列表)与采样数量进行采样,最后返回记录采样数据图片序号的列表迭代器。
在度量学习中,模型需要通过网络学习出图片间的相似度,相同行人图片相似度高于不同行人。所以需要一个批次中既有相同行人图片也要有不同行人图片,所以需要使用随机采样器进行采样。

二、随机采样器编写思路

该文件中的RandomIdentitySampler需要继承torch.utils.data下的sampler.py中的Sampler类,同时需要重写__iter__ 方法(用于迭代获取数据集内数据)与__len__方法(用于返回采样器采样的数据的总长度)。

  1. __init__方法需要传递一个数据集(类型list,即data_manager生成的三个子数据集的索引列表)还有一个指明每个ID采样数量的参数num_instances。然后声明一个字典(它的value为一个列表)用于存储行人ID与对应数据图片的序号。通过字典的keys()方法获得行人ID并使用list()方法生成行人ID列表。字典元素为(key,[图片1序号,图片2序号……])
  2. __iter__方法用于迭代获取数据集内元素。先将行人ID打乱顺序重排,然后声明一个result结果列表用于存储采样结果,迭代获取重排的行人ID,进行类型转换,从字典中获取该行人ID的图片列表,并通过列表获得该行人图片的数量,如果数量小于参数num_instances则可以进行重复采样,否则不需要重复采样。使用np.random.choice进行采样,并将采样后的图片加入result列表中。最后然后需要用iter()方法生成一个result的迭代器。Result列表的元素是数据图片的序号。
  3. __len__方法直接返回最后result列表的长度。

三、代码

import torch
import numpy as np
from torch.utils.data.sampler import Sampler
from collections import defaultdict
from IPython import embed

"""
本文件用于自定义采样器类RandomIdentitySampler
RandomIdentitySampler根据指定的数据集(索引列表)与采样数量进行采样,最后返回记录采样数据图片序号的列表迭代器
在度量学习中,模型需要通过网络学习出图片间的相似度,相同行人图片相似度高于不同行人。所以需要一个批次中既有相同行人图片也要有不同行人图片,所以需要使用随机采样器进行采样。
"""
class RandomIdentitySampler(Sampler):
    def __init__(self, data_source, num_instances=4):
        self.data_source=data_source
        self.num_instances=num_instances
        self.index_dic= defaultdict(list)#改字典用于记录行人ID与其数据图片序号的对应关系
        for index, (_,pid,_) in enumerate(data_source):
            self.index_dic[pid].append(index)#将该行人ID对应的数据图片序号加入字典
        self.pids=list(self.index_dic.keys())#获取行人ID列表
        self.num_indentities=len(self.pids)
    def __iter__(self):
        indices=torch.randperm(self.num_indentities)#对行人ID打乱顺序重排
        result=[]#result列表用于存储采样数据图片的序号
        for i in indices:
            #注意类型转换,字典的索引不能是Tensor类型
            t=self.index_dic[int(i)]
            #如果该pid拥有的图片少于num_instances,则可以重复采样
            replace=False if len(t)>=self.num_instances else True
            #采样
            t=np.random.choice(t, size=self.num_instances, replace=replace)
            result.extend(t)
        #!!一定要返回result的迭代器,不能只返回result列表!!
        return iter(result)
    def __len__(self):
        return self.num_instances*self.num_indentities

if __name__=='__main__':
    from dataset_manager import Market1501
    dataset=Market1501()
    print(type(dataset.train))
    sampler=RandomIdentitySampler(dataset.train, num_instances=4)
    b=sampler.__iter__()
    print('采样样本list的长度为{}'.format(sampler.__len__()))

四、测试结果

样本长度3004=训练集行人ID数量(751)×每个行人取样四张图片(4)

=> Market1501 loaded
------------------------------------------------------------------------
  subset: train  	| num_id:   751  	|  num_imgs:   12936  
  subset: query  	| num_id:   750  	|  num_imgs:    3368  
  subset: gallery 	| num_id:   751  	|  num_imgs:   19732  
------------------------------------------------------------------------
  total 			| num_id:  1501  	|  num_imgs:   16304  
------------------------------------------------------------------------
采样样本list的长度为3004