paddle几行代码轻松实现CNN卷积神经网络在MNIST上的图像分类
ztj100 2024-11-08 15:07 16 浏览 0 评论
前面几期的文章,我们介绍了paddle飞桨AI框架
PaddlePaddle飞桨深度学习实现手写数字识别任务——模型识别篇
本期我们重点介绍一下如何使用paddle来构建CNN卷积神经网络,并在MNIST数据集上面进行相关的数据加载与训练
----1----
MNIST的数据加载
前面已经提到,每一份MINIST数据都由图片以及标签组成。我们将图片命名为“x”,将标记数字的标签命名为“y”。训练数据集和测试数据集都是同样的结构,例如:训练的图片名为 mnist.train.images 而训练的标签名为 mnist.train.labels。
每一个图片均为28×28像素,我们可以将其理解为一个二维数组的结构:
在实际应用中,保存到本地的数据存储格式多种多样,如MNIST数据集以json格式存储在本地,其数据存储结构如下图 所示。
data包含三个元素的列表:train_set、val_set、 test_set,包括50 000条训练样本、10 000条验证样本、10 000条测试样本。每个样本包含手写的数字图片和对应的标签。
- train_set(训练集):用于确定模型参数。
- val_set(验证集):用于调节模型超参数(如多个网络结构、正则化权重的最优选择)。
- test_set(测试集):用于估计应用效果(没有在模型中应用过的数据,更贴近模型在真实场景应用的效果)。
train_set包含两个元素的列表:train_images、train_labels。
- train_images:[50 000, 784]的二维列表,包含50 000张图片。每张图片用一个长度为784的向量表示,内容是28*28尺寸的像素灰度值(黑白图片)。
- train_labels:[50 000, ]的列表,表示这些图片对应的分类标签,即0~9之间的一个数字。
import os
import random
import paddle
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import gzip
import json
from paddle.nn import Conv2D, MaxPool2D, Linear
import paddle.nn.functional as F
# 定义数据集读取器
def load_data(mode='train'):
# 加载数据
datafile = 'dataset/mnist.json.gz'
print('loading mnist dataset from {} ......'.format(datafile))
data = json.load(gzip.open(datafile))
print('mnist dataset load done')
# 读取到的数据区分训练集,验证集,测试集
train_set, val_set, eval_set = data
# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS
IMG_ROWS = 28
IMG_COLS = 28
if mode == 'train':
# 获得训练数据集
imgs, labels = train_set[0], train_set[1]
elif mode == 'valid':
# 获得验证数据集
imgs, labels = val_set[0], val_set[1]
elif mode == 'eval':
# 获得测试数据集
imgs, labels = eval_set[0], eval_set[1]
else:
raise Exception("mode can only be one of ['train', 'valid', 'eval']")
#校验数据
imgs_length = len(imgs)
assert len(imgs) == len(labels), \
"length of train_imgs({}) should be the same as train_labels({})".format(
len(imgs), len(labels))
# 定义数据集每个数据的序号, 根据序号读取数据
index_list = list(range(imgs_length))
# 读入数据时用到的batchsize
BATCHSIZE = 100
# 定义数据生成器
def data_generator():
if mode == 'train':
random.shuffle(index_list)
imgs_list = []
labels_list = []
for i in index_list:
#神经网络重点修改代码
img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')
label = np.reshape(labels[i], [1]).astype('int64')
imgs_list.append(img)
labels_list.append(label)
if len(imgs_list) == BATCHSIZE:
yield np.array(imgs_list), np.array(labels_list)
imgs_list = []
labels_list = []
# 如果剩余数据的数目小于BATCHSIZE,
# 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch
if len(imgs_list) > 0:
yield np.array(imgs_list), np.array(labels_list)
return data_generator
与往期代码处理数据不同的是如下2行代码
img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')
label = np.reshape(labels[i], [1]).astype('int64')
----2----
CNN卷积神经网络的搭建
# 卷积神经网络实现
class MNIST(paddle.nn.Layer):
def __init__(self):
super(MNIST, self).__init__()
# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2
self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)
# 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2
self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)
# 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
# 定义一层全连接层,输出维度是10
self.fc = Linear(in_features=980, out_features=10)
# 定义网络前向计算过程,卷积后紧接着使用池化层,最后使用全连接层计算最终输出
# 卷积层激活函数使用Relu
def forward(self, inputs):
x = self.conv1(inputs)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.reshape(x, [x.shape[0], 980])
x = self.fc(x)
return x
神经网络的搭建如上面代码所示,最终CNN输出10个MNIST数据集的数字分类,后期我们讲根据此神经网络进行数字识别
----3----
CNN卷积神经网络的训练与模型保存
def train(model):
model.train()
#调用加载数据的函数,获得MNIST训练数据集
train_loader = load_data('train')
# 使用SGD优化器,learning_rate设置为0.01
opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
#opt = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())
# 训练10轮
EPOCH_NUM = 10
# MNIST图像高和宽
IMG_ROWS, IMG_COLS = 28, 28
for epoch_id in range(EPOCH_NUM):
for batch_id, data in enumerate(train_loader()):
#准备数据
images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels)
#前向计算的过程
predicts = model(images)
#计算损失,使用交叉熵损失函数,取一个批次样本损失的平均值
loss = F.cross_entropy(predicts, labels)
avg_loss = paddle.mean(loss)
#每训练200批次的数据,打印下当前Loss的情况
if batch_id % 200 == 0:
print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))
#后向传播,更新参数的过程
avg_loss.backward()
# 最小化loss,更新参数
opt.step()
# 清除梯度
opt.clear_grad()
#保存模型参数
paddle.save(model.state_dict(), 'model/mnist_cnn_3.pdparams')
model = MNIST()
train(model)
这里跟前期代码重点关注的是如下代码,loss我们是使用交叉熵损失函数
#计算损失,使用交叉熵损失函数,取一个批次样本损失的平均值
loss = F.cross_entropy(predicts, labels)
loading mnist dataset from dataset/mnist.json.gz ......
mnist dataset load done
epoch: 0, batch: 0, loss is: [2.5198565]
epoch: 0, batch: 200, loss is: [0.28517285]
epoch: 0, batch: 400, loss is: [0.2283621]
epoch: 1, batch: 0, loss is: [0.22373457]
epoch: 1, batch: 200, loss is: [0.12276303]
epoch: 1, batch: 400, loss is: [0.16722086]
epoch: 2, batch: 0, loss is: [0.16171335]
epoch: 2, batch: 200, loss is: [0.07980514]
epoch: 2, batch: 400, loss is: [0.22929193]
epoch: 3, batch: 0, loss is: [0.11811857]
epoch: 3, batch: 200, loss is: [0.08179446]
epoch: 3, batch: 400, loss is: [0.18990536]
epoch: 4, batch: 0, loss is: [0.09782474]
epoch: 4, batch: 200, loss is: [0.1028099]
epoch: 4, batch: 400, loss is: [0.1387132]
epoch: 5, batch: 0, loss is: [0.23205265]
epoch: 5, batch: 200, loss is: [0.16299754]
epoch: 5, batch: 400, loss is: [0.14402886]
epoch: 6, batch: 0, loss is: [0.08868036]
epoch: 6, batch: 200, loss is: [0.0643991]
epoch: 6, batch: 400, loss is: [0.06269972]
epoch: 7, batch: 0, loss is: [0.09487689]
epoch: 7, batch: 200, loss is: [0.04431003]
epoch: 7, batch: 400, loss is: [0.07480995]
epoch: 8, batch: 0, loss is: [0.15115222]
epoch: 8, batch: 200, loss is: [0.07555249]
epoch: 8, batch: 400, loss is: [0.24242447]
epoch: 9, batch: 0, loss is: [0.03684139]
epoch: 9, batch: 200, loss is: [0.04600935]
epoch: 9, batch: 400, loss is: [0.22289713]
神经网络训练完成后,我们把训练的模型进行保存,以便进行数字识别,可以看到使用CNN卷积神经网络,loss已经降到了0.05以下,若多训练几次,此loss会更小
ok,有了此模型,我们便可以利用训练好的模型进行数字识别了
----4----
CNN卷积神经网络手写数字识别
首先,我们准备好需要识别的手写数字与上面训练好的模型,然后我们进行数字识别
第一步 搭建CNN卷积神经网络模型
# 导入图像读取第三方库
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linear
# 定义mnist数据识别网络结构
class MNIST(paddle.nn.Layer):
def __init__(self):
super(MNIST, self).__init__()
# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2
self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)
# 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2
self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)
# 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
# 定义一层全连接层,输出维度是1
self.fc = Linear(in_features=980, out_features=10)
# 定义网络前向计算过程,卷积后紧接着使用池化层,最后使用全连接层计算最终输出
# 卷积层激活函数使用Relu
def forward(self, inputs):
x = self.conv1(inputs)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.reshape(x, [x.shape[0], 980])
x = self.fc(x)
return x
搭建的神经网络模型跟训练的代码完全一样,直接复制上面的代码即可
第二步 加载图片
# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):
# 从img_path中读取图像,并转为灰度图
im = Image.open(img_path).convert('L')
#plt.imshow(im,cmap='gray')
# print(np.array(im))
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
# 图像归一化,保持和数据集的数据范围一致
im = 1 - im / 255
return im
加载图片后,我们需要对图片进行相关的预处理操作,包括转换到灰度空间,缩放到28*28尺寸,并进行数据的归一化
第三步 加载模型进行预测
# 定义预测过程
model = MNIST()
params_file_path = 'model/mnist_cnn_3.pdparams'
img_path = 'image/example_6.jpg'
# 加载模型参数
param_dict = paddle.load(params_file_path)
model.load_dict(param_dict)
# 灌入数据
model.eval()
tensor_img = load_image(img_path)
#模型反馈10个分类标签的对应概率
result = model(paddle.to_tensor(tensor_img))
print('result',result)
#取概率最大的标签作为预测输出
lab = np.argsort(result.numpy())
print(lab)
print("本次预测的数字是: ", lab[0][-1])
我们加载上面代码训练好的模型,并使用paddle.load进行模型的加载,然后使用model(paddle.to_tensor(tensor_img))函数进行数字的预测,最后打印输出预测的数字
result Tensor(shape=[1, 10], dtype=float32, place=CPUPlace, stop_gradient=False,
[[ 0.68616289, -2.16074777, -0.68914777, -4.40827608, -0.68840426,
1.91024637, 8.90141392, -4.07809448, 1.31746018, -2.95059443]])
[[3 7 9 1 2 4 0 8 5 6]]
本次预测的数字是: 6
结论
相比前期的模型训练与预测,CNN卷积神经网络不仅loss会加速减少,且利用训练好的模型可以精确的识别出训练的数字,本教程并没有完全使用paddle的高级API进行CNN的搭建与预测,其代码相对多一些,有关paddle的高级API,我们后期进行相关技术的分享
相关推荐
- 从IDEA开始,迈进GO语言之门(idea got)
-
前言笔者在学习GO语言编程的时候,GO语言在国内还没有像JAVA/Php/Python那样普及,绕了不少的弯路,要开始入门学习一门编程语言,最好就先从选择一个好的编程语言的开发环境开始,有了这个开发环...
- 基于SpringBoot+MyBatis的私人影院java网上购票jsp源代码Mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于SpringBoot...
- 基于springboot的个人服装管理系统java网上商城jsp源代码mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于springboot...
- 基于springboot的美食网站Java食品销售jsp源代码Mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍基于springboot...
- 贸易管理进销存springboot云管货管账分析java jsp源代码mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目描述贸易管理进销存spring...
- SpringBoot+VUE员工信息管理系统Java人员管理jsp源代码Mysql
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目介绍SpringBoot+V...
- 目前见过最牛的一个SpringBoot商城项目(附源码)还有人没用过吗
-
帮粉丝找了一个基于SpringBoot的天猫商城项目,快速部署运行,所用技术:MySQL,Druid,Log4j2,Maven,Echarts,Bootstrap...免费给大家分享出来前台演示...
- SpringBoot+Mysql实现的手机商城附带源码演示导入视频
-
今天为大家带来的是基于SpringBoot+JPA+Thymeleaf框架的手机商城管理系统,商城系统分为前台和后台、前台用的是Bootstrap框架后台用的是SpringBoot+JPA都是现在主...
- 全网首发!马士兵内部共享—1658页《Java面试突击核心讲》
-
又是一年一度的“金九银十”秋招大热门,为助力广大程序员朋友“面试造火箭”,小编今天给大家分享的便是这份马士兵内部的面试神技——1658页《Java面试突击核心讲》!...
- SpringBoot数据库操作的应用(springboot与数据库交互)
-
1.JDBC+HikariDataSource...
- SpringBoot 整合 Flink 实时同步 MySQL
-
1、需求在Flink发布SpringBoot打包的jar包能够实时同步MySQL表,做到原表进行新增、修改、删除的时候目标表都能对应同步。...
- SpringBoot + Mybatis + Shiro + mysql + redis智能平台源码分享
-
后端技术栈基于SpringBoot+Mybatis+Shiro+mysql+redis构建的智慧云智能教育平台基于数据驱动视图的理念封装element-ui,即使没有vue的使...
- Springboot+Mysql舞蹈课程在线预约系统源码附带视频运行教程
-
今天发布的是由【猿来入此】的优秀学员独立做的一个基于springboot脚手架的Springboot+Mysql舞蹈课程在线预约系统,系统项目源代码在【猿来入此】获取!https://www.yuan...
- SpringBoot+Mysql在线众筹系统源码+讲解视频+开发文档(参考论文
-
今天发布的是由【猿来入此】的优秀学员独立做的一个基于springboot脚手架的在线众筹管理系统,主要实现了普通用户在线参与众筹基本操作流程的全部功能,系统分普通用户、超级管理员等角色,除基础脚手架外...
- Docker一键部署 SpringBoot 应用的方法,贼快贼好用
-
这两天发现个Gradle插件,支持一键打包、推送Docker镜像。今天我们来讲讲这个插件,希望对大家有所帮助!GradleDockerPlugin简介...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- 从IDEA开始,迈进GO语言之门(idea got)
- 基于SpringBoot+MyBatis的私人影院java网上购票jsp源代码Mysql
- 基于springboot的个人服装管理系统java网上商城jsp源代码mysql
- 基于springboot的美食网站Java食品销售jsp源代码Mysql
- 贸易管理进销存springboot云管货管账分析java jsp源代码mysql
- SpringBoot+VUE员工信息管理系统Java人员管理jsp源代码Mysql
- 目前见过最牛的一个SpringBoot商城项目(附源码)还有人没用过吗
- SpringBoot+Mysql实现的手机商城附带源码演示导入视频
- 全网首发!马士兵内部共享—1658页《Java面试突击核心讲》
- SpringBoot数据库操作的应用(springboot与数据库交互)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)