!转载请注明原文地址!——东方旅行者

更多行人重识别文章移步我的专栏:行人重识别专栏

本文目录

  • 数据处理器(transform.py)
    • 一、数据处理器作用
    • 二、数据处理器编写思路
    • 三、代码
    • 四、测试结果

数据处理器(transform.py)

一、数据处理器作用

模型读取数据(无论是训练还是测试)时,需要对数据进行必要的处理,如尺度统一、水平变换(训练时需要,测试时不需要)、将图片转为张量、归一化等。
用于对数据集图片进行尺寸标准化随机裁剪.

二、数据处理器编写思路

处理方法可以自定义也可以使用torchvision.transforms中的方法。

  1. 首先引用from torchvision.transforms import *,便于在引用自定义转换类时可以调用其他的转换方法。同时一个自定义的转换类需要实现其__init__方法与__call__方法。
  2. __init__方法需要传入目标长度、目标宽度、概率(随机裁剪需要进行概率判断)、插值方法(默认线性插值)
  3. __call__方法,如果随机数产生的概率小于给定概率,则直接将图片尺度归一,返回标准尺寸图片。否则,首先根据插值方法先将图片放大到目标高度与目标宽度的1.125倍。然后使用随机数在可选范围内选择一个裁剪的起点,然后指定起点横纵坐标,与长度宽度,进行裁剪,返回裁剪后的标准尺寸图片

三、代码

from torchvision.transforms import *
from PIL import Image
import random
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#防止服务挂掉

"""
模型读取数据(无论是训练还是测试)时,需要对数据进行必要的处理,如尺度统一、水平变换(训练时需要,测试时不需要)、将图片转为张量、归一化等
该文件用于对数据集图片进行尺寸标准化与随机裁剪
"""
class Random2DTransform(object):
    def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):#目标高度、目标宽度、概率p(用于是否进行随机裁剪)、插值方法(默认线性插值)
        self.height=height
        self.width=width
        self.p=p
        self.interpolation=interpolation
        
    def __call__(self, img):
        #若小于目标概率,则直接将图片尺寸标准化
        if random.random() < self.p:
            img=img.resize((self.width, self.height), self.interpolation)
        #若大于目标概率,则先根据插值方法扩大图片,并进行随机裁剪
        else:
            new_width=int(round(self.width*1.125))
            new_height=int(round(self.height*1.125))
            #先将图片扩大到目标长宽的1.125倍,然后再随机裁剪
            resize_img=img.resize((new_width, new_height), self.interpolation)
            x_maxrange=new_width-self.width
            y_maxrange=new_height-self.height
            #计算随机裁剪XY轴起点
            x_start=int(round(random.uniform(0,x_maxrange)))
            y_start=int(round(random.uniform(0,y_maxrange)))
            #进行裁剪
            img=resize_img.crop((x_start, y_start, x_start+self.width, y_start+self.height))
        return img
    
if __name__=='__main__':
    from dataset_manager import Market1501
    from dataset_loader import ImageDataset
    dataset=Market1501()
    train_loader=ImageDataset(dataset.train)
    plt.figure()
    j=1
    #从训练集中获取前两张图片进行处理,并使用matplot显示图片
    for batch_id, (img, pid, cid) in enumerate(train_loader):
        if(batch_id<2):  
            transform=Random2DTransform(64,64,0.5)
            img_t=transform(img)
            img_t=np.array(img_t)
            plt.subplot(1,2,j)
            plt.imshow(img) # 显示图片
            plt.savefig()
            j=j+1
            plt.subplot(1,2,j)
            plt.imshow(img_t) # 显示图片
            plt.show()
            j=1

四、测试结果

tranform测试结果1
tranform测试结果2