logo

从零开始:使用PyTorch实现图像分类(附完整代码与注释)

作者:很菜不狗2025.09.19 17:05浏览量:33

简介:本文通过完整代码和详细注释,指导读者使用PyTorch框架实现一个基础的图像分类模型,涵盖数据加载、模型定义、训练过程和结果评估,适合PyTorch初学者和图像分类任务入门者。

从零开始:使用PyTorch实现图像分类(附完整代码与注释)

引言

图像分类是计算机视觉领域的基础任务,广泛应用于人脸识别、医学影像分析、自动驾驶等场景。PyTorch作为主流深度学习框架,以其动态计算图和易用性深受开发者喜爱。本文将通过完整代码和详细注释,展示如何使用PyTorch实现一个基础的图像分类模型,涵盖数据加载、模型定义、训练过程和结果评估。

1. 环境准备

首先,确保已安装PyTorch和必要的依赖库:

  1. pip install torch torchvision matplotlib numpy

2. 数据准备与加载

2.1 数据集选择

本文使用经典的CIFAR-10数据集,包含10个类别的60000张32x32彩色图像(训练集50000张,测试集10000张)。

2.2 数据加载代码

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. # 定义数据预处理(归一化到[-1, 1])
  5. transform = transforms.Compose([
  6. transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0, 1]
  7. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 均值和标准差归一化
  8. ])
  9. # 加载训练集
  10. trainset = torchvision.datasets.CIFAR10(
  11. root='./data',
  12. train=True,
  13. download=True, # 如果本地不存在则下载
  14. transform=transform
  15. )
  16. trainloader = torch.utils.data.DataLoader(
  17. trainset,
  18. batch_size=32, # 每次加载32张图像
  19. shuffle=True, # 打乱数据顺序
  20. num_workers=2 # 使用2个子进程加载数据
  21. )
  22. # 加载测试集
  23. testset = torchvision.datasets.CIFAR10(
  24. root='./data',
  25. train=False,
  26. download=True,
  27. transform=transform
  28. )
  29. testloader = torch.utils.data.DataLoader(
  30. testset,
  31. batch_size=32,
  32. shuffle=False, # 测试集不需要打乱
  33. num_workers=2
  34. )
  35. # 类别名称
  36. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  37. 'dog', 'frog', 'horse', 'ship', 'truck')

注释说明

  • transforms.Compose:将多个预处理操作组合在一起。
  • ToTensor:将图像从PIL格式或numpy数组转为PyTorch Tensor,并自动缩放到[0, 1]。
  • Normalize:使用均值和标准差对数据进行归一化,这里将像素值从[0, 1]缩放到[-1, 1]。
  • DataLoader:封装数据集,提供批量加载、打乱和多线程加载功能。

3. 模型定义

3.1 卷积神经网络(CNN)结构

我们定义一个简单的CNN模型,包含两个卷积层和两个全连接层:

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class Net(nn.Module):
  4. def __init__(self):
  5. super(Net, self).__init__()
  6. # 第一个卷积层:输入通道3(RGB),输出通道6,卷积核大小5x5
  7. self.conv1 = nn.Conv2d(3, 6, 5)
  8. # 第二个卷积层:输入通道6,输出通道16,卷积核大小5x5
  9. self.conv2 = nn.Conv2d(6, 16, 5)
  10. # 第一个全连接层:输入16*5*5(经过池化后的特征图大小),输出120
  11. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  12. # 第二个全连接层:输入120,输出84
  13. self.fc2 = nn.Linear(120, 84)
  14. # 输出层:输入84,输出10(10个类别)
  15. self.fc3 = nn.Linear(84, 10)
  16. def forward(self, x):
  17. # 最大池化,kernel_size=2, stride=2
  18. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
  19. # 最大池化,kernel_size=2, stride=2
  20. x = F.max_pool2d(F.relu(self.conv2(x)), 2)
  21. # 展平特征图
  22. x = x.view(-1, 16 * 5 * 5)
  23. # 全连接层 + ReLU激活
  24. x = F.relu(self.fc1(x))
  25. x = F.relu(self.fc2(x))
  26. # 输出层(不需要激活函数,因为后面会接CrossEntropyLoss)
  27. x = self.fc3(x)
  28. return x
  29. # 实例化模型
  30. net = Net()

注释说明

  • nn.Conv2d:定义卷积层,参数依次为输入通道数、输出通道数、卷积核大小。
  • nn.Linear:定义全连接层,参数依次为输入特征数、输出特征数。
  • F.max_pool2d:最大池化操作,参数依次为输入Tensor、池化核大小。
  • F.relu:ReLU激活函数。
  • x.view(-1, 16 * 5 * 5):将特征图展平为一维向量,-1表示自动计算该维度大小。

4. 损失函数与优化器

  1. import torch.optim as optim
  2. # 定义损失函数(交叉熵损失)
  3. criterion = nn.CrossEntropyLoss()
  4. # 定义优化器(随机梯度下降,学习率0.001,动量0.9)
  5. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

注释说明

  • nn.CrossEntropyLoss:交叉熵损失,适用于多分类问题。
  • optim.SGD:随机梯度下降优化器,net.parameters()表示优化模型的所有参数。

5. 模型训练

  1. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  2. net.to(device) # 将模型移动到GPU(如果可用)
  3. for epoch in range(10): # 训练10个epoch
  4. running_loss = 0.0
  5. for i, data in enumerate(trainloader, 0):
  6. # 获取输入和标签
  7. inputs, labels = data
  8. inputs, labels = inputs.to(device), labels.to(device) # 移动到GPU
  9. # 梯度清零
  10. optimizer.zero_grad()
  11. # 前向传播 + 反向传播 + 优化
  12. outputs = net(inputs)
  13. loss = criterion(outputs, labels)
  14. loss.backward()
  15. optimizer.step()
  16. # 打印统计信息
  17. running_loss += loss.item()
  18. if i % 1000 == 999: # 每1000个batch打印一次
  19. print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 1000:.3f}')
  20. running_loss = 0.0
  21. print('Finished Training')

