百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

TensorFlow2.0以上版本的图像分类

ztj100 2024-11-27 23:33 25 浏览 0 评论

摘要

本篇文章采用CNN实现图像的分类,图像选取了猫狗大战数据集的1万张图像(猫狗各5千)。模型采用自定义的CNN网络,版本是TensorFlow 2.0以上的版本。通过本篇文章,你可以学到图像分类常用的手段,包括:

1、图像增强

2、训练集和验证集切分

3、使用ModelCheckpoint保存最优模型

4、使用ReduceLROnPlateau调整学习率。

5、打印loss结果生成jpg图片。

网络详解

训练部分

1、导入依赖

 import os
 import numpy as np
 from tensorflow import keras
 from tensorflow.keras.optimizers import Adam
 from tensorflow.keras.models import Sequential
 from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
 from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
 import cv2
 from tensorflow.keras.preprocessing.image import img_to_array
 from sklearn.model_selection import train_test_split
 from tensorflow.python.keras import Input
 from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
 from tensorflow.python.keras.layers import PReLU, Activation
 from tensorflow.python.keras.models import Model

2、设置全局参数

norm_size=100#输入到网络的图像尺寸,单位是像素。

datapath='train'#图片的根目录

EPOCHS =100#训练的epoch个数

INIT_LR = 1e-3#初始学习率

labelList=[]#标签

dicClass={'cat':0,'dog':1}#类别
labelnum=2#类别个数
batch_size = 4

3、加载数据

def loadImageData():

    imageList = []

    listImage=os.listdir(datapath)#获取所有的图像

    for img in listImage:#遍历图像

        labelName=dicClass[img.split('.')[0]]#获取label对应的数字

        print(labelName)

        labelList.append(labelName)

        dataImgPath=os.path.join(datapath,img)

        print(dataImgPath)

        image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)

        # load the image, pre-process it, and store it in the data list

        image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)

        image = img_to_array(image)

        imageList.append(image)

    imageList = np.array(imageList, dtype="int") / 255.0#归一化图像

    return imageList
print("开始加载数据")

imageArr=loadImageData()

labelList = np.array(labelList)

print("加载数据完成")

print(labelList)

4、定义模型

def bn_prelu(x):

    x = BatchNormalization(epsilon=1e-5)(x)

    x = PReLU()(x)

    return x

def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):

    inputs_dim = Input(input_shape)

    x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)

    x = bn_prelu(x)

    x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)

    x = bn_prelu(x)

    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)

    x = bn_prelu(x)

    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)

    x = bn_prelu(x)

    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)

    x = bn_prelu(x)

    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)

    x = bn_prelu(x)

    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)

    x = bn_prelu(x)

    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)

    x = bn_prelu(x)

    x = GlobalAveragePooling2D()(x)

    dp_1 = Dropout(0.5)(x)

    fc2 = Dense(out_dims)(dp_1)

    fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数

    model = Model(inputs=inputs_dim, outputs=fc2)

    return model

model=build_model(labelnum)#生成模型

optimizer = Adam(lr=INIT_LR)#加入优化器,设置优化器的学习率。

model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

5、切割训练集和验证集

trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)

6、数据增强

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(featurewise_center=True,

    featurewise_std_normalization=True,

    rotation_range=20,

    width_shift_range=0.2,

    height_shift_range=0.2,

    horizontal_flip=True)

val_datagen = ImageDataGenerator()     #验证集不做图片增强
train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)

val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)

7、设置callback函数

checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',

                            monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')



reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,

                                            verbose=1,

                                            factor=0.5,

                                            min_lr=1e-6)

8、训练并保存模型

history = model.fit_generator(train_generator,

       steps_per_epoch=trainX.shape[0]/batch_size,

       validation_data = val_generator,

       epochs=EPOCHS,

       validation_steps=valX.shape[0]/batch_size,

       callbacks=[checkpointer,reduce],

       verbose=1,shuffle=True)

model.save('my_model_.h5')

9、保存训练历史数据

import os

loss_trend_graph_path = r"WW_loss.jpg"

acc_trend_graph_path = r"WW_acc.jpg"

import matplotlib.pyplot as plt

print("Now,we start drawing the loss and acc trends graph...")

# summarize history for accuracy

fig = plt.figure(1)

plt.plot(history.history["accuracy"])

plt.plot(history.history["val_accuracy"])

plt.title("Model accuracy")

plt.ylabel("accuracy")

plt.xlabel("epoch")

plt.legend(["train", "test"], loc="upper left")

plt.savefig(acc_trend_graph_path)

plt.close(1)

# summarize history for loss

fig = plt.figure(2)

plt.plot(history.history["loss"])

plt.plot(history.history["val_loss"])

