logo

从零开始:使用PyTorch实现图像分类(含完整代码与注释)

作者:有好多问题2025.09.18 17:01浏览量:0

简介:本文详细介绍如何使用PyTorch框架实现一个完整的图像分类模型,包含数据加载、模型构建、训练与评估全流程,并提供逐行代码注释,适合初学者快速入门。

从零开始:使用PyTorch实现图像分类(含完整代码与注释)

一、引言

图像分类是计算机视觉领域的核心任务之一,广泛应用于人脸识别、医学影像分析、自动驾驶等场景。PyTorch作为主流深度学习框架,以其动态计算图和简洁的API设计,成为研究者与开发者的首选工具。本文将通过一个完整的CIFAR-10数据集分类案例,详细讲解如何使用PyTorch实现图像分类,包含数据加载、模型构建、训练与评估全流程,并提供逐行代码注释。

二、环境准备

2.1 依赖安装

  1. pip install torch torchvision matplotlib numpy
  • torch:PyTorch核心库
  • torchvision:提供计算机视觉工具(数据集、模型架构、图像变换)
  • matplotlib:用于可视化训练过程
  • numpy:数值计算基础库

2.2 硬件要求

  • CPU:建议Intel i5及以上
  • GPU(可选):NVIDIA显卡(CUDA支持可加速训练)
  • 内存:8GB以上(CIFAR-10数据集约150MB)

三、完整代码实现

3.1 导入必要库

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision
  5. import torchvision.transforms as transforms
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  • torch.nn:定义神经网络层与模型
  • torch.optim:优化器(如SGD、Adam)
  • torchvision.transforms:图像预处理(归一化、裁剪等)

3.2 数据加载与预处理

  1. # 定义数据增强与归一化
  2. transform = transforms.Compose([
  3. transforms.ToTensor(), # 将PIL图像转为Tensor,范围[0,1]
  4. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1,1]
  5. ])
  6. # 加载CIFAR-10训练集与测试集
  7. trainset = torchvision.datasets.CIFAR10(
  8. root='./data', train=True, download=True, transform=transform)
  9. trainloader = torch.utils.data.DataLoader(
  10. trainset, batch_size=32, shuffle=True, num_workers=2)
  11. testset = torchvision.datasets.CIFAR10(
  12. root='./data', train=False, download=True, transform=transform)
  13. testloader = torch.utils.data.DataLoader(
  14. testset, batch_size=32, shuffle=False, num_workers=2)
  15. # 类别名称
  16. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  17. 'dog', 'frog', 'horse', 'ship', 'truck')

关键点

  • Compose:组合多个变换操作
  • Normalize:参数为(均值,标准差),CIFAR-10是RGB三通道
  • DataLoadershuffle=True打乱训练数据,num_workers加速数据加载

3.3 定义卷积神经网络(CNN)

  1. class CNN(nn.Module):
  2. def __init__(self):
  3. super(CNN, self).__init__()
  4. self.conv1 = nn.Conv2d(3, 6, 5) # 输入3通道,输出6通道,5x5卷积核
  5. self.pool = nn.MaxPool2d(2, 2) # 2x2最大池化
  6. self.conv2 = nn.Conv2d(6, 16, 5)
  7. self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层
  8. self.fc2 = nn.Linear(120, 84)
  9. self.fc3 = nn.Linear(84, 10) # 输出10类
  10. def forward(self, x):
  11. x = self.pool(torch.relu(self.conv1(x))) # 卷积+ReLU+池化
  12. x = self.pool(torch.relu(self.conv2(x)))
  13. x = x.view(-1, 16 * 5 * 5) # 展平为向量
  14. x = torch.relu(self.fc1(x))
  15. x = torch.relu(self.fc2(x))
  16. x = self.fc3(x)
  17. return x
  18. net = CNN()

架构解析

  1. 卷积层1:32x32输入 → 6通道特征图(28x28)
  2. 池化层1:特征图 → 14x14
  3. 卷积层2:14x14 → 16通道(10x10)
  4. 池化层2:特征图 → 5x5
  5. 全连接层:5x5x16=400维 → 120 → 84 → 10(输出类别)