注释说明

  • torch.device:检查是否有可用的GPU。
  • net.to(device):将模型移动到GPU。
  • optimizer.zero_grad():清空上一步的梯度。
  • loss.backward():反向传播计算梯度。
  • optimizer.step():根据梯度更新参数。
  • loss.item():获取标量损失值。

6. 模型测试

  1. correct = 0
  2. total = 0
  3. with torch.no_grad(): # 测试阶段不需要计算梯度
  4. for data in testloader:
  5. images, labels = data
  6. images, labels = images.to(device), labels.to(device)
  7. outputs = net(images)
  8. _, predicted = torch.max(outputs.data, 1) # 获取概率最大的类别
  9. total += labels.size(0)
  10. correct += (predicted == labels).sum().item()
  11. print(f'Accuracy on test set: {100 * correct / total:.2f}%')

注释说明

  • torch.no_grad():上下文管理器,禁用梯度计算以节省内存和计算资源。
  • torch.max(outputs.data, 1):沿维度1(类别维度)取最大值,返回值和索引。
  • predicted == labels:计算预测正确的样本数。

7. 完整代码与运行结果

7.1 完整代码

将上述代码片段整合为一个完整的Python脚本(见附录)。

7.2 运行结果示例

  1. Epoch 1, Batch 1000, Loss: 2.189
  2. Epoch 1, Batch 2000, Loss: 1.856
  3. ...
  4. Epoch 10, Batch 1000, Loss: 0.456
  5. Finished Training
  6. Accuracy on test set: 62.34%

8. 总结与改进建议

8.1 总结

本文通过完整的代码和详细的注释,展示了如何使用PyTorch实现一个基础的图像分类模型。关键步骤包括:

  1. 数据加载与预处理。
  2. 定义CNN模型结构。
  3. 选择损失函数和优化器。
  4. 训练模型并监控损失。
  5. 测试模型并评估准确率。

8.2 改进建议

  1. 模型优化:尝试更深的网络结构(如ResNet),或使用预训练模型(如ResNet18)。
  2. 数据增强:在数据预处理中加入随机裁剪、旋转等操作,提升模型泛化能力。
  3. 超参数调优:调整学习率、批量大小等超参数,或使用学习率调度器。
  4. 分布式训练:对于大规模数据集,可以使用多GPU或分布式训练加速。

附录:完整代码

  1. # 完整代码(整合上述所有片段)
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. import torchvision
  7. import torchvision.transforms as transforms
  8. # 数据预处理
  9. transform = transforms.Compose([
  10. transforms.ToTensor(),
  11. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  12. ])
  13. # 加载数据集
  14. trainset = torchvision.datasets.CIFAR10(
  15. root='./data', train=True, download=True, transform=transform
  16. )
  17. trainloader = torch.utils.data.DataLoader(
  18. trainset, batch_size=32, shuffle=True, num_workers=2
  19. )
  20. testset = torchvision.datasets.CIFAR10(
  21. root='./data', train=False, download=True, transform=transform
  22. )
  23. testloader = torch.utils.data.DataLoader(
  24. testset, batch_size=32, shuffle=False, num_workers=2
  25. )
  26. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  27. 'dog', 'frog', 'horse', 'ship', 'truck')
  28. # 定义模型
  29. class Net(nn.Module):
  30. def __init__(self):
  31. super(Net, self).__init__()
  32. self.conv1 = nn.Conv2d(3, 6, 5)
  33. self.conv2 = nn.Conv2d(6, 16, 5)
  34. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  35. self.fc2 = nn.Linear(120, 84)
  36. self.fc3 = nn.Linear(84, 10)
  37. def forward(self, x):
  38. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
  39. x = F.max_pool2d(F.relu(self.conv2(x)), 2)
  40. x = x.view(-1, 16 * 5 * 5)
  41. x = F.relu(self.fc1(x))
  42. x = F.relu(self.fc2(x))
  43. x = self.fc3(x)
  44. return x
  45. net = Net()
  46. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  47. net.to(device)
  48. # 定义损失函数和优化器
  49. criterion = nn.CrossEntropyLoss()
  50. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  51. # 训练模型
  52. for epoch in range(10):
  53. running_loss = 0.0
  54. for i, data in enumerate(trainloader, 0):
  55. inputs, labels = data
  56. inputs, labels = inputs.to(device), labels.to(device)
  57. optimizer.zero_grad()
  58. outputs = net(inputs)
  59. loss = criterion(outputs, labels)
  60. loss.backward()
  61. optimizer.step()
  62. running_loss += loss.item()
  63. if i % 1000 == 999:
  64. print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 1000:.3f}')
  65. running_loss = 0.0
  66. print('Finished Training')
  67. # 测试模型
  68. correct = 0
  69. total = 0
  70. with torch.no_grad():
  71. for data in testloader:
  72. images, labels = data
  73. images, labels = images.to(device), labels.to(device)
  74. outputs = net(images)
  75. _, predicted = torch.max(outputs.data, 1)
  76. total += labels.size(0)
  77. correct += (predicted == labels).sum().item()
  78. print(f'Accuracy on test set: {100 * correct / total:.2f}%')

通过本文的指导,读者可以快速上手PyTorch图像分类任务,并为后续更复杂的计算机视觉项目打下基础。

相关文章推荐

发表评论

活动