【深度学习基础(3)】初识神经网络之深度学习hello world

文章目录

  • 一. 训练Keras中的MNIST数据集
  • 二. 工作流程
    • 1. 构建神经网络
    • 2. 准备图像数据
    • 3. 训练模型
    • 4. 利用模型进行预测
    • 5. (新数据上)评估模型精度

本节将首先给出一个神经网络示例,引出如下概念。了解完本节后,可以对神经网络在代码上的实现有一个整体的了解。

本节相关概念:

  • 样本
  • 标签
  • 层(layer)
  • 数据蒸馏
  • 密集连接
  • 10路softmax分类层
  • 编译(compilation)步骤的3个参数
  • 损失值、精度
  • 过拟合

我们来看一个神经网络的具体实例:使用Python的Keras库来学习手写数字分类。

在这个例子中,我们要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)。我们将使用MNIST数据集。你可以将“解决”MNIST问题看作深度学习的“Hello World”,用来验证你的算法正在按预期运行。下图给出了MNIST数据集的一些样本。

在这里插入图片描述

说明

在机器学习中,分类问题中的某个类别叫作类(class),数据点叫作样本(sample)与某个样本对应的类叫作标签(label)(即描述:样本属于哪个类别)。

 

你不需要现在就尝试在计算机上运行这个例子。之后的文章会具体分析。

 

一. 训练Keras中的MNIST数据集

from tensorflow.keras.datasets import mnist 

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。

图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。

 
看一下训练数据:

>>> train_images.shape 
 (60000, 28, 28) 
>>> len(train_labels) 
60000 
>>> train_labels 
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

 

再来看一下测试数据:

>>> test_images.shape
(10000, 28, 28) 
>>> len(test_labels) 
10000 
>>> test_labels 
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

 

二. 工作流程

我们的工作流程如下:
首先,将训练数据(train_images和train_labels)输入神经网络;
然后,神经网络学习将图像和标签关联在一起;
最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。
 

1. 构建神经网络

下面我们来构建神经网络,如下:

from tensorflow import keras 
from tensorflow.keras import layers 

model = keras.Sequential(
						 [ layers.Dense(512, activation="relu"), 
						 layers.Dense(10, activation="softmax") ])

 

神经网络的核心组件是层(layer)

具体来说,层从输入数据中提取表示。大多数深度学习工作涉及将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)。深度学习模型就像是处理数据的筛子,包含一系列越来越精细的数据过滤器(也就是层)。
 

本例中的模型包含2个Dense层,它们都是密集连接(也叫全连接)的神经层。

第2层是一个10路softmax分类层,它将返回一个由10个概率值(总和为1)组成的数组。每个概率值表示当前数字图像属于10个数字类别中某一个的概率。
 

在训练模型之前,我们还需要指定编译(compilation)步骤的3个参数

  • 优化器(optimizer):模型基于训练数据来自我更新的机制,其目的是提高模型性能。
  • 损失函数(loss function):模型如何衡量在训练数据上的性能,从而引导自己朝着正确的方向前进。
  • 在训练和测试过程中需要监控的指标(metric):本例只关心精度(accuracy),即正确分类的图像所占比例。后面两章会详细介绍损失函数和优化器的确切用途。

如下代码展示了编译步骤。


model.compile(
			  optimizer="rmsprop", 
			  loss="sparse_categorical_crossentropy", 
			  metrics=["accuracy"]
			  )

 

2. 准备图像数据

在开始训练之前,我们先对数据进行预处理,将其变换为模型要求的形状,并缩放到所有值都在[0, 1]区间。前面提到过,训练图像保存在一个uint8类型的数组中,其形状为(60000, 28, 28),取值区间为[0,255]。我们将把它变换为一个float32数组,其形状为(60000, 28 *28),取值范围是[0, 1]。

下面准备图像数据,如代码所示。

train_images = train_images.reshape((60000, 28 * 28)) 
train_images = train_images.astype("float32") / 255 

test_images = test_images.reshape((10000, 28 * 28)) 
test_images = test_images.astype("float32") / 255

 

3. 训练模型

在Keras中,通过调用模型的fit方法调用数据,训练模型。

