logo

从零开始:使用PyTorch实现图像分类(含完整代码与注释)

作者:很菜不狗2025.09.19 11:28浏览量:17

简介:本文详细讲解如何使用PyTorch框架实现图像分类任务,包含数据加载、模型构建、训练与评估的全流程代码,并附有逐行注释说明,适合PyTorch初学者和进阶开发者参考。

使用PyTorch实现图像分类(全流程详解)

图像分类是计算机视觉的基础任务之一,PyTorch作为主流深度学习框架,提供了简洁高效的API支持。本文将通过完整的代码示例,从数据准备到模型部署,详细讲解如何使用PyTorch实现图像分类。

一、环境准备与依赖安装

首先需要安装PyTorch及相关依赖库。推荐使用conda或pip安装:

  1. # 使用conda创建虚拟环境
  2. conda create -n pytorch_img_cls python=3.8
  3. conda activate pytorch_img_cls
  4. # 安装PyTorch(根据CUDA版本选择)
  5. pip install torch torchvision torchaudio
  6. # 安装其他依赖
  7. pip install numpy matplotlib tqdm

二、数据集准备与预处理

我们将使用CIFAR-10数据集(包含10个类别的6万张32x32彩色图像)作为示例。PyTorch的torchvision模块提供了便捷的数据加载方式:

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. # 定义数据预处理流程
  5. transform = transforms.Compose([
  6. transforms.ToTensor(), # 将PIL图像转换为Tensor,并归一化到[0,1]
  7. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]
  8. ])
  9. # 加载训练集和测试集
  10. trainset = torchvision.datasets.CIFAR10(
  11. root='./data',
  12. train=True,
  13. download=True, # 自动下载数据集
  14. transform=transform
  15. )
  16. testset = torchvision.datasets.CIFAR10(
  17. root='./data',
  18. train=False,
  19. download=True,
  20. transform=transform
  21. )
  22. # 创建数据加载器(支持批量加载和随机打乱)
  23. batch_size = 32
  24. trainloader = torch.utils.data.DataLoader(
  25. trainset,
  26. batch_size=batch_size,
  27. shuffle=True,
  28. num_workers=2
  29. )
  30. testloader = torch.utils.data.DataLoader(
  31. testset,
  32. batch_size=batch_size,
  33. shuffle=False,
  34. num_workers=2
  35. )
  36. # 类别名称
  37. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  38. 'dog', 'frog', 'horse', 'ship', 'truck')

关键点说明

  1. transforms.Compose:组合多个数据预处理操作
  2. Normalize:使用(均值,标准差)参数进行标准化,这里使用[0.5,0.5,0.5]将像素值从[0,1]映射到[-1,1]
  3. DataLoader:支持批量加载、随机打乱和多线程加载,num_workers建议设置为CPU核心数的一半

三、模型构建:CNN网络设计

我们将实现一个经典的CNN模型,包含卷积层、池化层和全连接层:

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class CNN(nn.Module):
  4. def __init__(self):
  5. super(CNN, self).__init__()
  6. # 卷积层1:输入通道3(RGB),输出通道32,3x3卷积核
  7. self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
  8. # 卷积层2:输入通道32,输出通道64,3x3卷积核
  9. self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  10. # 最大池化层:2x2窗口,步长2
  11. self.pool = nn.MaxPool2d(2, 2)
  12. # 全连接层1:输入64*8*8(经过两次池化后尺寸),输出512
  13. self.fc1 = nn.Linear(64 * 8 * 8, 512)
  14. # 全连接层2:输入512,输出10(类别数)
  15. self.fc2 = nn.Linear(512, 10)
  16. # Dropout层:防止过拟合
  17. self.dropout = nn.Dropout(0.25)
  18. def forward(self, x):
  19. # 第一层卷积+ReLU+池化
  20. x = self.pool(F.relu(self.conv1(x)))
  21. # 第二层卷积+ReLU+池化
  22. x = self.pool(F.relu(self.conv2(x)))
  23. # 展平特征图
  24. x = x.view(-1, 64 * 8 * 8)
  25. # 全连接层+ReLU+Dropout
  26. x = self.dropout(F.relu(self.fc1(x)))
  27. # 输出层(不使用激活函数,因为CrossEntropyLoss包含Softmax)
  28. x = self.fc2(x)
  29. return x
  30. # 初始化模型
  31. model = CNN()
  32. print(model) # 打印模型结构

