实战Omniglot数据集识别(手写数字MNIST升级版)

364
0
2020年12月13日 09时06分

实战Omniglot数据集识别(手写数字MNIST升级版)

 

目录

  • 实战Omniglot数据集识别(手写数字MNIST升级版)
    • 经典方法
    • 简单方法
      • 最简单——无脑全连接
      • 简单卷积网络
      • 添加STN与inception
      • 改变loss函数
      • 加深网络
    • 小样本学习
      • 度量学习(metric learning)
      • 数据增强(data augmentation)
      • 元学习(meta learning)
      • 语义的方法(semantic)
    • 升级方法尝试
      • 数据增强
      • 原型网络
    • 总结

 

最近模式识别老师布置了一个大作业:手写字符识别
这里用到的数据集是Omniglot,如下:

 

在这里插入图片描述

 

这个数据集可谓是手写数字识别的爸爸呀,手写数字识别作为识别领域最简单的任务(大概吧),Omniglot相比它的难点就在于:Omniglot具有1623个类别,但每个类别只有20张图片。
今天正好周末,我们就来打(虐)发(待)一下时(电)间(脑)吧。

 

经典方法

 

在做什么任务之前,我们都应该想想经典方法可不可以实现。
用于分类,我们可以使用SVM或者朴素贝叶斯分类等。
我尝试了SVM,速度真是惨不忍睹,之后我尝试了朴素贝叶斯分类法,准确率大概是0.13左右。

 

简单方法

 

最简单——无脑全连接

 

我们先从最简单的方法开始,直接把图片展成一个向量,然后送入全连接。
网络结构如下:

 

QQ图片20201210105551

 

每一层均使用batchnorm,ReLU激活,并使用DropOut,概率为0.5。
这方法简单但准确率并不高,甚至和朴素贝叶斯分类方法准确率不相上下,训练过程如下:

 

在这里插入图片描述

 

测试集准确率在0.18左右浮动,就连训练集准确率也没有突破0.2。
并不想浪费时间在调整这样的无脑网络上,所以我没有继续想办法优化,仅作为一个尝试,接下来开始卷积网络。

 

简单卷积网络

 

首先作为尝试,我先选择简单的卷积网络进行试验,设计网络结构如下:

 

(括号里表示in_channel,out_channel,kernal_size,stride)
Conv2d(1,64,3,2)
Conv2d(64,128,3,2)
Conv2d(128,256,3,1)
Conv2d(256,512,3,1)
View(1,512*3*3)
Full_connect(512*3*3,2048)
Full_connect(2048,1623)
层间均用ReLU激活,全连接中添加batchnorm层(由于通过尝试发现卷积层间添加batchnorm时会导致准确率降低,所以在这里不添加)

 

训练设置weight_decay为1e-4,初始学习率为1e-4,并以0.9指数每两轮衰减一次。训练结果如下:

 

在这里插入图片描述

 

测试准确率收敛到了0.72,也算好了很多了。同时也可以看出这个数据集确实没有MNIST那么简单。
但在训练时可以发现,过拟合现象很严重,训练集预测准确率可以达到0.99。通过设置weight_dacay,添加DropOut,也没有很大的改善。看来只是按照普通的方法进行卷积还是有缺陷。

 

添加STN与inception

 

数据集中的图片每一个字符可能不是正的,这时候就需要网络具有旋转不变性。考虑到这一点,我在网络输入图片的时候添加了STN模块。同时我们识别时也需要从不同尺度看这张图片,然后通过特征融合得到不同尺度的特征,所以我添加了inception的思路。

 

在这里插入图片描述

 

图为STN模块,具体细节可自行查找。
我使用具体网络结构如下:

 

输入先通过STN模块调整方向,然后分为两路:
第一路:小卷积核
Conv2d(1, 16, kernel_size=3, padding=1)
Conv2d(16, 32, kernel_size=3, padding=1)
第二路:大卷积核
Conv2d(1, 16, kernel_size=7, padding=3)
Conv2d(16, 32, kernel_size=7, padding=3)
将两路的特征堆叠连接,送入如下卷积层
Conv2d(64, 128, kernel_size=3, padding=1)
Conv2d(128, 256, kernel_size=3, padding=1)
Conv2d(256, 512, kernel_size=3)
Conv2d(512, 1024, kernel_size=3)
Conv2d(1024, 2048, kernel_size=3)
这时特征已经缩减为一个向量,直接送入如下全连接层:
Full_connect(2048, 2048)
Full_connect(2048, 1623)
各层均用ReLU激活,全连接层用DropOut防止过拟合

 

训练设置weight_decay为1e-4,初始学习率为1e-4,并以0.9指数每两轮衰减一次。训练结果如下:

 

在这里插入图片描述

 