>>> model.fit(train_images, train_labels, epochs=5, batch_size=128) 
Epoch 1/5 
60000/60000 [===========================] - 5s - loss: 0.2524 - acc: 0.9273 Epoch 2/5 
51328/60000 [=====================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692

训练过程中显示了两个数字:一个是模型在训练数据上的损失值(loss),另一个是模型在训练数据上的精度(acc)。我们很快就在训练数据上达到了0.989(98.9%)的精度。

现在我们得到了一个训练好的模型,可以利用它来预测新数字图像的类别概率(如下代码)。这些新数字图像不属于训练数据,比如可以是测试集中的数据。

 

4. 利用模型进行预测

>>> test_digits = test_images[0:10] 
>>> predictions = model.predict(test_digits) 
>>> predictions[0] 
array([1.0726176e-10, 1.6918376e-10, 6.1314843e-08, 8.4106023e-06, 2.9967067e-11, 3.0331331e-09, 8.3651971e-14, 9.9999106e-01, 2.6657624e-08, 3.8127661e-07], dtype=float32)

如上代码我们对11个test_images图片进行预测,是什么数字,我们拿到第一个图片预测的概率数组,其中索引为7时,概率最大(0.99999106,几乎等于1),所以根据我们的模型,这个数字一定是7。

>>> predictions[0].argmax() 
7 
>>> predictions[0][7] 
0.99999106

这里我们检查测试标签是否与之一致:

>>> test_labels[0] 

7

平均而言,我们的模型对这种前所未见的数字图像进行分类的效果如何?我们来计算在整个测试集上的平均精度,如下代码所示。

 

5. (新数据上)评估模型精度

>>> test_loss, test_acc = model.evaluate(test_images, test_labels) 
>>> print(f"test_acc: {test_acc}") 
test_acc: 0.9785

测试精度约为97.8%,比训练精度(98.9%)低不少。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。

过拟合是指机器学习模型在新数据上的性能往往比在训练数据上要差。

第一个例子到这里就结束了。你刚刚看到了如何用不到15行Python代码构建和训练一个神经网络,对手写数字进行分类。

 

之后的文章我们将详细描述每一个步骤的原理,并且将学到张量(输入模型的数据存储对象)、张量运算(层的组成要素)与梯度下降(可以让模型从训练示例中进行学习)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/592523.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】Yolov8使用心得

兜兜转转,原本以为和yolov没啥关系了,但是新公司又回到了yolov侧。 yolov8 可以用pip的方式安装,以package的三方软件包形式,隐藏了很多细节。当然你也可以从git上把全套代码down下来。 1.分类模型 1.1 改错误 位置&#xff1a…

区块链扩容:水平扩展 vs.垂直扩展

1. 引言 随着Rollups 的兴起,区块链扩容一直集中在模块化(modular)vs. 整体式(monolithic)之争。 如今,模块化与整体式这种一分为二的心理模型,已不适合于当前的扩容场景。本文,将展…

Python机器学习手册:从预处理到深度学习的实际解决方案

书籍:Machine Learning with Python Cookbook: Practical Solutions from Preprocessing to Deep Learning 作者:Kyle Gallatin,Chris Albon 出版:OReilly Media 书籍下载-《Python机器学习手册:从预处理到深度学习…

ASP.NET网上车辆档案管理系统

摘 要 本文采用基于Web的Asp.net技术,并与sql server 2000数据库相结合,研发了一套车辆档案管理系统。该系统扩展性好,易于维护。简化了车辆档案设计流程,去除了冗余信息。汽车销售企业可以通过本系统完成整个销售及售后所有档案…

IoTDB 入门教程 基础篇⑦——数据库管理工具 | DBeaver 连接 IoTDB

文章目录 一、前文二、下载iotdb-jdbc三、安装DBeaver3.1 DBeaver 下载3.2 DBeaver 安装 四、安装驱动五、连接数据库六、参考 一、前文 IoTDB入门教程——导读 二、下载iotdb-jdbc 下载地址org/apache/iotdb/iotdb-jdbc:https://maven.proxy.ustclug.org/maven2/o…

微信小程序 uniapp家庭食谱菜谱食材网上商城系统小程序ko137

随着生活节奏的不断加快,越来越多的人因为工作忙而没有时间自己出去订购喜欢的菜品。随着Internet的飞速发展,网络已经成为我们日常生活中必不可少的部分,越来越多的人也接受了电子商务这种快捷、方便的交易方式。网上订餐其独有的便捷性和直…

计算机网络——Dijkstra路由算法

实验目的 实现基于 Dijkstra 算法的路由软件 实验内容 网络拓扑如图所示 实验过程 先编写开辟应该图的空间,然后给点映射数字,构建图。程序获取用户输入的学号,构建图中边的权值。接下来程序从用户输入获取最短路径的搜索起点&#xff0…

Docker 中的 Nginx 服务为什么要启用 HTTPS

一安装容器 1 安装docker-20.10.17 2 安装所需的依赖 sudo yum install -y yum-utils device-mapper-persistent-data lvm23 添加Docker官方仓库 sudo yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo4 安装Docker CE 20.10.17 s…

【React】React-redux多组件间的状态传递

效果(部分完整代码在最底部): 编写 Person 组件 上面的 Count 组件,已经在前面几篇写过了,也可以直接翻到最底部看 首先我们需要在 containers 文件夹下编写 Person 组件的容器组件 首先我们需要编写 index.jsx 文件…

STM32G474 CMAKE VSCODE 开发环境搭建

本篇博文尝试搭建 stm32g474 的开发环境 一. 工具安装 1. 关于 MinGW、OpenOCD、Zadig 这些工具的下载和安装见 JlinkOpenOCDSTM32 Vscode 下载和调试环境搭建_vscode openocd stm32 jlink-CSDN博客 2. 导出一个 STM32 的 CMAKE 工程,这里略过。 3. 安装 ninja …

QT5之windowswidget_菜单栏+工具栏_核心控件_浮动窗口_模态对话框_标准对话框/文本对话框

菜单栏工具栏 新建工程基类是QMainWindow 1、 2、 3、 点.pro文件&#xff0c;添加配置 因为之后用到lambda&#xff1b; 在.pro文件添加配置c11 CONFIG c11 #不能加分号 添加头文件 #include <QMenuBar>//菜单栏的头文件 主窗口代码mainwindow.cpp文件 #include &q…

深入理解分布式事务⑨ ---->MySQL 事务的实现原理 之 MySQL 中的XA 事务(基本原理、流程分析、事务语法、简单例子演示)详解

目录 MySQL 事务的实现原理 之 MySQL 中的XA 事务&#xff08;基本原理、流程分析、事务语法、简单例子演示&#xff09;详解MySQL 中的 XA 事务1、XA 事务的基本原理1-1&#xff1a;XA 事务模型图&#xff1a;1-2&#xff1a;XA 事务模型的两阶段提交操作&#xff1a;Prepare …

「 网络安全常用术语解读 」通用漏洞报告框架CVRF详解

1. 背景 ICASI在推进多供应商协调漏洞披露方面处于领先地位&#xff0c;引入了通用漏洞报告框架&#xff08;Common Vulnerability Reporting Format&#xff0c;CVRF&#xff09;标准&#xff0c;制定了统一安全事件响应计划&#xff08;USIRP&#xff09;的原则&#xff0c;…

mysql 指定根目录 迁移根目录

mysql 指定根目录 迁移根目录 1、问题描述2、问题分析3、解决方法3.1、初始化mysql前就手动指定mysql根目录为一个大的分区(支持动态扩容)&#xff0c;事前就根本上解决mysql根目录空间不够问题3.1.0、方法思路3.1.1、卸载mariadb3.1.2、下载Mysql安装包3.1.3、安装Mysql 8.353…

ASP.NET 两种开发模式

1》》WebForm 开发模式 1. 服务器端控件 2. 一般处理程序html静态页Ajax 3. 一般处理程序html模板 如下图 2》》MVC 太复杂的系统&#xff0c;会造成Controller 过复杂。 后来就诞生了 MVP、MVVM等模式

腾讯云CentOS7使用Docker安装ElasticSearch与Kibana详细教程

文章目录 一、安装ElasticSearch二、安装Kibana 一、安装ElasticSearch 使用Docker拉取ElasticSearch镜像 这里版本选择的是7.15.2 docker pull docker.elastic.co/elasticsearch/elasticsearch:7.15.22. 查看ElasticSearch的镜像id docker images3. 创建ElasticSearch容器 …

目标跟踪—卡尔曼滤波

目标跟踪—卡尔曼滤波 卡尔曼滤波引入 滤波是将信号中特定波段频率滤除的操作&#xff0c;是抑制和防止干扰的一项重要措施。是根据观察某一随机过程的结果&#xff0c;对另一与之有关的随机过程进行估计的概率理论与方法。 历史上最早考虑的是维纳滤波&#xff0c;后来R.E.卡…

nn.GRU层输出:state与output的关系

在 GRU&#xff08;Gated Recurrent Unit&#xff09;中&#xff0c;output 和 state 都是由 GRU 层的循环计算产生的&#xff0c;它们之间有直接的关系。state 实际上是 output 中最后一个时间步的隐藏状态。 GRU 的基本公式 GRU 的核心计算包括更新门&#xff08;update gat…

从零开始学AI绘画,万字Stable Diffusion终极教程(四)

【第4期】图生图 欢迎来到SD的终极教程&#xff0c;这是我们的第四节课 这套课程分为六节课&#xff0c;会系统性的介绍sd的全部功能&#xff0c;让你打下坚实牢靠的基础 1.SD入门 2.关键词 3.Lora模型 4.图生图 5.controlnet 6.知识补充 在前面的课程中&#xff0c;我…

QT:QT窗口(一)

文章目录 菜单栏创建菜单栏在菜单栏中添加菜单创建菜单项添加分割线 工具栏创建工具栏设置停靠位置创建工具栏的同时指定停靠位置使用QToolBar类提供的setAllowedAreas函数来设置停靠位置 设置浮动属性设置移动属性 状态栏状态栏的创建在状态栏中显示实时消息在状态栏中显示永久…