3.4 定义损失函数与优化器

  1. criterion = nn.CrossEntropyLoss() # 交叉熵损失
  2. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 随机梯度下降
  • CrossEntropyLoss:适用于多分类问题,组合了LogSoftmaxNLLLoss
  • SGD:参数momentum=0.9可加速收敛

3.5 训练模型

  1. for epoch in range(10): # 训练10个epoch
  2. running_loss = 0.0
  3. for i, data in enumerate(trainloader, 0):
  4. inputs, labels = data
  5. # 梯度清零
  6. optimizer.zero_grad()
  7. # 前向传播+反向传播+优化
  8. outputs = net(inputs)
  9. loss = criterion(outputs, labels)
  10. loss.backward()
  11. optimizer.step()
  12. # 打印统计信息
  13. running_loss += loss.item()
  14. if i % 200 == 199: # 每200个batch打印一次
  15. print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 200:.3f}')
  16. running_loss = 0.0
  17. print('Training finished')

训练逻辑

  1. optimizer.zero_grad():清除上一步的梯度
  2. loss.backward():计算梯度
  3. optimizer.step():更新参数

3.6 测试模型

  1. correct = 0
  2. total = 0
  3. with torch.no_grad(): # 禁用梯度计算
  4. for data in testloader:
  5. images, labels = data
  6. outputs = net(images)
  7. _, predicted = torch.max(outputs.data, 1) # 取概率最大的类别
  8. total += labels.size(0)
  9. correct += (predicted == labels).sum().item()
  10. print(f'Accuracy on test set: {100 * correct / total:.2f}%')

评估指标

  • 准确率(Accuracy)= 正确预测数 / 总样本数

3.7 可视化训练过程(可选)

  1. # 记录每个epoch的损失(需在训练循环中修改代码)
  2. train_losses = []
  3. test_accuracies = []
  4. # 修改训练循环以记录损失
  5. for epoch in range(10):
  6. epoch_loss = 0.0
  7. for i, data in enumerate(trainloader, 0):
  8. # ...(原训练代码)
  9. epoch_loss += loss.item()
  10. train_losses.append(epoch_loss / len(trainloader))
  11. # 测试集评估
  12. correct = 0
  13. total = 0
  14. with torch.no_grad():
  15. for data in testloader:
  16. # ...(原测试代码)
  17. test_accuracies.append(100 * correct / total)
  18. # 绘制曲线
  19. plt.figure(figsize=(10, 5))
  20. plt.subplot(1, 2, 1)
  21. plt.plot(train_losses, label='Training Loss')
  22. plt.xlabel('Epoch')
  23. plt.ylabel('Loss')
  24. plt.legend()
  25. plt.subplot(1, 2, 2)
  26. plt.plot(test_accuracies, label='Test Accuracy')
  27. plt.xlabel('Epoch')
  28. plt.ylabel('Accuracy (%)')
  29. plt.legend()
  30. plt.show()

四、关键优化建议

  1. 数据增强:添加随机裁剪、水平翻转提升泛化能力
    1. transform = transforms.Compose([
    2. transforms.RandomHorizontalFlip(),
    3. transforms.RandomCrop(32, padding=4),
    4. transforms.ToTensor(),
    5. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    6. ])
  2. 学习率调度:使用torch.optim.lr_scheduler动态调整学习率
    1. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    2. # 在每个epoch后调用scheduler.step()
  3. 模型保存与加载
    1. torch.save(net.state_dict(), 'model.pth') # 保存模型参数
    2. net.load_state_dict(torch.load('model.pth')) # 加载参数

五、总结与扩展

本文通过CIFAR-10数据集展示了PyTorch实现图像分类的完整流程,核心步骤包括:

  1. 数据加载与预处理
  2. CNN模型定义
  3. 训练与优化
  4. 测试与评估

扩展方向

  • 尝试更深的网络(如ResNet)
  • 使用预训练模型(Transfer Learning)
  • 部署到移动端(PyTorch Mobile)

通过理解本例的代码逻辑,读者可快速迁移到其他图像分类任务(如MNIST、ImageNet),并进一步探索目标检测、语义分割等高级任务。

相关文章推荐

发表评论