可见测试集准确率上升至0.75左右,上升不是很明显。

 

改变loss函数

 

目前目标识别领域常用的loss函数除了交叉熵损失,还有focal loss,该损失是交叉熵的拓展,往往比交叉熵有更好的效果。通常无用的易分反例样本会使得模型的整体学习方向跑偏,导致无效学习,所以该损失通过调整权重降低这些样本的影响,如下:

 

在这里插入图片描述
在这里插入图片描述

 

γ \gammaγ等于0时,该损失退化为交叉熵。

使用该损失函数,同时使用之前最简单的卷积结构,训练设置weight_decay为1e-4,初始学习率为1e-4,并以0.9指数每两轮衰减一次,γ \gammaγ取2。训练结果如下:

 

在这里插入图片描述

 

可见准确率也上升至了0.75左右。

 

加深网络

 

看来这并不是简单的任务,我们通过更深的网络进行尝试:
直接使用未进行预训练的ResNet50的结构,将输出全连接的最后输出通道改为1,维度改为1623,训练结果如下:

 

在这里插入图片描述

 

可见效果良好,可以达到0.99准确率,但是收敛慢,训练慢(毕竟太深了)。在1080ti上跑50个epoch用了半小时。

 

小样本学习

 

一顿乱试之后,我们该静下来想想为什么了,有没有方法能够花费较少的时间快速收敛且执行效率高呢?

 

其实,该任务属于小样本学习,即样本量非常少。目前,解决该难题的方法大致有如下四种:

 

1、度量学习(metric learning)
2、数据增强(data augmentation)
3、元学习(meta learning)
4、语义的方法(semantic)
我们一一来解释一下:

 

度量学习(metric learning)

 

即将待检测样本通过神经网络Embeding到另一个空间域内,在该空间中,每个样本为一个高维点,高维点之间距离越近代表这两个样本越可能是同一个类别。距离可以取各种距离,这也就是其名称“度量”之意。神经网络需要学习的也就是这样的一个映射,这里有一些有名的损失函数如triplet loss与reconstructive loss。

 

数据增强(data augmentation)

 

这个应该不用多说了吧,就是通过各种骚方法扩充数据集增加可识别率。

 

元学习(meta learning)

 

这应该也是目前的一个热点,包含面较为广泛,其根本用意就是我们常听到的“learning to learn”。他包含的方法有孪生网络、原型网络以及一些其他的方法。

 

语义的方法(semantic)

 

由于小样本学习困难的本质还是在于信息不够多,我们就想办法引入一些语义的信息来帮助分类。

 

下面,我们将运用数据增强以及原型网络来试一下下:

 

升级方法尝试

 

数据增强

 

由于对于字符,镜面翻转与随机旋转都不行,我采用了对每一个字符进行开闭运算的方法将数据集扩充了一倍,这里使用每类中35张图片作为训练集,5张图片作为测试集。使用上一节中最简单的卷积网络训练,使用focal loss,结果如下:

 

在这里插入图片描述

 

非常amazing啊,测试集准确率训练10个epoch时达到了0.9,18epoch时达到了0.99。收敛如此之快让我们领会到了数据的重要性。

 

原型网络

 

原型网络为解决小样本学习的元学习方法中的一种,我这里运用了原型网络最初的论文:Prototypical Networks for Few-shot Learning中的方法,简要介绍一下:

 

在这里插入图片描述

 

如图,神经网络学习一个空间Embeding,将数据映射到另一空间,然后求同一类别的均值,作为该类别的原型。如下:

 

在这里插入图片描述

 

其中f即为该神经网络:

 

在这里插入图片描述

 

然后引入一个新的数据,判断其到每个原型的欧氏距离的softmax值,作为其属于该类别的概率:

 

在这里插入图片描述

 

损失函数要做的就是最大化正确识别时的这个概率,如下:

 

在这里插入图片描述

 

运用这样的方法,我使用的网络结构如下:

 

Conv2d(1,64,kernal_size=3,stride=2,padding=1),BatchNorm(),ReLU()
Conv2d(64,128,kernal_size=3,stride=2,padding=1),BatchNorm(),ReLU()
Conv2d(128,256,kernal_size=3,stride=2,padding=1),BatchNorm(),ReLU()
Conv2d(256,512,kernal_size=3,stride=2,padding=1),BatchNorm(),ReLU()
Flatten()

 

然后进行训练,结果更加amazing:
一轮直接收敛:
每类10个做训练集,10个做测试集时,训练一轮后测试集准确率到达0.988,后面最高到达0.99.
每类2个做训练集,18个做测试集时,一轮训练后测试集准确率到达0.96,后面最高到达0.97。

 

总结

 

完成了老师布置的作业,终于能去快乐地玩耍了emm。

发表评论

后才能评论