【tensorflow2.0】fashion mnist 数据集训练

176
0
2020年12月12日 09时10分

数据集介绍

 

使用Fashion MNIST数据集,其中包含10个类别的70,000个灰度图像。图像显示了低分辨率(28 x 28像素)的单个衣​​物,如下所示(图片来自tensorflow官方文档):

 

微信图片_20201209142910

 

图像是28×28 NumPy数组,像素值范围是0到255。标签是整数数组,范围是0到9。这些对应于图像表示的衣服类别:

 

1607495372(1)

 

代码

 

import tensorflow as tf
import pandas as pd
import matplotlib as mlt
import matplotlib.pyplot as plt
print(tf.__version__)
print(tf.test.is_gpu_available())
# 加载mnist数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train_all, Y_train_all),(X_test, Y_test) = fashion_mnist.load_data()
X_train_all = X_train_all/255
X_test = X_test/255
# 将训练集拆分出验证集,让模型每跑完一次数据就验证一次准确度
x_valid, x_train  = X_train_all[:5000], X_train_all[5000:]
y_valid, y_train  = Y_train_all[:5000], Y_train_all[5000:]
# 模型构建 使用的是tf.keras.Sequential
# relu:y=max(0,x) 即取0和x中的最大值
# softmax: 将输出向量变成概率分布,例如 x = [x1, x2, x3], 则
#                                     y = [e^x1/sum, e^x2/sum, e^x3/sum],
#                                     sum = e^x1+e^x2+e^x3
model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28,28)), # Flatten函数的作用是将输入的二维数组进行展开,使其变成一维的数组
        tf.keras.layers.Dense(256,activation='relu'), # 创建权连接层,激活函数使用relu
        tf.keras.layers.Dropout(0.2),                 # 使用dropout缓解过拟合的发生
        tf.keras.layers.Dense(10, activation='softmax') # 输出层
    ]
)
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy', # 损失函数使用交叉熵
              metrics=['accuracy'])
model.summary() # 打印模型信息
# history记录模型训练过程中的一些值
history = model.fit(x_train, y_train, epochs=5,
                    validation_data=(x_valid,y_valid))
print('history:',history.history)
# 将history中的数据以图片表示出来
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.ylim(0,1)
plt.show()
model.evaluate(X_test,  Y_test, verbose=2)

 

模型结构

 

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten (Flatten)            (None, 784)               0
_________________________________________________________________
dense (Dense)                (None, 256)               200960
_________________________________________________________________
dropout (Dropout)            (None, 256)               0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2570
=================================================================
Total params: 203,530
Trainable params: 203,530
Non-trainable params: 0
_________________________________________________________________

 

训练过程

 

Train on 55000 samples, validate on 5000 samples
Epoch 1/5
55000/55000 [==============================] - 6s 106us/sample - loss: 0.5183 - accuracy: 0.8162 - val_loss: 0.3885 - val_accuracy: 0.8598
Epoch 2/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3908 - accuracy: 0.8570 - val_loss: 0.3656 - val_accuracy: 0.8696
Epoch 3/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3585 - accuracy: 0.8697 - val_loss: 0.3203 - val_accuracy: 0.8836
Epoch 4/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3358 - accuracy: 0.8767 - val_loss: 0.3326 - val_accuracy: 0.8796
Epoch 5/5
55000/55000 [==============================] - 5s 98us/sample - loss: 0.3237 - accuracy: 0.8808 - val_loss: 0.3297 - val_accuracy: 0.8824

 

微信图片_20201209143017

 

 

发表评论

后才能评论