This task is the "Hello world" in deep learning, which is used to classify ten types of numbers from 0~9, that is, the image of handwritten numbers can be entered to identify the numbers in this picture.
1. New Project (Pycharm) and install Paddle using PIP
python -m pip install paddlepaddle==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
2. Import paddle and check its version
import paddle
print(paddle.__version__)
3. Before starting
Install Python's $matplotlib$ library and $numpy$ library. The matplotlib library is used to visualize pictures, and the numpy is used to process data.
# using pip to install matplotlib and numpy
python -m pip install matplotlib numpy
4. Recognition
The dataset contains 60,000 traning images, 10,000 test images, and corresponding classification label files, each of which is a handwritten number of 0~9 with a solution of 28*28.
Deep learning tasks are generally divided into the following core steps:
(1)数据集定义和加载; Data Definition and loading;
(2)模型组网; Model networking;
(3)模型训练和评估; Model training and evaluation;
(4)模型推理; Model inference;
4.1 Dataset Definition and Loading
Paddle在$paddle.vision.datasets$下内置了Computer vision领域中的常见数据集,如MNIST、Cifar10、Cifar100、FashionMNIST等。 In this task, the traing set and test set of MNIST are loaded. The training set is used to train the model, and the test set is used to evaluated the model performance.
import paddle
from paddle.vision.transforms import Normalize
transform = Normalize(mean=[127.5], std=[127.5], data_format="CHW")
# 下载数据集并初始化 DataSet
train_dataset = paddle.vision.datasets.MNIST(mode="train", transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode="test", transform=transform)
# 打印数据集里图片数量
print(
"{} images in train_dataset, {} images in test_dataset".format(
len(train_dataset), len(test_dataset)
)
)
4.2 模型组网
There are multiple ways to netywork the model of Paddlepaddle, which can directly use the built-in model of Paddlepaddle or customize the networking.
This task is relatively simple, and ordinary neural networks can achieve high accuracy. The built-in model $LeNet$ of Paddle is used in this task. Paddle has some calssic models in the CV field built into $paddle.vision.models$. "num_classes" is the number of catgories for classification, which is set to 10 (numbers from 0 to 9).
In addition, $paddle.summary$ makes it easy to print the infrastructure and paramenter information of the network.
# 模型组网并初始化网络
lenet = paddle.vision.models.LeNet(num_classes=10)
# 可视化模型组网结构和参数
paddle.summary(lenet, (1, 1, 28, 28))
4.3 模型训练和评估
4.3.1 模型训练
(1)Package model: 使用$paddle.Model$封装模型。将网络结构组合成可快速使用paddle高层API进行训练、评估和推理的实例,方便后续操作。
(2)Use $paddle. Model.prepare$ to complete the configuration of the training. 配置准备工作。包括损失函数、优化器和评价指标等。Paddle在$paddle.optimizer$下提供了优化器算法相关的API;在$paddle.nn$ Loss层提供了损失函数相关的API;在$paddle.matric$下提供了评价指标相关的API。
(3)Configue the loop parameters and strat the traning. The configuration parameters include the data source, the bacth size and the number of training rounds. And the traning cycle of the model will be automatically completed after execution.
因为是分类任务,这里的损失函数使用常见的CrossEntropyLoss(交叉熵损失函数) ,有坏去使用Adam,评价指标使用Accuracy来计算模型在训练集上的精度。
# 封装模型,便于进行后续的训练、评估和推理
model = paddle.Model(lenet)
# 模型训练的配置准备,准备损失函数,优化器和评价指标
model.prepare(
paddle.optimizer.Adam(parameters=model.parameters()),
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy(),
)
# 开始训练
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
4.3.2 模型评估
After the model is trained, uses a pre-defined test dataset to evalutate the performance of the trained model. 评估完成后将输出模型在测试集上的损失函数值 loss 和精度 acc。
# 进行模型评估
model.evaluate(test_dataset, batch_size=64, verbose=1)
4.4 模型推理
4.4.1 模型保存
在Paddle中调用$paddle.Model.save$保存模型,其中output为模型保存的文件夹名称,minst为保存的模型文件名称。
# 保存模型,文件夹会自动创建
model.save("./output/mnist")
4.4.2 模型加载并执行推理
执行模型推理时,可调用 $paddle.Model.load $加载模型,然后即可通过 $paddle.Model.predict_batch$执行推理操作。
针对前面创建的 model 网络加载保存的参数文件 output/mnist,并选择测试集中的一张图片 test_dataset[0] 作为输入,执行推理并打印结果,可以看到推理的结果与可视化图片一致。
5. 总结
from: https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html
Tips:
Reference