神经网络从0到1(四)——数据集搭建

139
0
2020年10月9日 16时13分

引言

在前面部分的预处理中,我们总共完成了以下的功能,在图片中框出车牌部分,并将识别出来的车牌进行二值化处理,再进行字符分割。
到此,我们便可以进入pytorch的学习了,这一小节会教大家如何利用我们分割的字符来搭建神经网络数据集。我下面会逐块介绍代码,以防大家不知道如何拼接代码,我把整个工程放在下面云盘链接了,代码里需要修改一下路径参数,然后运行MAIN.py即可。


链接:https://pan.baidu.com/s/1X13WmUEE07Cuyelx8pKAPQ
提取码:2ma1
复制这段内容后打开百度网盘手机App,操作更方便哦


下面进入正题!


正文

接上一期的代码,我们获得了字符分割后的图片,我们利用下面代码,将分割后的图片单独切割出来并设置其像素为40*60(尺寸小训练起来不会特别消耗计算量,并且我们需要识别的图片较为简单,因此不需要很多像素点)

        for f in range(10):
            if split_line[f] == 0:
                flag_posi = f
                break
        #print(flag_posi)
        flag_step = 0
        for g in range(flag_posi+1):
            if flag_step == 0 and g < flag_posi and split_line[g] - 0 >=18:
                split0 = img_final_resize[1:49, 1:split_line[g]]
                split0_resize = cv2.resize(split0, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split0_resize',0)
                cv2.resizeWindow('split0_resize', 100, 200)
                cv2.moveWindow('split0_resize',400,600)
                cv2.imshow('split0_resize', split0_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/0.jpg',split0_resize)
            elif flag_step == 1 and g < flag_posi and split_line[1] - split_line[0] >=18:
                split1 = img_final_resize[1:49, split_line[0]:split_line[1]]
                split1_resize = cv2.resize(split1, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split1_resize', 0)
                cv2.resizeWindow('split1_resize', 100, 200)
                cv2.moveWindow('split1_resize', 550, 600)
                cv2.imshow('split1_resize', split1_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/1.jpg', split1_resize)

            elif flag_step == 2 and g < flag_posi and split_line[2] - split_line[1] >=18:
                split2 = img_final_resize[1:49, split_line[1]:split_line[2]]
                split2_resize = cv2.resize(split2, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split2_resize', 0)
                cv2.resizeWindow('split2_resize', 100, 200)
                cv2.moveWindow('split2_resize', 700, 600)
                cv2.imshow('split2_resize', split2_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/2.jpg', split2_resize)

            elif flag_step == 3 and g < flag_posi and split_line[3] - split_line[2] >=18:
                split3 = img_final_resize[1:49, split_line[2]:split_line[3]]
                split3_resize = cv2.resize(split3, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split3_resize', 0)
                cv2.resizeWindow('split3_resize', 100, 200)
                cv2.moveWindow('split3_resize', 850, 600)
                cv2.imshow('split3_resize', split3_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/3.jpg', split3_resize)

            elif flag_step == 4 and g < flag_posi and split_line[4] - split_line[3] >=18:
                split4 = img_final_resize[1:49, split_line[3]:split_line[4]]
                split4_resize = cv2.resize(split4, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split4_resize', 0)
                cv2.resizeWindow('split4_resize', 100, 200)
                cv2.moveWindow('split4_resize', 1000, 600)
                cv2.imshow('split4_resize', split4_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/4.jpg', split4_resize)

            elif flag_step == 5 and g < flag_posi and split_line[5] - split_line[4] >=18:
                split5 = img_final_resize[1:49, split_line[4]:split_line[5]]
                split5_resize = cv2.resize(split5, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split5_resize', 0)
                cv2.resizeWindow('split5_resize', 100, 200)
                cv2.moveWindow('split5_resize', 1150, 600)
                cv2.imshow('split5_resize', split5_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/5.jpg', split5_resize)

            elif flag_step == 6 and g < flag_posi and split_line[6] - split_line[5] >=18:
                split6 = img_final_resize[1:49, split_line[5]:split_line[6]]
                split6_resize = cv2.resize(split6, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split6_resize', 0)
                cv2.resizeWindow('split6_resize', 100, 200)
                cv2.moveWindow('split6_resize', 1300, 600)
                cv2.imshow('split6_resize', split6_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/6.jpg', split6_resize)

            elif flag_step == 7 and g < flag_posi and split_line[7] - split_line[6] >=18:
                split7 = img_final_resize[1:49, split_line[6]:split_line[7]]
                split7_resize = cv2.resize(split7, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split7_resize', 0)
                cv2.resizeWindow('split7_resize', 100, 200)
                cv2.moveWindow('split7_resize', 1450, 600)
                cv2.imshow('split7_resize', split7_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/7.jpg', split7_resize)

            elif flag_step == 8 and g < flag_posi and split_line[8] - split_line[7] >=18:
                split8 = img_final_resize[1:49, split_line[7]:split_line[8]]
                split8_resize = cv2.resize(split8, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split8_resize', 0)
                cv2.resizeWindow('split8_resize', 100, 200)
                cv2.moveWindow('split8_resize', 1600, 600)
                cv2.imshow('split8_resize', split8_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/8.jpg', split8_resize)

            elif flag_step == 9 and g < flag_posi and split_line[9] - split_line[8] >=18:
                split9 = img_final_resize[1:49, split_line[8]:split_line[9]]
                split9_resize = cv2.resize(split9, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split9_resize', 0)
                cv2.resizeWindow('split9_resize', 100, 200)
                cv2.moveWindow('split9_resize', 1750, 600)
                cv2.imshow('split9_resize', split9_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/9.jpg', split9_resize)
            elif g == flag_posi and 199 - split_line[flag_posi-1] >=15:
                split10 = img_final_resize[1:49, split_line[flag_posi-1]:199]
                split10_resize = cv2.resize(split10, dsize=(40, 60), fx=0, fy=0)
                cv2.namedWindow('split10_resize', 0)
                cv2.resizeWindow('split10_resize', 100, 200)
                cv2.moveWindow('split10_resize', 1600, 600)
                cv2.imshow('split10_resize', split10_resize)
                cv2.imwrite('D:/PycharmProjects/Num_distinguish/test_data/10.jpg', split10_resize)
            flag_step += 1
#以上为显示被分割出的子字符

神经网络从0到1(四)——数据集搭建插图
神经网络从0到1(四)——数据集搭建插图(1)
剪切完之后的效果如上图,可以得到单独的字符图片,我们将其保存在工程目录下的train_data/data目录下。并按要求进行命名,如下图
神经网络从0到1(四)——数据集搭建插图(2)
再在工程目录下train_data文件夹中创建labels.txt标签文件,这个文件中将会把图片的路径以及对应编号标明,如下图
神经网络从0到1(四)——数据集搭建插图(3)
我们用0~9来编码字符0~9,然后用10以后的数字来编码英文字母以及汉字。由于我们这个工程只是带大家入门,学会怎么使用pytorch,因此就不对所有字母以及所有地区的汉字简称进行编码了,仅仅对几个例子编码。


我们取出剪切出来的字符,大约三四十张,并制作好对应的标签,下面我们用代码创建数据集。
首先我们包含一些必要的头文件

import torch.utils.data.dataset
from torchvision import transforms
import os
from PIL import Image
import torch
import torch.nn as nn

接下来我们需要定义一个类用来描述我们的数据集,这个类名我们定义为MyDataset,其中我们需要完成这个类中必须要存在的三个构造方法,代码如下:

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, label_file_path):
        super().__init__()
        with open(label_file_path, 'r', encoding='UTF-8') as f:
            # (image_path(str), image_label(str))
            self.imgs = list(map(lambda line: line.strip().split(' '), f))

    def __getitem__(self, index):
        path, label = self.imgs[index]
        img = Image.open(path).convert('RGB')
        transform = transforms.ToTensor()
        img = transform(img)
        label = int(label)
        return img, label

    def __len__(self):  # 返回数据集的长度,即多少张图片
        return len(self.imgs)

这个类是自定义数据集所必须具备的类,定义完之后我们加载数据集即可,代码如下,其中test_data文件夹会放置我们之后识别用的数据集。

train_path = 'D:/PycharmProjects/Num_distinguish/train_data/labels.txt'
test_path = 'D:/PycharmProjects/Num_distinguish/test_data'

train_data = CNN_h.MyDataset(train_path)
train_loader = Data.DataLoader(dataset=train_data, batch_size=1, shuffle=True, num_workers=0)

总结

这章开始,我们进入神经网络的学习,这部分是构造数据集的一个方法,下一节我们就开始介绍如何搭建CNN网络来进行图片的训练。
急于想看到结果以及整个工程代码的同学可以在本章开头的百度网盘链接中下载整个工程,只需要运行MAIN.py,按提示输入即可。

发表评论

后才能评论