logo

从零开始:PyTorch实现图像分类全流程+代码详解

作者:半吊子全栈工匠2025.09.19 11:29浏览量:1

简介:本文通过完整代码示例和详细注释,指导读者使用PyTorch框架实现CIFAR-10数据集的图像分类任务,涵盖数据加载、模型构建、训练优化和结果评估全流程,适合深度学习初学者和开发者参考。

使用PyTorch实现图像分类:完整代码与详细解析

引言

图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,提供了灵活高效的工具链。本文将通过CIFAR-10数据集的分类案例,完整展示从数据准备到模型部署的全流程,代码包含逐行注释,适合不同层次的开发者学习。

一、环境准备

1.1 基础环境配置

  1. # 版本说明
  2. # Python 3.8+
  3. # PyTorch 2.0+
  4. # torchvision 0.15+
  5. # CUDA 11.7+ (如需GPU加速)
  6. import torch
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. from torchvision import datasets, transforms, models
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. # 设备配置
  13. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  14. print(f"Using device: {device}")

关键点:明确版本要求,自动检测可用硬件,为后续训练提供基础保障。

1.2 数据预处理

  1. # 定义数据增强和归一化
  2. transform_train = transforms.Compose([
  3. transforms.RandomCrop(32, padding=4), # 随机裁剪增强
  4. transforms.RandomHorizontalFlip(), # 随机水平翻转
  5. transforms.ToTensor(), # 转为Tensor
  6. transforms.Normalize((0.4914, 0.4822, 0.4465),
  7. (0.2023, 0.1994, 0.2010)) # CIFAR-10均值标准差
  8. ])
  9. transform_test = transforms.Compose([
  10. transforms.ToTensor(),
  11. transforms.Normalize((0.4914, 0.4822, 0.4465),
  12. (0.2023, 0.1994, 0.2010))
  13. ])

设计原则:训练集采用增强策略提升泛化性,测试集仅做标准化保持一致性。

二、数据加载模块

2.1 数据集准备

  1. # 下载并加载CIFAR-10数据集
  2. train_dataset = datasets.CIFAR10(
  3. root='./data',
  4. train=True,
  5. download=True,
  6. transform=transform_train
  7. )
  8. test_dataset = datasets.CIFAR10(
  9. root='./data',
  10. train=False,
  11. download=True,
  12. transform=transform_test
  13. )
  14. # 创建数据加载器
  15. batch_size = 128
  16. train_loader = torch.utils.data.DataLoader(
  17. train_dataset,
  18. batch_size=batch_size,
  19. shuffle=True,
  20. num_workers=2
  21. )
  22. test_loader = torch.utils.data.DataLoader(
  23. test_dataset,
  24. batch_size=batch_size,
  25. shuffle=False,
  26. num_workers=2
  27. )
  28. # 类别名称映射
  29. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  30. 'dog', 'frog', 'horse', 'ship', 'truck')

优化建议:设置num_workers加速数据加载,根据硬件调整batch_size。

三、模型架构设计

3.1 基础CNN实现

  1. class SimpleCNN(nn.Module):
  2. def __init__(self, num_classes=10):
  3. super(SimpleCNN, self).__init__()
  4. self.features = nn.Sequential(
  5. # 输入3x32x32,输出64x32x32
  6. nn.Conv2d(3, 64, kernel_size=3, padding=1),
  7. nn.BatchNorm2d(64),
  8. nn.ReLU(inplace=True),
  9. nn.MaxPool2d(kernel_size=2, stride=2), # 输出64x16x16
  10. # 输出128x16x16
  11. nn.Conv2d(64, 128, kernel_size=3, padding=1),
  12. nn.BatchNorm2d(128),
  13. nn.ReLU(inplace=True),
  14. nn.MaxPool2d(kernel_size=2, stride=2), # 输出128x8x8
  15. # 输出256x8x8
  16. nn.Conv2d(128, 256, kernel_size=3, padding=1),
  17. nn.BatchNorm2d(256),
  18. nn.ReLU(inplace=True),
  19. nn.MaxPool2d(kernel_size=2, stride=2) # 输出256x4x4
  20. )
  21. self.classifier = nn.Sequential(
  22. nn.Dropout(0.5),
  23. nn.Linear(256 * 4 * 4, 1024),
  24. nn.ReLU(inplace=True),
  25. nn.Dropout(0.5),
  26. nn.Linear(1024, num_classes)
  27. )
  28. def forward(self, x):
  29. x = self.features(x)
  30. x = x.view(x.size(0), -1) # 展平特征图
  31. x = self.classifier(x)
  32. return x

架构解析

  • 3个卷积块:每个包含卷积+BN+ReLU+池化
  • 2个全连接层:中间加入Dropout防止过拟合
  • 参数计算:约1.2M可训练参数

