logo

深度学习实战:使用PyTorch实现图像分类(含完整代码与注释)

作者:搬砖的石头2025.09.18 17:51浏览量:0

简介:本文详细介绍如何使用PyTorch框架实现图像分类任务,包含从数据加载到模型训练的全流程代码,并附有详细注释,适合PyTorch初学者和深度学习开发者参考。

深度学习实战:使用PyTorch实现图像分类(含完整代码与注释)

一、引言

图像分类是计算机视觉领域的核心任务之一,广泛应用于人脸识别、医学影像分析、自动驾驶等场景。PyTorch作为主流深度学习框架,以其动态计算图和简洁的API设计受到开发者青睐。本文将通过一个完整的图像分类案例,展示如何使用PyTorch实现从数据加载、模型构建到训练评估的全流程,代码均附有详细注释。

二、环境准备

2.1 安装依赖

  1. pip install torch torchvision matplotlib numpy
  • torch: PyTorch核心库
  • torchvision: 提供计算机视觉相关工具(数据集、模型、变换)
  • matplotlib: 用于可视化训练过程
  • numpy: 数值计算基础库

2.2 硬件要求

  • CPU或GPU(推荐NVIDIA GPU+CUDA加速)
  • 至少4GB内存(数据集较小时)

三、完整代码实现

3.1 数据准备与预处理

  1. import torch
  2. from torchvision import datasets, transforms
  3. from torch.utils.data import DataLoader
  4. # 定义数据变换(训练集和测试集不同)
  5. train_transform = transforms.Compose([
  6. transforms.RandomHorizontalFlip(), # 随机水平翻转
  7. transforms.RandomRotation(10), # 随机旋转(-10°~+10°)
  8. transforms.ToTensor(), # 转为Tensor并归一化到[0,1]
  9. transforms.Normalize((0.5,), (0.5,)) # 标准化到[-1,1]
  10. ])
  11. test_transform = transforms.Compose([
  12. transforms.ToTensor(),
  13. transforms.Normalize((0.5,), (0.5,))
  14. ])
  15. # 加载MNIST数据集(手写数字识别)
  16. train_dataset = datasets.MNIST(
  17. root='./data',
  18. train=True,
  19. download=True,
  20. transform=train_transform
  21. )
  22. test_dataset = datasets.MNIST(
  23. root='./data',
  24. train=False,
  25. download=True,
  26. transform=test_transform
  27. )
  28. # 创建数据加载器(批量大小64,4个工作进程)
  29. train_loader = DataLoader(
  30. train_dataset,
  31. batch_size=64,
  32. shuffle=True,
  33. num_workers=4
  34. )
  35. test_loader = DataLoader(
  36. test_dataset,
  37. batch_size=64,
  38. shuffle=False,
  39. num_workers=4
  40. )

关键点说明

  • transforms.Compose: 组合多个数据增强操作
  • RandomHorizontalFlip/RandomRotation: 提升模型泛化能力
  • Normalize: 使用均值和标准差进行标准化(MNIST的均值=0.5,标准差=0.5)
  • DataLoader: 实现批量加载、随机打乱和多线程加速

3.2 模型定义(CNN)

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class CNN(nn.Module):
  4. def __init__(self):
  5. super(CNN, self).__init__()
  6. # 卷积层1:输入1通道,输出16通道,3x3卷积核
  7. self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
  8. # 卷积层2:输入16通道,输出32通道
  9. self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
  10. # 全连接层
  11. self.fc1 = nn.Linear(32 * 7 * 7, 128) # 输入尺寸需计算
  12. self.fc2 = nn.Linear(128, 10) # 输出10类
  13. # 池化层和Dropout
  14. self.pool = nn.MaxPool2d(2, 2)
  15. self.dropout = nn.Dropout(0.25)
  16. def forward(self, x):
  17. # 卷积层1 + ReLU + 池化
  18. x = self.pool(F.relu(self.conv1(x)))
  19. # 卷积层2 + ReLU + 池化
  20. x = self.pool(F.relu(self.conv2(x)))
  21. # 展平特征图
  22. x = x.view(-1, 32 * 7 * 7)
  23. # 全连接层 + Dropout
  24. x = self.dropout(F.relu(self.fc1(x)))
  25. x = self.fc2(x)
  26. return x

模型结构解析

  1. 输入:28x28灰度图像(1通道)
  2. 卷积块1:3x3卷积→ReLU→2x2最大池化(输出16x14x14)
  3. 卷积块2:同上(输出32x7x7)
  4. 全连接层:3277→128→10(输出类别概率)
  5. Dropout:防止过拟合(训练时随机丢弃25%神经元)

