VGG16进行微调,训练mnist数据集_vgg16 tensorflow 2 mnist-程序员宅基地

技术标签: tensorflow  机器学习  深度学习  神经网络  

直接用mnist训练VGG16,得到的准确率只有0.11左右,所以我用keras内置的用imagenet训练的VGG16进行微调,训练mnist,准确率达到0.94,得到了很大的提升。

from keras.applications import VGG16
from keras.datasets import mnist
from keras.utils import to_categorical
from keras import models
from keras.layers.core import Dense,Flatten,Dropout
import cv2
import numpy as np

#加载数据
(x_train,y_train),(x_test,y_test)=mnist.load_data()
#VGG16模型,权重由ImageNet训练而来,模型的默认输入尺寸是224x224,但是最小是48x48
#修改数据集的尺寸、将灰度图像转换为rgb图像
x_train=[cv2.cvtColor(cv2.resize(i,(48,48)),cv2.COLOR_GRAY2BGR)for i in x_train]
x_test=[cv2.cvtColor(cv2.resize(i,(48,48)),cv2.COLOR_GRAY2BGR)for i in x_test]
#第一步:通过np.newaxis函数把每一个图片增加一个维度变成(1,48,48,3)。所以就有了程序中的arr[np.newaxis]。
#第二步:通过np.concatenate把每个数组连接起来组成一个新的x_train数组,连接后的x_train数组shape为(10000,48,48,3)
x_train=np.concatenate([arr[np.newaxis]for arr in x_train])
x_test=np.concatenate([arr[np.newaxis]for arr in x_test])


x_train=x_train.astype("float32")/255
x_train=x_train.reshape((60000,48,48,3))
x_test=x_test.astype("float32")/255
x_test=x_test.reshape((10000,48,48,3))
y_train=to_categorical(y_train)
y_test=to_categorical(y_test)

#划出验证集
x_val=x_train[:10000]
y_val=y_train[:10000]
x_train=x_train[10000:]
y_train=y_train[10000:]

#建立模型
conv_base=VGG16(weights='imagenet',
				include_top=False,
				input_shape=(48,48,3))
conv_base.trainable=False
model=models.Sequential()
model.add(conv_base)
model.add(Flatten())
model.add(Dense(4096,activation="relu"))
model.add(Dropout(0.5))
# layer 14
model.add(Dense(4096, activation="relu"))
model.add(Dropout(0.5))
# layer 15
model.add(Dense(10,activation="softmax"))
model.summary()

#编译模型
model.compile(optimizer="rmsprop",loss="categorical_crossentropy",metrics=["accuracy"])

#训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5,validation_data=(x_val,y_val))

#评估模型
test_loss,test_acc=model.evaluate(x_test,y_test,batch_size=64)
print("The accuracy is:"+str(test_acc))

在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_41078023/article/details/105509700

智能推荐

CC2530定时器1使能_t1cctl0-程序员宅基地

文章浏览阅读5.8k次,点赞4次,收藏19次。CC2530定时器1CC2530有5个定时器(定时器1,定时器2,定时器3,定时器4,睡眠定时器),定时器1是一个16位定时器,具有一个可编程的分频器,一个16位周期值,和五个各自可编程的计数器/捕获通道,每个都有一个16位比较值。..._t1cctl0

jsp中9个隐含对象-程序员宅基地

文章浏览阅读63次。在JSP中一共有9个隐含对象,这个9个对象我可以在JSP中直接使用。因为在service方法已经对这个九个隐含对象进行声明及赋值,所以可以在JSP中直接使用。 - pageContext 类型:PageContext 代表:当前页面的上下文 作用:可以获取到页面中的其他隐含对象,同时它还是一个域对象。 - request 类型:HttpServlet..._jsp中9个隐含对象

Android开发之打包APK详解_安卓打包apk-程序员宅基地

文章浏览阅读1.6w次,点赞23次,收藏141次。Android开发之打包APK详解_安卓打包apk

springcloudgateway踩过404的坑笔记整理_gateway 404-程序员宅基地

文章浏览阅读3.7k次,点赞2次,收藏6次。错误回顾网关的配置server: port: 80spring: application: name: api-gateway#springCloudgateway配置项对相应 GatewayPropweties cloud: # 网关配置 gateway: # 路由配置:对应RouteDefinition数组 routes: - id: desk-route #路由的编号,保证是唯一的 # _gateway 404

c++备战CCF之力扣简单题(数组中两元素的最大乘积)_c++实现一行数中任意两个乘积最大-程序员宅基地

文章浏览阅读87次。例如nth_element(arr, arr+5, arr+10);是将从小到大排序后应该在arr[5]的元素放在arr[5]这个位置上。若要寻找第k大的数,nth_element(数组名,数组名+元素个数-k,数组名+元素个数)其用法为:函数语句:nth_element(数组名,数组名+第k个元素,数组名+元素个数)冒泡是每次将0~n-i范围内的最大数放在arr[n-1-i]位置,i代表了冒泡的次数。c++的STL里也有快速选择的函数nth_element()快速选择是用于分开较大的数和较小的数。_c++实现一行数中任意两个乘积最大

113-Linux_安装c/c++开发库及连接mysql数据库_linux安装c++ mysql库-程序员宅基地

文章浏览阅读786次。安装开发c/c++的库,命令:==apt install _linux安装c++ mysql库

随便推点

ESP32移植LVGL并将LVGL外部输入设备设置为物理按键_lvgl adc-button-程序员宅基地

文章浏览阅读3.1k次。前情提示:在上篇博文中,我将一个使用GUI-Guider生成的图形界面移植到了ESP32设备上。显示成功,详情参见:如何将使用GUI-Guider生成的LVGL移植到ESP32https://blog.csdn.net/QTRPio/article/details/124120432背景:但是我使用GUI-Guider生成的图形界面还包含了几个图片按键: 条码识别; 手势识别;..._lvgl adc-button

BLE协议架构概述(1)_ble 协议-程序员宅基地

文章浏览阅读1.3w次,点赞2次,收藏12次。BLE 协议架构总体上分成3块,从下到上分别是:控制器(Controller),主机(Host) 和应用端(Apps);3者可以在同一芯片类实现,也可以分不同芯片内实现,控制器(Controller)是处理射频数据解析,接收和发送,主机(Host)是控制不同设备之间如何进行数据交换;应用端(Apps)实现具体应用。控制器ControllerController实现射频相关的模拟和数_ble 协议

围观了张一鸣近10年的微博,我整理了这231条干货_张一鸣微博干货-程序员宅基地

文章浏览阅读3.7k次,点赞73次,收藏284次。本文转载自 仟语仟寻,作者 霍仟这几天抽空把张一鸣的所有微博看了一遍,发现2010年的微博最好,就是他30岁左右的时候,那时候刚创业没多久,在微博上认真分享自己的思考和观点。到了2012年附近,开始做今日头条,每天都是大量转发今日头条上的文章到微博,干货变少了。后面就更新得越来越少。我从他的微博中,试图找到他成功的钥匙,得到了一些只言片语,但是仅仅是这些只言片语,我都觉得对我的启发很大。他从南开大学毕业,妻子是大学同学,毕业后去过微软,后来从微软离职,然后去过饭否,应该跟过王兴王慧文一阵子,后来_张一鸣微博干货

UniDAC使用教程(二):数据更新-程序员宅基地

文章浏览阅读881次。2019独角兽企业重金招聘Python工程师标准>>> ..._unidac 帮助

yum报:[Errno 14] curl#6 - “Could not resolve host: mirrors.cqu.edu.cn;Unknown error“未知的错误,正在尝试其他镜像类错误_could not resolve host: mirrors.cqu.edu.cn; unknow-程序员宅基地

文章浏览阅读1.1w次,点赞9次,收藏15次。报错信息:[root@localhost ~]# yum install bind bind-utils.x86_64 -y已加载插件:fastestmirror, langpacksCould not retrieve mirrorlist http://mirrorlist.centos.org/?release=7&arch=x86_64&repo=os&infra=stock error was14: curl#6 - "Could not resolve hos_could not resolve host: mirrors.cqu.edu.cn; unknown error

自动驾驶简介 转自: 智车科技_智能驾驶-程序员宅基地

文章浏览阅读484次。技术分级自动驾驶技术分为多个等级,目前国内外产业界采用较多的为美国汽车工程师协会(SAE)和美国高速公路安全管理局(NHTSA)推出的分类标准。按照SAE的标准,自动驾驶汽车视智能化、自动化程度水平分为6个等级:无自动化(L0)、驾驶支援(L1)、部分自动化(L2)、有条件自动化(L3)、高度自动化(L4)和完全自动化(L5)。两种不同分类标准的主要区别在于完全自动驾驶场景下,SAE更加细分了自动..._智能驾驶

推荐文章

热门文章

相关标签