TensorFlow2.0以上版本的图像分类
ztj100 2024-11-27 23:33 19 浏览 0 评论
摘要
本篇文章采用CNN实现图像的分类,图像选取了猫狗大战数据集的1万张图像(猫狗各5千)。模型采用自定义的CNN网络,版本是TensorFlow 2.0以上的版本。通过本篇文章,你可以学到图像分类常用的手段,包括:
1、图像增强
2、训练集和验证集切分
3、使用ModelCheckpoint保存最优模型
4、使用ReduceLROnPlateau调整学习率。
5、打印loss结果生成jpg图片。
网络详解
训练部分
1、导入依赖
import os
import numpy as np
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
import cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_split
from tensorflow.python.keras import Input
from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.python.keras.layers import PReLU, Activation
from tensorflow.python.keras.models import Model
2、设置全局参数
norm_size=100#输入到网络的图像尺寸,单位是像素。
datapath='train'#图片的根目录
EPOCHS =100#训练的epoch个数
INIT_LR = 1e-3#初始学习率
labelList=[]#标签
dicClass={'cat':0,'dog':1}#类别
labelnum=2#类别个数
batch_size = 4
3、加载数据
def loadImageData():
imageList = []
listImage=os.listdir(datapath)#获取所有的图像
for img in listImage:#遍历图像
labelName=dicClass[img.split('.')[0]]#获取label对应的数字
print(labelName)
labelList.append(labelName)
dataImgPath=os.path.join(datapath,img)
print(dataImgPath)
image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imageList.append(image)
imageList = np.array(imageList, dtype="int") / 255.0#归一化图像
return imageList
print("开始加载数据")
imageArr=loadImageData()
labelList = np.array(labelList)
print("加载数据完成")
print(labelList)
4、定义模型
def bn_prelu(x):
x = BatchNormalization(epsilon=1e-5)(x)
x = PReLU()(x)
return x
def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
inputs_dim = Input(input_shape)
x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
x = bn_prelu(x)
x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = GlobalAveragePooling2D()(x)
dp_1 = Dropout(0.5)(x)
fc2 = Dense(out_dims)(dp_1)
fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数
model = Model(inputs=inputs_dim, outputs=fc2)
return model
model=build_model(labelnum)#生成模型
optimizer = Adam(lr=INIT_LR)#加入优化器,设置优化器的学习率。
model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
5、切割训练集和验证集
trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)
6、数据增强
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
val_datagen = ImageDataGenerator() #验证集不做图片增强
train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)
val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)
7、设置callback函数
checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',
monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')
reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,
verbose=1,
factor=0.5,
min_lr=1e-6)
8、训练并保存模型
history = model.fit_generator(train_generator,
steps_per_epoch=trainX.shape[0]/batch_size,
validation_data = val_generator,
epochs=EPOCHS,
validation_steps=valX.shape[0]/batch_size,
callbacks=[checkpointer,reduce],
verbose=1,shuffle=True)
model.save('my_model_.h5')
9、保存训练历史数据
import os
loss_trend_graph_path = r"WW_loss.jpg"
acc_trend_graph_path = r"WW_acc.jpg"
import matplotlib.pyplot as plt
print("Now,we start drawing the loss and acc trends graph...")
# summarize history for accuracy
fig = plt.figure(1)
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.title("Model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(acc_trend_graph_path)
plt.close(1)
# summarize history for loss
fig = plt.figure(2)
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(loss_trend_graph_path)
plt.close(2)
print("We are done, everything seems OK...")
# #windows系统设置10关机
os.system("shutdown -s -t 10")
img
img
完整代码: import os
import numpy as np
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
import cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_split
from tensorflow.python.keras import Input
from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.python.keras.layers import PReLU, Activation
from tensorflow.python.keras.models import Model
norm_size=100
datapath='train'
EPOCHS =100
INIT_LR = 1e-3
labelList=[]
dicClass={'cat':0,'dog':1}
labelnum=2
batch_size = 4
def loadImageData():
imageList = []
listImage=os.listdir(datapath)
for img in listImage:
labelName=dicClass[img.split('.')[0]]
print(labelName)
labelList.append(labelName)
dataImgPath=os.path.join(datapath,img)
print(dataImgPath)
image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imageList.append(image)
imageList = np.array(imageList, dtype="int") / 255.0
return imageList
print("开始加载数据")
imageArr=loadImageData()
labelList = np.array(labelList)
print("加载数据完成")
print(labelList)
def bn_prelu(x):
x = BatchNormalization(epsilon=1e-5)(x)
x = PReLU()(x)
return x
def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
inputs_dim = Input(input_shape)
x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
x = bn_prelu(x)
x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = GlobalAveragePooling2D()(x)
dp_1 = Dropout(0.5)(x)
fc2 = Dense(out_dims)(dp_1)
fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数
model = Model(inputs=inputs_dim, outputs=fc2)
return model
model=build_model(labelnum)
optimizer = Adam(lr=INIT_LR)
model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
val_datagen = ImageDataGenerator() #验证集不做图片增强
train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)
val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)
checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',
monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')
reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,
verbose=1,
factor=0.5,
min_lr=1e-6)
history = model.fit_generator(train_generator,
steps_per_epoch=trainX.shape[0]/batch_size,
validation_data = val_generator,
epochs=EPOCHS,
validation_steps=valX.shape[0]/batch_size,
callbacks=[checkpointer,reduce],
verbose=1,shuffle=True)
model.save('my_model_.h5')
print(history)
import os
loss_trend_graph_path = r"WW_loss.jpg"
acc_trend_graph_path = r"WW_acc.jpg"
import matplotlib.pyplot as plt
print("Now,we start drawing the loss and acc trends graph...")
# summarize history for accuracy
fig = plt.figure(1)
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.title("Model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(acc_trend_graph_path)
plt.close(1)
# summarize history for loss
fig = plt.figure(2)
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(loss_trend_graph_path)
plt.close(2)
print("We are done, everything seems OK...")
# #windows系统设置10关机
os.system("shutdown -s -t 10")
测试部分 1、导入依赖 import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import time
2、设置全局参数 norm_size=100
imagelist=[]
emotion_labels = {
0: 'cat',
1: 'dog'
}
3、加载模型 emotion_classifier=load_model(**"my_model_.h5"**)
t1=time.time()
4、处理图片 image = cv2.imdecode(np.fromfile(**'test/8.jpg'**, dtype=np.uint8), -1)
\# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype=**"float"**) / 255.0
5、预测类别 pre=np.argmax(emotion_classifier.predict(imageList))
emotion = emotion_labels[pre]
t2=time.time()
print(emotion)
t3=t2-t1
print(t3)
完整代码 import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import time
norm_size=100
imagelist=[]
emotion_labels = {
0: 'cat',
1: 'dog'
}
emotion_classifier=load_model("my_model_.h5")
t1=time.time()
image = cv2.imdecode(np.fromfile('test/8.jpg', dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype="float") / 255.0
pre=np.argmax(emotion_classifier.predict(imageList))
emotion = emotion_labels[pre]
t2=time.time()
print(emotion)
t3=t2-t1
print(t3)
相关推荐
- Sublime Text 4 稳定版 Build 4113 发布
-
IT之家7月18日消息知名编辑器SublimeText4近日发布了Build4113版本,是SublimeText4的第二个稳定版。IT之家了解到,SublimeTe...
- 【小白课程】openKylin便签贴的设计与实现
-
openKylin便签贴作为侧边栏的一个小插件,提供便捷的文本记录和灵活的页面展示。openKylin便签贴分为两个部分:便签列表...
- 壹啦罐罐 Android 手机里的 Xposed 都装了啥
-
这是少数派推出的系列专题,叫做「我的手机里都装了啥」。这个系列将邀请到不同的玩家,从他们各自的角度介绍手机中最爱的或是日常使用最频繁的App。文章将以「每周一篇」的频率更新,内容范围会包括iOS、...
- 电气自动化专业词汇中英文对照表(电气自动化专业英语单词)
-
专业词汇中英文对照表...
- Python界面设计Tkinter模块的核心组件
-
我们使用一个模块,我们要熟悉这个模块的主要元件。如我们设计一个窗口,我们可以用Tk()来完成创建;一些交互元素,按钮、标签、编辑框用到控件;怎么去布局你的界面,我们可以用到pack()、grid()...
- 以色列发现“死海古卷”新残片(死海古卷是真的吗)
-
编译|陈家琦据艺术新闻网(artnews.com)报道,3月16日,以色列考古学家发现了死海古卷(DeadSeaScrolls)新残片。新出土的羊皮纸残片中包括以希腊文书写的《十二先知书》段落,这...
- 鸿蒙Next仓颉语言开发实战教程:订单列表
-
大家上午好,最近不断有友友反馈仓颉语言和ArkTs很像,所以要注意不要混淆。今天要分享的是仓颉语言开发商城应用的订单列表页。首先来分析一下这个页面,它分为三大部分,分别是导航栏、订单类型和订单列表部分...
- 哪些模块可以用在 Xposed for Lollipop 上?Xposed 模块兼容性解答
-
虽然已经有了XposedforLollipop的安装教程,但由于其还处在alpha阶段,一些Xposed模块能不能依赖其正常工作还未可知。为了解决大家对于模块兼容性的疑惑,笔者尽可能多...
- 利用 Fluid 自制 Mac 版 Overcast 应用
-
我喜爱收听播客,健身、上/下班途中,工作中,甚至是忙着做家务时。大多数情况下我会用MarcoArment开发的Overcast(Freemium)在iPhone上收听,这是我目前最喜爱的Po...
- 浅色Al云食堂APP代码(三)(手机云食堂)
-
以下是进一步优化完善后的浅色AI云食堂APP完整代码,新增了数据可视化、用户反馈、智能推荐等功能,并优化了代码结构和性能。项目结构...
- 实战PyQt5: 121-使用QImage实现一个看图应用
-
QImage简介QImage类提供了独立于硬件的图像表示形式,该图像表示形式可以直接访问像素数据,并且可以用作绘制设备。QImage是QPaintDevice子类,因此可以使用QPainter直接在图...
- 滚动条隐藏及美化(滚动条隐藏但是可以滚动)
-
1、滚动条隐藏背景/场景:在移动端,滑动的时候,会显示默认滚动条,如图1://隐藏代码:/*隐藏滚轮*/.ul-scrool-box::-webkit-scrollbar,.ul-scrool...
- 浅色AI云食堂APP完整代码(二)(ai 食堂)
-
以下是整合后的浅色AI云食堂APP完整代码,包含后端核心功能、前端界面以及优化增强功能。项目采用Django框架开发,支持库存管理、订单处理、财务管理等核心功能,并包含库存预警、数据导出、权限管理等增...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)