plt.title("Model loss")

plt.ylabel("loss")

plt.xlabel("epoch")

plt.legend(["train", "test"], loc="upper left")

plt.savefig(loss_trend_graph_path)

plt.close(2)

print("We are done, everything seems OK...")

# #windows系统设置10关机

os.system("shutdown -s -t 10")

img

img

完整代码: import os
import numpy as np
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
import cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_split
from tensorflow.python.keras import Input
from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.python.keras.layers import PReLU, Activation
from tensorflow.python.keras.models import Model

norm_size=
100
datapath=
'train'
EPOCHS =
100
INIT_LR =
1e-3
labelList=[]
dicClass={
'cat':0,'dog':1}
labelnum=
2
batch_size =
4
def loadImageData():
imageList = []
listImage=os.listdir(datapath)
for img in listImage:
labelName=dicClass[img.split(
'.')[0]]
print(labelName)
labelList.append(labelName)
dataImgPath=os.path.join(datapath,img)
print(dataImgPath)
image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8),
-1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imageList.append(image)
imageList = np.array(imageList, dtype=
"int") / 255.0
return imageList
print(
"开始加载数据")
imageArr=loadImageData()
labelList = np.array(labelList)
print(
"加载数据完成")
print(labelList)
def bn_prelu(x):
x = BatchNormalization(epsilon=
1e-5)(x)
x = PReLU()(x)
return x