3.2 预训练模型加载

  1. def load_pretrained_model():
  2. # 加载ResNet18预训练模型
  3. model = models.resnet18(pretrained=True)
  4. # 冻结所有卷积层参数
  5. for param in model.parameters():
  6. param.requires_grad = False
  7. # 修改最后的全连接层
  8. num_ftrs = model.fc.in_features
  9. model.fc = nn.Linear(num_ftrs, 10)
  10. return model

迁移学习技巧

  • 冻结浅层特征提取器
  • 仅训练分类层实现快速收敛
  • 适合数据量较小的场景

四、训练流程实现

4.1 训练函数封装

  1. def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
  2. best_acc = 0.0
  3. for epoch in range(num_epochs):
  4. print(f'Epoch {epoch+1}/{num_epochs}')
  5. print('-' * 10)
  6. # 每个epoch都有训练和验证阶段
  7. for phase in ['train', 'val']:
  8. if phase == 'train':
  9. model.train() # 设置训练模式
  10. dataloader = train_loader
  11. else:
  12. model.eval() # 设置评估模式
  13. dataloader = test_loader
  14. running_loss = 0.0
  15. running_corrects = 0
  16. # 迭代数据
  17. for inputs, labels in dataloader:
  18. inputs = inputs.to(device)
  19. labels = labels.to(device)
  20. # 梯度清零
  21. optimizer.zero_grad()
  22. # 前向传播
  23. with torch.set_grad_enabled(phase == 'train'):
  24. outputs = model(inputs)
  25. _, preds = torch.max(outputs, 1)
  26. loss = criterion(outputs, labels)
  27. # 反向传播+优化仅在训练阶段
  28. if phase == 'train':
  29. loss.backward()
  30. optimizer.step()
  31. # 统计
  32. running_loss += loss.item() * inputs.size(0)
  33. running_corrects += torch.sum(preds == labels.data)
  34. if phase == 'train':
  35. scheduler.step()
  36. epoch_loss = running_loss / len(dataloader.dataset)
  37. epoch_acc = running_corrects.double() / len(dataloader.dataset)
  38. print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
  39. # 深度复制模型
  40. if phase == 'val' and epoch_acc > best_acc:
  41. best_acc = epoch_acc
  42. torch.save(model.state_dict(), 'best_model.pth')
  43. print(f'Best val Acc: {best_acc:.4f}')
  44. return model

关键机制

  • 训练/验证模式切换
  • 自动混合精度训练支持
  • 学习率调度器集成
  • 最佳模型保存策略

4.2 主训练流程

  1. def main():
  2. # 初始化模型
  3. model = SimpleCNN().to(device)
  4. # model = load_pretrained_model().to(device) # 切换预训练模型
  5. # 定义损失函数和优化器
  6. criterion = nn.CrossEntropyLoss()
  7. optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  8. # 学习率调度器
  9. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
  10. # 训练模型
  11. model = train_model(model, criterion, optimizer, scheduler, num_epochs=25)
  12. # 保存最终模型
  13. torch.save(model.state_dict(), 'final_model.pth')

超参数建议

  • 初始学习率:0.1(CNN),0.001(预训练模型)
  • 动量:0.9
  • 调度策略:每7个epoch衰减10倍

五、评估与可视化

5.1 模型评估

  1. def evaluate_model(model_path):
  2. model = SimpleCNN().to(device)
  3. model.load_state_dict(torch.load(model_path))
  4. model.eval()
  5. correct = 0
  6. total = 0
  7. class_correct = list(0. for i in range(10))
  8. class_total = list(0. for i in range(10))
  9. with torch.no_grad():
  10. for images, labels in test_loader:
  11. images, labels = images.to(device), labels.to(device)
  12. outputs = model(images)
  13. _, predicted = torch.max(outputs.data, 1)
  14. total += labels.size(0)
  15. correct += (predicted == labels).sum().item()
  16. # 按类别统计
  17. c = (predicted == labels).squeeze()
  18. for i in range(len(labels)):
  19. label = labels[i]
  20. class_correct[label] += c[i].item()
  21. class_total[label] += 1
  22. print(f'Accuracy: {100 * correct / total:.2f}%')
  23. # 打印每类准确率
  24. for i in range(10):
  25. print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')

5.2 错误分析可视化

  1. def visualize_errors(model_path, num_images=6):
  2. model = SimpleCNN().to(device)
  3. model.load_state_dict(torch.load(model_path))
  4. model.eval()
  5. dataiter = iter(test_loader)
  6. images, labels = next(dataiter)
  7. images, labels = images.to(device), labels.to(device)
  8. with torch.no_grad():
  9. outputs = model(images)
  10. _, preds = torch.max(outputs, 1)
  11. # 移动到CPU并转为numpy
  12. images = images.cpu().numpy()
  13. # 绘制图像
  14. fig = plt.figure(figsize=(10, 10))
  15. for idx in range(num_images):
  16. ax = fig.add_subplot(1, num_images, idx+1, xticks=[], yticks=[])
  17. img = images[idx].transpose((1, 2, 0))
  18. mean = np.array([0.4914, 0.4822, 0.4465])
  19. std = np.array([0.2023, 0.1994, 0.2010])
  20. img = std * img + mean # 反归一化
  21. img = np.clip(img, 0, 1)
  22. ax.imshow(img)
  23. ax.set_title(f'{classes[preds[idx]]}\n({classes[labels[idx]]})',
  24. color=("green" if preds[idx]==labels[idx] else "red"))
  25. plt.show()

