!转载请注明原文地址!——东方旅行者

更多行人重识别文章移步我的专栏:行人重识别专栏

本文目录

  • 数据管理器(dataset_manager.py)
    • 一、数据管理器作用
    • 二、数据管理器编写思路
    • 三、代码
    • 四、测试结果

数据管理器(dataset_manager.py)

一、数据管理器作用

该文件主要负责指定数据集路径处理原始数据集并生成数据索引列表返回子数据集相关参数(子集行人ID数量,子集图片数量)。因为Market1501已经划分好训练集、测试集与查询集,所以直接可以根据路径提取这三个数据集。

二、数据管理器编写思路

  1. 指定数据集根目录路径
  2. 分别指定训练集、测试集、查询集的路径
  3. 通过这三个子数据集路径获得子集下所有图片的地址,通过每个图片的地址就可以得到行人ID与摄像机ID等信息,根据这些信息生成一个索引列表,类型为list,列表中每个元素都是一个三元组(数据图片地址,行人ID,摄像头ID),与此同时获取子数据集相关信息,如子集行人ID数量,子集图片数量等参数
  4. 三个数据集索引列表生成完毕,打印相关参数到控制台

索引列表如下所示:

[
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000451_03.jpg',0,0),
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000551_01.jpg',0,0),
('./data/Market-1501-v15.09.15\\bounding_box_train\\0002_c1s1_000776_01.jpg',0,0)
]

因为每一个子集中行人ID不一定连续,所以为了便于训练,一般要对训练集的行人ID进行重排,便于训练。所以需要使用一个名称为pid2label的Map来记录原始ID与重排ID的对应关系。

三、代码

import os
import os.path as osp
import numpy as np
import glob
import re
from IPython import embed
"""
Market1501类用于
1.指定数据集路径
2.处理原始数据集并生成数据索引列表
3.返回子数据集的相关参数(子集行人ID数量,子集图片数量)
"""
class Market1501(object):
    dataset_dir='data/Market-1501-v15.09.15'#指定数据集路径
    
    def __init__(self,root='./',**kwargs):
        self.dataset_dir=osp.join(root,self.dataset_dir)
        self.train_dir=osp.join(self.dataset_dir,'bounding_box_train')#训练集
        self.gallery_dir=osp.join(self.dataset_dir,'bounding_box_test')#测试集
        self.query_dir=osp.join(self.dataset_dir,'query')#查询集
        
        train, num_train_pids, num_train_imgs=self._process_dir(self.train_dir,relabel=True)
        query, num_query_pids, num_query_imgs=self._process_dir(self.query_dir,relabel=False)
        gallery, num_gallery_pids, num_gallery_imgs=self._process_dir(self.gallery_dir,relabel=False)
        
        num_total_pids=num_train_pids+num_query_pids
        num_total_imgs=num_train_imgs+num_query_imgs
        
        print("=> Market1501 loaded")
        print("------------------------------------------------------------------------")
        print("  subset: train  \t| num_id: {:5d}  \t|  num_imgs:{:8d}  ".format(num_train_pids,num_train_imgs))
        print("  subset: query  \t| num_id: {:5d}  \t|  num_imgs:{:8d}  ".format(num_query_pids,num_query_imgs))
        print("  subset: gallery \t| num_id: {:5d}  \t|  num_imgs:{:8d}  ".format(num_gallery_pids,num_gallery_imgs))
        print("------------------------------------------------------------------------")
        print("  total \t\t\t| num_id: {:5d}  \t|  num_imgs:{:8d}  ".format(num_total_pids,num_total_imgs))
        print("------------------------------------------------------------------------")
        
        self.train=train
        self.query=query
        self.gallery=gallery
        self.num_train_pids=num_train_pids
        self.num_query_pids=num_query_pids
        self.num_gallery_pids=num_gallery_pids
        
    def _process_dir(self,dir_path,relabel=False):
        img_paths=glob.glob(osp.join(dir_path,'*.jpg'))
        pid_container=set()
        
        for img_path in img_paths:
            pid=int(img_path.split("\\")[-1].split("_")[0])
            if pid==-1:continue
            pid_container.add(pid)
        
        pid2label={pid:label for label,pid in enumerate(pid_container)}
        
        dataset=[]
        
        for img_path in img_paths:
            str_list=img_path.split("\\")[-1].split("_")
            pid=int(str_list[0])
            cid=int(str_list[1][1:2])
            if pid==-1:continue
            assert 0<=pid <=1501
            assert 1<=cid<=6
            cid+=-1
            if relabel:
                pid=pid2label[pid]
            dataset.append((img_path,pid,cid))
        
        num_pids=len(pid_container)
        num_imgs=len(img_paths)
        #返回一个数据为三元组(图片地址,行人ID,摄像机ID)的索引列表形式的数据集,行人ID数量,图片数量
        return dataset, num_pids, num_imgs    

if __name__=='__main__':
    data=Market1501()

四、测试结果

data_manager测试结果