模型设计要点

  1. 输入尺寸:CIFAR-10图像为32x32,经过两次2x2池化后变为8x8
  2. 通道数:从3(RGB)逐步增加到32→64,增强特征表达能力
  3. 激活函数:使用ReLU加速收敛
  4. 正则化:Dropout层随机丢弃25%的神经元

四、训练流程实现

完整的训练循环包含前向传播、损失计算、反向传播和参数更新:

  1. import torch.optim as optim
  2. from tqdm import tqdm # 进度条库
  3. # 定义损失函数和优化器
  4. criterion = nn.CrossEntropyLoss()
  5. optimizer = optim.Adam(model.parameters(), lr=0.001)
  6. # 设备配置(GPU加速)
  7. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  8. model.to(device)
  9. # 训练参数
  10. num_epochs = 10
  11. for epoch in range(num_epochs):
  12. running_loss = 0.0
  13. correct = 0
  14. total = 0
  15. # 训练模式(启用Dropout和BatchNorm)
  16. model.train()
  17. # 使用tqdm显示进度条
  18. loop = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
  19. for i, (inputs, labels) in enumerate(loop):
  20. # 移动数据到设备
  21. inputs, labels = inputs.to(device), labels.to(device)
  22. # 梯度清零
  23. optimizer.zero_grad()
  24. # 前向传播
  25. outputs = model(inputs)
  26. # 计算损失
  27. loss = criterion(outputs, labels)
  28. # 反向传播
  29. loss.backward()
  30. # 更新参数
  31. optimizer.step()
  32. # 统计信息
  33. running_loss += loss.item()
  34. _, predicted = torch.max(outputs.data, 1)
  35. total += labels.size(0)
  36. correct += (predicted == labels).sum().item()
  37. # 更新进度条信息
  38. loop.set_postfix(loss=running_loss/(i+1), acc=100*correct/total)
  39. # 打印每个epoch的统计信息
  40. train_loss = running_loss / len(trainloader)
  41. train_acc = 100 * correct / total
  42. print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
  43. # 验证模式(禁用Dropout和BatchNorm的训练行为)
  44. model.eval()
  45. val_loss = 0.0
  46. correct = 0
  47. total = 0
  48. with torch.no_grad(): # 不计算梯度
  49. for inputs, labels in testloader:
  50. inputs, labels = inputs.to(device), labels.to(device)
  51. outputs = model(inputs)
  52. loss = criterion(outputs, labels)
  53. val_loss += loss.item()
  54. _, predicted = torch.max(outputs.data, 1)
  55. total += labels.size(0)
  56. correct += (predicted == labels).sum().item()
  57. val_loss = val_loss / len(testloader)
  58. val_acc = 100 * correct / total
  59. print(f'Epoch {epoch+1}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
  60. print('训练完成')

训练技巧说明

  1. model.train()model.eval():切换模型模式,影响BatchNorm和Dropout的行为
  2. 学习率选择:Adam优化器通常使用0.001作为初始学习率
  3. 进度条:tqdm库提供可视化训练进度
  4. 设备选择:自动检测并使用GPU加速

五、模型评估与可视化

训练完成后,我们需要评估模型性能并可视化结果:

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. # 函数:显示一批图像及其预测结果
  4. def imshow(img):
  5. img = img / 2 + 0.5 # 反归一化
  6. npimg = img.numpy()
  7. plt.imshow(np.transpose(npimg, (1, 2, 0)))
  8. plt.show()
  9. # 获取一个批次的图像和标签
  10. dataiter = iter(testloader)
  11. images, labels = next(dataiter)
  12. images, labels = images.to(device), labels.to(device)
  13. # 预测
  14. outputs = model(images)
  15. _, predicted = torch.max(outputs, 1)
  16. # 显示图像和真实/预测标签
  17. imshow(torchvision.utils.make_grid(images.cpu()))
  18. print('真实标签:', ' '.join(f'{classes[labels[j]]}' for j in range(batch_size)))
  19. print('预测标签:', ' '.join(f'{classes[predicted[j]]}' for j in range(batch_size)))
  20. # 绘制训练曲线
  21. def plot_metrics(train_losses, train_accs, val_losses, val_accs):
  22. plt.figure(figsize=(12, 4))
  23. plt.subplot(1, 2, 1)
  24. plt.plot(train_losses, label='Train Loss')
  25. plt.plot(val_losses, label='Val Loss')
  26. plt.xlabel('Epoch')
  27. plt.ylabel('Loss')
  28. plt.legend()
  29. plt.subplot(1, 2, 2)
  30. plt.plot(train_accs, label='Train Acc')
  31. plt.plot(val_accs, label='Val Acc')
  32. plt.xlabel('Epoch')
  33. plt.ylabel('Accuracy (%)')
  34. plt.legend()
  35. plt.tight_layout()
  36. plt.show()
  37. # 这里需要在实际代码中收集各epoch的loss和acc
  38. # 示例数据(实际应从训练过程中收集)
  39. train_losses = [2.3, 1.8, 1.5, 1.2, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
  40. train_accs = [30, 45, 55, 65, 70, 75, 78, 80, 82, 84]
  41. val_losses = [2.1, 1.7, 1.4, 1.1, 0.95, 0.85, 0.78, 0.72, 0.68, 0.65]
  42. val_accs = [35, 50, 60, 68, 72, 76, 79, 81, 83, 85]
  43. plot_metrics(train_losses, train_accs, val_losses, val_accs)

六、模型保存与加载

训练好的模型可以保存为.pth文件,方便后续使用:

  1. # 保存模型
  2. PATH = './cifar_net.pth'
  3. torch.save(model.state_dict(), PATH)
  4. # 加载模型
  5. def load_model():
  6. # 重新初始化模型
  7. model = CNN()
  8. model.to(device)
  9. # 加载保存的权重
  10. model.load_state_dict(torch.load(PATH))
  11. # 设置为评估模式
  12. model.eval()
  13. return model
  14. # 使用加载的模型进行预测
  15. loaded_model = load_model()
  16. # ... 进行预测的代码与之前相同

七、进阶优化建议

  1. 数据增强:在transforms中添加随机裁剪、水平翻转等操作提升模型泛化能力

    1. transform_train = transforms.Compose([
    2. transforms.RandomHorizontalFlip(),
    3. transforms.RandomCrop(32, padding=4),
    4. transforms.ToTensor(),
    5. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    6. ])
  2. 学习率调度:使用StepLR或ReduceLROnPlateau动态调整学习率

    1. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    2. # 在每个epoch后调用scheduler.step()
  3. 更先进的模型:尝试ResNet、EfficientNet等预训练模型

    1. from torchvision.models import resnet18
    2. model = resnet18(pretrained=False, num_classes=10)
  4. 分布式训练:使用torch.nn.DataParallel进行多GPU训练

    1. if torch.cuda.device_count() > 1:
    2. print(f"使用 {torch.cuda.device_count()} 个GPU")
    3. model = nn.DataParallel(model)

八、完整代码整合

将上述所有部分整合为一个完整的可运行脚本:

  1. # 完整代码见前文各部分,此处省略重复内容
  2. # 建议将代码分为以下几个文件:
  3. # 1. model.py (定义CNN类)
  4. # 2. train.py (训练流程)
  5. # 3. utils.py (辅助函数如imshow)
  6. # 4. evaluate.py (模型评估)

九、总结与展望

本文详细介绍了使用PyTorch实现图像分类的完整流程,从数据加载到模型部署。关键要点包括:

  1. 数据预处理的重要性:标准化和增强操作显著影响模型性能
  2. CNN架构设计:合理的卷积层、池化层和全连接层组合
  3. 训练技巧:学习率选择、优化器选择和正则化方法
  4. 评估与可视化:准确监控训练过程,及时发现过拟合/欠拟合

未来可以探索的方向包括:

  • 使用Transformer架构(如ViT)进行图像分类
  • 尝试半监督或自监督学习方法
  • 部署模型到移动端或边缘设备

通过本文的学习,读者应该能够掌握PyTorch实现图像分类的核心技术,并能够根据实际需求调整模型结构和训练策略。

相关文章推荐

发表评论

活动