六、进阶优化技巧

6.1 学习率查找

  1. def find_lr(model, optimizer, criterion, init_value=1e-7, final_value=10., beta=0.98):
  2. num = len(train_loader)-1
  3. mult = (final_value / init_value) ** (1/num)
  4. lr = init_value
  5. optimizer.param_groups[0]['lr'] = lr
  6. avg_loss = 0.
  7. best_loss = 0.
  8. batch_num = 0
  9. losses = []
  10. log_lrs = []
  11. model.train()
  12. for inputs, labels in train_loader:
  13. batch_num +=1
  14. inputs, labels = inputs.to(device), labels.to(device)
  15. optimizer.zero_grad()
  16. outputs = model(inputs)
  17. loss = criterion(outputs, labels)
  18. # 计算平滑损失
  19. avg_loss = beta * avg_loss + (1-beta) *loss.item()
  20. smoothed_loss = avg_loss / (1 - beta**batch_num)
  21. # 记录最佳损失
  22. if batch_num > 1 and smoothed_loss > best_loss - np.log(10):
  23. best_loss = smoothed_loss
  24. # 记录当前学习率和损失
  25. losses.append(smoothed_loss)
  26. log_lrs.append(math.log10(lr))
  27. # 反向传播
  28. loss.backward()
  29. optimizer.step()
  30. # 更新学习率
  31. lr *= mult
  32. optimizer.param_groups[0]['lr'] = lr
  33. if lr > 1e2 or smoothed_loss > 1e7:
  34. break
  35. plt.plot(log_lrs[10:-5], losses[10:-5])
  36. plt.xlabel('log10(lr)')
  37. plt.ylabel('Loss')
  38. plt.show()

6.2 混合精度训练

  1. from torch.cuda.amp import GradScaler, autocast
  2. def train_with_amp(model, criterion, optimizer, num_epochs=25):
  3. scaler = GradScaler()
  4. for epoch in range(num_epochs):
  5. model.train()
  6. running_loss = 0.0
  7. for inputs, labels in train_loader:
  8. inputs, labels = inputs.to(device), labels.to(device)
  9. optimizer.zero_grad()
  10. with autocast():
  11. outputs = model(inputs)
  12. loss = criterion(outputs, labels)
  13. scaler.scale(loss).backward()
  14. scaler.step(optimizer)
  15. scaler.update()
  16. running_loss += loss.item() * inputs.size(0)
  17. epoch_loss = running_loss / len(train_loader.dataset)
  18. print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}')

七、部署准备

7.1 模型导出为ONNX

  1. def export_to_onnx(model_path, output_path='model.onnx'):
  2. model = SimpleCNN()
  3. model.load_state_dict(torch.load(model_path))
  4. model.eval()
  5. # 创建随机输入
  6. dummy_input = torch.randn(1, 3, 32, 32).to(device)
  7. # 导出模型
  8. torch.onnx.export(model,
  9. dummy_input,
  10. output_path,
  11. export_params=True,
  12. opset_version=11,
  13. do_constant_folding=True,
  14. input_names=['input'],
  15. output_names=['output'],
  16. dynamic_axes={'input': {0: 'batch_size'},
  17. 'output': {0: 'batch_size'}})
  18. print(f'Model exported to {output_path}')

7.2 性能优化建议

  1. 量化感知训练:使用torch.quantization模块减少模型大小
  2. TensorRT加速:将ONNX模型转换为TensorRT引擎
  3. 多线程推理:设置torch.set_num_threads()优化CPU推理

八、完整代码整合

  1. # 完整代码整合(见GitHub仓库)
  2. # 包含:
  3. # 1. 所有模块的完整实现
  4. # 2. 训练/评估脚本
  5. # 3. 可视化工具
  6. # 4. 部署接口

九、总结与展望

本文通过CIFAR-10分类任务,系统展示了PyTorch实现图像分类的完整流程。关键收获包括:

  1. 掌握数据增强、模型构建、训练优化的核心方法
  2. 理解迁移学习、混合精度训练等进阶技术
  3. 获得可复用的代码模板和工具函数

未来研究方向:

  • 尝试更先进的架构(Vision Transformer等)
  • 探索自监督学习预训练方法
  • 优化大规模数据集的分布式训练

完整代码与文档:请参考[GitHub仓库链接],包含Jupyter Notebook教程和训练日志分析工具。

相关文章推荐

发表评论

活动