!转载请注明原文地址!——东方旅行者
更多行人重识别文章移步我的专栏:行人重识别专栏
本文目录
- 随机采样器(sampler.py)
- 一、随机采样器作用
- 二、随机采样器编写思路
- 三、代码
- 四、测试结果
随机采样器(sampler.py)
一、随机采样器作用
本文件用于自定义采样器类RandomIdentitySampler,RandomIdentitySampler根据指定的数据集(索引列表)与采样数量进行采样,最后返回记录采样数据图片序号的列表迭代器。
在度量学习中,模型需要通过网络学习出图片间的相似度,相同行人图片相似度高于不同行人。所以需要一个批次中既有相同行人图片也要有不同行人图片,所以需要使用随机采样器进行采样。
二、随机采样器编写思路
该文件中的RandomIdentitySampler需要继承torch.utils.data下的sampler.py中的Sampler类,同时需要重写__iter__ 方法(用于迭代获取数据集内数据)与__len__方法(用于返回采样器采样的数据的总长度)。
- __init__方法需要传递一个数据集(类型list,即data_manager生成的三个子数据集的索引列表)还有一个指明每个ID采样数量的参数num_instances。然后声明一个字典(它的value为一个列表)用于存储行人ID与对应数据图片的序号。通过字典的keys()方法获得行人ID并使用list()方法生成行人ID列表。字典元素为(key,[图片1序号,图片2序号……])
- __iter__方法用于迭代获取数据集内元素。先将行人ID打乱顺序重排,然后声明一个result结果列表用于存储采样结果,迭代获取重排的行人ID,进行类型转换,从字典中获取该行人ID的图片列表,并通过列表获得该行人图片的数量,如果数量小于参数num_instances则可以进行重复采样,否则不需要重复采样。使用np.random.choice进行采样,并将采样后的图片加入result列表中。最后然后需要用iter()方法生成一个result的迭代器。Result列表的元素是数据图片的序号。
- __len__方法直接返回最后result列表的长度。
三、代码
import torch
import numpy as np
from torch.utils.data.sampler import Sampler
from collections import defaultdict
from IPython import embed
"""
本文件用于自定义采样器类RandomIdentitySampler
RandomIdentitySampler根据指定的数据集(索引列表)与采样数量进行采样,最后返回记录采样数据图片序号的列表迭代器
在度量学习中,模型需要通过网络学习出图片间的相似度,相同行人图片相似度高于不同行人。所以需要一个批次中既有相同行人图片也要有不同行人图片,所以需要使用随机采样器进行采样。
"""
class RandomIdentitySampler(Sampler):
def __init__(self, data_source, num_instances=4):
self.data_source=data_source
self.num_instances=num_instances
self.index_dic= defaultdict(list)#改字典用于记录行人ID与其数据图片序号的对应关系
for index, (_,pid,_) in enumerate(data_source):
self.index_dic[pid].append(index)#将该行人ID对应的数据图片序号加入字典
self.pids=list(self.index_dic.keys())#获取行人ID列表
self.num_indentities=len(self.pids)
def __iter__(self):
indices=torch.randperm(self.num_indentities)#对行人ID打乱顺序重排
result=[]#result列表用于存储采样数据图片的序号
for i in indices:
#注意类型转换,字典的索引不能是Tensor类型
t=self.index_dic[int(i)]
#如果该pid拥有的图片少于num_instances,则可以重复采样
replace=False if len(t)>=self.num_instances else True
#采样
t=np.random.choice(t, size=self.num_instances, replace=replace)
result.extend(t)
#!!一定要返回result的迭代器,不能只返回result列表!!
return iter(result)
def __len__(self):
return self.num_instances*self.num_indentities
if __name__=='__main__':
from dataset_manager import Market1501
dataset=Market1501()
print(type(dataset.train))
sampler=RandomIdentitySampler(dataset.train, num_instances=4)
b=sampler.__iter__()
print('采样样本list的长度为{}'.format(sampler.__len__()))
四、测试结果
样本长度3004=训练集行人ID数量(751)×每个行人取样四张图片(4)
=> Market1501 loaded
------------------------------------------------------------------------
subset: train | num_id: 751 | num_imgs: 12936
subset: query | num_id: 750 | num_imgs: 3368
subset: gallery | num_id: 751 | num_imgs: 19732
------------------------------------------------------------------------
total | num_id: 1501 | num_imgs: 16304
------------------------------------------------------------------------
采样样本list的长度为3004
评论(0)
您还未登录,请登录后发表或查看评论