3.3 训练流程

  1. def train_model():
  2. # 初始化模型、损失函数和优化器
  3. model = CNN()
  4. criterion = nn.CrossEntropyLoss()
  5. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  6. # 移动模型到GPU(如果可用)
  7. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  8. model.to(device)
  9. # 训练循环
  10. for epoch in range(10): # 训练10个epoch
  11. model.train()
  12. running_loss = 0.0
  13. for i, (images, labels) in enumerate(train_loader):
  14. # 移动数据到设备
  15. images, labels = images.to(device), labels.to(device)
  16. # 前向传播
  17. outputs = model(images)
  18. loss = criterion(outputs, labels)
  19. # 反向传播和优化
  20. optimizer.zero_grad()
  21. loss.backward()
  22. optimizer.step()
  23. # 统计损失
  24. running_loss += loss.item()
  25. if i % 100 == 99: # 每100个batch打印一次
  26. print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')
  27. running_loss = 0.0
  28. # 保存模型
  29. torch.save(model.state_dict(), 'mnist_cnn.pth')
  30. print('训练完成,模型已保存')

训练要点

  • CrossEntropyLoss: 多分类交叉熵损失
  • Adam优化器: 自适应学习率,适合初学者
  • model.train(): 启用Dropout等训练专用层
  • zero_grad(): 清除上一步的梯度
  • state_dict(): 保存模型参数(不包含结构)

3.4 评估与预测

  1. def evaluate_model():
  2. model = CNN()
  3. model.load_state_dict(torch.load('mnist_cnn.pth'))
  4. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  5. model.to(device)
  6. model.eval() # 切换到评估模式
  7. correct = 0
  8. total = 0
  9. with torch.no_grad(): # 禁用梯度计算
  10. for images, labels in test_loader:
  11. images, labels = images.to(device), labels.to(device)
  12. outputs = model(images)
  13. _, predicted = torch.max(outputs.data, 1)
  14. total += labels.size(0)
  15. correct += (predicted == labels).sum().item()
  16. print(f'测试集准确率: {100 * correct / total:.2f}%')
  17. # 单张图片预测示例
  18. def predict_image(image_path):
  19. from PIL import Image
  20. import numpy as np
  21. # 加载并预处理图片(假设是28x28灰度图)
  22. image = Image.open(image_path).convert('L')
  23. image = image.resize((28, 28))
  24. transform = transforms.Compose([
  25. transforms.ToTensor(),
  26. transforms.Normalize((0.5,), (0.5,))
  27. ])
  28. image = transform(image).unsqueeze(0) # 添加batch维度
  29. # 加载模型并预测
  30. model = CNN()
  31. model.load_state_dict(torch.load('mnist_cnn.pth'))
  32. model.eval()
  33. with torch.no_grad():
  34. output = model(image)
  35. _, predicted = torch.max(output.data, 1)
  36. print(f'预测结果: {predicted.item()}')

四、关键优化技巧

  1. 学习率调度
    1. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    2. # 在每个epoch后调用 scheduler.step()
  2. 早停机制:监控验证集损失,当连续3个epoch未改善时停止训练。
  3. 模型微调:使用预训练模型(如ResNet)的卷积基,仅训练最后的全连接层。

五、常见问题解决方案

  1. CUDA内存不足:减小batch_size或使用torch.cuda.empty_cache()
  2. 过拟合:增加Dropout比例、添加L2正则化或使用数据增强
  3. 收敛慢:尝试不同的学习率(如1e-3、5e-4)或优化器(SGD+Momentum)

六、扩展应用

  1. 多标签分类:修改输出层为nn.Sigmoid(),使用BCELoss
  2. 自定义数据集:继承torch.utils.data.Dataset实现__len____getitem__
  3. 分布式训练:使用torch.nn.parallel.DistributedDataParallel

七、总结

本文通过MNIST手写数字分类案例,完整展示了PyTorch实现图像分类的流程。关键步骤包括:

  1. 数据加载与增强
  2. CNN模型构建
  3. 训练循环与优化
  4. 模型评估与部署

读者可基于此代码框架,替换数据集和模型结构以适应不同任务(如CIFAR-10分类)。建议进一步探索:

  • 更复杂的模型架构(ResNet、EfficientNet)
  • 混合精度训练加速
  • 使用TensorBoard可视化训练过程

完整代码已通过PyTorch 1.12和Python 3.8验证,可直接运行。

相关文章推荐

发表评论