持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第31天,点击查看活动详情
前言
本文将介绍如何在 PyTorch 中构建一个简单的卷积神经网络,并训练它使用 MNIST 数据集识别手写数字,这将可以被看做是图像识别的 “Hello, World!”;
MNIST 包含 70,000 张手写数字图像:60,000 张用于训练,10,000 张用于测试。这些图像是灰度的,28x28 像素,居中以减少预处理并更快地开始。
配置环境
在本文中,我们将使用 PyTorch 训练卷积神经网络来识别 MNIST 的手写数字。 PyTorch 是一个非常流行的深度学习框架,如 Tensorflow、CNTK 和 Caffe2。但与这些其他框架不同,PyTorch 具有动态执行图,这意味着计算图是动态创建的。
1 | py复制代码import torch |
这里关于 PyTorch 的环境搭建就不再赘述了;
PyTorch 的官方文档链接:PyTorch documentation,在这里不仅有 API的说明还有一些经典的实例可供参考,中文文档点这!
准备数据集
完成环境导入之后,我们可以继续准备我们将使用的数据。
但在此之前,我们将定义我们将用于实验的超参数。在这里,epoch
的数量定义了我们将在整个训练数据集上循环多少次,而 learning_rate
和 momentum
是我们稍后将使用的优化器的超参数。
1 | py复制代码n_epochs = 3 |
对于可重复的实验,我们必须为任何使用随机数生成的值设置随机种子:numpy
和 random
;
而且,由于 cuDNN 使用非确定性算法,可以通过设置 torch.backends.cudnn.enabled = False
禁用该算法。
现在我们还需要数据集 DataLoaders,这就是 TorchVision 发挥作用的地方。它让我们以方便的方式使用加载 MNIST 数据集。下面用于 Normalize()
转换的值 0.1307 和 0.3081 是 MNIST 数据集的全局平均值和标准差,我们将在此处将它们作为给定值。
1 | py复制代码train_loader = torch.utils.data.DataLoader( |
1 | ruby复制代码Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz |
TIP: 如果你可以接受等待时间的话,可以改动 download=True
,不然的话,就自己先下载,然后在设置路径;
PyTorch 的 DataLoader
包含一些有趣的选项,而不是数据集和批量大小。例如,我们可以使用 num_workers > 1
来使用子进程异步加载数据或使用固定 RAM(via pin_memory)来加速 RAM 到 GPU 的传输。
使用数据集
接下来使用一下 test_loader
:
1 | py复制代码examples = enumerate(test_loader) |
所以一个测试数据批次是一个形状张量:这意味着我们有 1000 个 28x28 像素的灰度示例(即没有 rgb 通道,因此只有一个)。可以使用 matplotlib 绘制其中的一些:
1 | py复制代码import matplotlib.pyplot as plt |
后记
当你完成上述工作后,且一切正常,那么你的准备工作就完成了!接下来,就是要构建一个简单的卷积神经网络,并训练它使用 MNIST 数据集识别手写数字;
📝 上篇精讲:【项目实战】—— SSM 图书管理系统
💖 我是 𝓼𝓲𝓭𝓲𝓸𝓽,期待你的关注;
👍 创作不易,请多多支持;
本文转载自: 掘金