def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
inputs_dim = Input(input_shape)
x = Conv2D(
32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
x = bn_prelu(x)
x = Conv2D(
32, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(
2, 2))(x)
x = Conv2D(
64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(
64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(
2, 2))(x)
x = Conv2D(
128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(
128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(
2, 2))(x)
x = Conv2D(
256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(
256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = GlobalAveragePooling2D()(x)
dp_1 = Dropout(
0.5)(x)
fc2 = Dense(out_dims)(dp_1)
fc2 = Activation(
'softmax')(fc2) #此处注意,为sigmoid函数
model = Model(inputs=inputs_dim, outputs=fc2)
return model
model=build_model(labelnum)
optimizer = Adam(lr=INIT_LR)
model.compile(optimizer =optimizer, loss=
'sparse_categorical_crossentropy', metrics=['accuracy'])
trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=
0.3, random_state=42)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(featurewise_center=
True,
featurewise_std_normalization=
True,
rotation_range=
20,
width_shift_range=
0.2,
height_shift_range=
0.2,
horizontal_flip=
True)
val_datagen = ImageDataGenerator()
#验证集不做图片增强

train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=
True)
val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=
True)
checkpointer = ModelCheckpoint(filepath=
'weights_best_simple_model.hdf5',
monitor=
'val_accuracy',verbose=1, save_best_only=True, mode='max')

reduce = ReduceLROnPlateau(monitor=
'val_accuracy',patience=10,
verbose=
1,
factor=
0.5,
min_lr=
1e-6)
history = model.fit_generator(train_generator,
steps_per_epoch=trainX.shape[
0]/batch_size,
validation_data = val_generator,
epochs=EPOCHS,
validation_steps=valX.shape[
0]/batch_size,
callbacks=[checkpointer,reduce],
verbose=
1,shuffle=True)
model.save(
'my_model_.h5')
print(history)
import os
loss_trend_graph_path =
r"WW_loss.jpg"
acc_trend_graph_path =
r"WW_acc.jpg"
import matplotlib.pyplot as plt
print(
"Now,we start drawing the loss and acc trends graph...")
# summarize history for accuracy
fig = plt.figure(
1)
plt.plot(history.history[
"accuracy"])
plt.plot(history.history[
"val_accuracy"])
plt.title(
"Model accuracy")
plt.ylabel(
"accuracy")
plt.xlabel(
"epoch")
plt.legend([
"train", "test"], loc="upper left")
plt.savefig(acc_trend_graph_path)
plt.close(
1)
# summarize history for loss
fig = plt.figure(
2)
plt.plot(history.history[
"loss"])
plt.plot(history.history[
"val_loss"])
plt.title(
"Model loss")
plt.ylabel(
"loss")
plt.xlabel(
"epoch")
plt.legend([
"train", "test"], loc="upper left")
plt.savefig(loss_trend_graph_path)
plt.close(
2)
print(
"We are done, everything seems OK...")
# #windows系统设置10关机
os.system(
"shutdown -s -t 10")
测试部分 1、导入依赖 import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import time
2、设置全局参数 norm_size=100
imagelist=[]
emotion_labels = {
0: 'cat',
1: 'dog'
}
3、加载模型 emotion_classifier=load_model(**"my_model_.h5"**)
t1=time.time()
4、处理图片 image = cv2.imdecode(np.fromfile(**'test/8.jpg'**, dtype=np.uint8), -1)
\
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype=**
"float"**) / 255.0
5、预测类别 pre=np.argmax(emotion_classifier.predict(imageList))
emotion = emotion_labels[pre]
t2=time.time()
print(emotion)
t3=t2-t1
print(t3)
完整代码 import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import time
norm_size=
100
imagelist=[]
emotion_labels = {
0: 'cat',
1: 'dog'
}
emotion_classifier=load_model(
"my_model_.h5")
t1=time.time()
image = cv2.imdecode(np.fromfile(
'test/8.jpg', dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype=
"float") / 255.0
pre=np.argmax(emotion_classifier.predict(imageList))
emotion = emotion_labels[pre]
t2=time.time()
print(emotion)
t3=t2-t1
print(t3)

相关推荐

离谱!写了5年Vue,还不会自动化测试?

前言大家好,我是倔强青铜三。是一名热情的软件工程师,我热衷于分享和传播IT技术,致力于通过我的知识和技能推动技术交流与创新,欢迎关注我,微信公众号:倔强青铜三。Playwright是一个功能强大的端到...

package.json 与 package-lock.json 的关系

模块化开发在前端越来越流行,使用node和npm可以很方便的下载管理项目所需的依赖模块。package.json用来描述项目及项目所依赖的模块信息。那package-lock.json和...

Github 标星35k 的 SpringBoot整合acvtiviti开源分享,看完献上膝盖

前言activiti是目前比较流行的工作流框架,但是activiti学起来还是费劲,还是有点难度的,如何整合在线编辑器,如何和业务表单绑定,如何和系统权限绑定,这些问题都是要考虑到的,不是说纯粹的把a...

Vue3 + TypeScript 前端研发模板仓库

我们把这个Vue3+TypeScript前端研发模板仓库的初始化脚本一次性补全到可直接运行的状态,包括:完整的目录结构所有配置文件研发规范文档示例功能模块(ExampleFeature)...

Vue 2迁移Vue 3:从响应式到性能优化

小伙伴们注意啦!Vue2已经在2023年底正式停止维护,再不升级就要面临安全漏洞没人管的风险啦!而且Vue3带来的性能提升可不是一点点——渲染速度快40%,内存占用少一半,更新速度直接翻倍!还在...

VUE学习笔记:声明式渲染详解,对比WEB与VUE

声明式渲染是指使用简洁的模板语法,声明式的方式将数据渲染进DOM系统。声明式是相对于编程式而言,声明式是面向对象的,告诉框架做什么,具体操作由框架完成。编程式是面向过程思想,需要手动编写代码完成具...

苏州web前端培训班, 苏州哪里有web前端工程师培训

前端+HTML5德学习内容:第一阶段:前端页面重构:PC端网站布局、HTML5+CSS3基础项目、WebAPP页面布局;第二阶段:高级程序设计:原生交互功能开发、面向对象开发与ES5/ES6、工具库...

跟我一起开发微信小程序——扩展组件的代码提示补全

用户自定义代码块步骤:1.HBuilderX中工具栏:工具-代码块设置-vue代码块2.通过“1”步骤打开设置文件...

JimuReport 积木报表 v1.9.3发布,免费可视化报表

项目介绍积木报表JimuReport,是一款免费的数据可视化报表,含报表、大屏和仪表盘,像搭建积木一样完全在线设计!功能涵盖:数据报表、打印设计、图表报表、门户设计、大屏设计等!...

软开企服开源的无忧企业文档(V2.1.3)产品说明书

目录1....

一款面向 AI 的下一代富文本编辑器,已开源

简介AiEditor是一个面向AI的下一代富文本编辑器。开箱即用、支持所有前端框架、支持Markdown书写模式什么是AiEditor?AiEditor是一个面向AI的下一代富文本编辑...

玩转Markdown(2)——抽象语法树的提取与操纵

上一篇玩转Markdown——数据的分离存储与组件的原生渲染发布,转眼已经鸽了大半年了。最近在操纵mdast生成md文件的时候,心血来潮,把玩转Markdown(2)给补上了。...

DeepseekR1+ollama+dify1.0.0搭建企业/个人知识库(入门避坑版)

找了网上的视频和相关文档看了之后,可能由于版本不对或文档格式不对,很容易走弯路,看完这一章,可以让你少踩三天的坑。步骤和注意事项我一一列出来:1,前提条件是在你的电脑上已配置好ollama,dify1...

升级JDK17的理由,核心是降低GC时间

升级前后对比升级方法...

一个vsCode格式化插件_vscode格式化插件缩进量

ESlint...

取消回复欢迎 发表评论: