博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras实现CIFAR-10分类
阅读量:6124 次
发布时间:2019-06-21

本文共 3510 字,大约阅读时间需要 11 分钟。

  仅仅为了学习Keras的使用,使用一个四层的全连接网络对MNIST数据集进行分类,网络模型各层结点数为:3072: : 1024 : 512:10;

  使用50000张图片进行训练,10000张测试:

precision    recall  f1-score   support    airplane       0.61      0.69      0.65      1000  automobile       0.69      0.67      0.68      1000        bird       0.43      0.49      0.45      1000         cat       0.40      0.32      0.36      1000        dear       0.49      0.50      0.50      1000         dog       0.45      0.48      0.47      1000        frog       0.58      0.65      0.61      1000       horse       0.63      0.60      0.62      1000        ship       0.72      0.66      0.69      1000       truck       0.63      0.58      0.60      1000   micro avg       0.56      0.56      0.56     10000   macro avg       0.56      0.56      0.56     10000weighted avg       0.56      0.56      0.56     10000

训练过程中,损失和正确率曲线:

  可以看到,训练集的损失在一直降低,而测试集的损失出现大范围波动,并趋于上升,说明在一些epoch之后,出现过拟合;

  训练集的正确率也在一直上升,并接近100%;而测试集的正确率达到50%就趋于平稳了;

Training_Loss_and_Accuracy_CIFAR10.png

代码:

#!/usr/bin/env python# -*- coding: utf-8 -*-# @Time : 19-5-9"""implement classification for CIFAR-10 with Keras"""__author__ = 'Zhen Chen'# import the necessary packagesfrom sklearn.preprocessing import LabelBinarizerfrom sklearn.metrics import classification_reportfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import SGDfrom keras.datasets import cifar10import matplotlib.pyplot as pltimport numpy as npimport argparse# construct the argument parse and parse the argumentsparser = argparse.ArgumentParser()parser.add_argument("-o", "--output", default="./Training Loss and Accuracy_CIFAR10.png")args = parser.parse_args()# load the training and testing data, scale it into the range [0, 1],# then reshape the design matrixprint("[INFO] loading CIFAR-10 data...")((trainX, trainY), (testX, testY)) = cifar10.load_data()trainX = trainX.astype("float") / 255.0testX = testX.astype("float") / 255.0trainX = trainX.reshape((trainX.shape[0], 3072))testX = testX.reshape((testX.shape[0], 3072))# convert the labels from integers to vectorslb = LabelBinarizer()trainY = lb.fit_transform(trainY)testY = lb.fit_transform(testY)# initialize the label names for the CIFAR-10 datasetlabelNames = ["airplane", "automobile", "bird", "cat", "dear", "dog", "frog", "horse", "ship", "truck"]# define the 2072-1024-512-10 architecture Kerasmodel = Sequential()model.add(Dense(1024, input_shape=(3072,), activation="relu"))model.add(Dense(512, activation="relu"))model.add(Dense(10, activation="softmax"))# train the model using SGDprint("[INFO] training network...")sgd = SGD(0.01)model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["accuracy"])H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=32)model.save_weights('SGD_100_32.h5')# evaluate the networkprint("[INFO] evaluating network...")predictions = model.predict(testX, batch_size=32)print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), target_names=labelNames))# plot the training losss and accuracyplt.style.use("ggplot")plt.figure()plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss")plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc")plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on CIRFAR-10")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")plt.legend()plt.savefig(args.output)

转载于:https://www.cnblogs.com/chenzhen0530/p/10837622.html

你可能感兴趣的文章
网卡驱动程序之框架(一)
查看>>
css斜线
查看>>
Windows phone 8 学习笔记(3) 通信
查看>>
重新想象 Windows 8 Store Apps (18) - 绘图: Shape, Path, Stroke, Brush
查看>>
Revit API找到风管穿过的墙(当前文档和链接文档)
查看>>
Scroll Depth – 衡量页面滚动的 Google 分析插件
查看>>
Windows 8.1 应用再出发 - 视图状态的更新
查看>>
自己制作交叉编译工具链
查看>>
Qt Style Sheet实践(四):行文本编辑框QLineEdit及自动补全
查看>>
[物理学与PDEs]第3章习题1 只有一个非零分量的磁场
查看>>
深入浅出NodeJS——数据通信,NET模块运行机制
查看>>
onInterceptTouchEvent和onTouchEvent调用时序
查看>>
android防止内存溢出浅析
查看>>
4.3.3版本之引擎bug
查看>>
SQL Server表分区详解
查看>>
使用FMDB最新v2.3版本教程
查看>>
SSIS从理论到实战,再到应用(3)----SSIS包的变量,约束,常用容器
查看>>
STM32启动过程--启动文件--分析
查看>>
垂死挣扎还是涅槃重生 -- Delphi XE5 公布会归来感想
查看>>
淘宝的几个架构图
查看>>