logo

使用PyTorch构建图像分类系统:完整代码与深度解析

作者:渣渣辉2025.09.19 17:05浏览量:0

简介:本文详细介绍如何使用PyTorch框架实现图像分类任务,包含从数据加载到模型训练的全流程代码,每行代码均附有详细注释,适合PyTorch初学者及有一定基础的开发者参考。

使用PyTorch构建图像分类系统:完整代码与深度解析

图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,提供了简洁高效的API支持。本文将通过完整代码示例,展示如何使用PyTorch实现从数据准备到模型部署的全流程,所有代码均包含详细注释,确保读者能够理解每个步骤的实现原理。

一、环境准备与依赖安装

首先需要安装PyTorch及相关依赖库。推荐使用conda创建虚拟环境:

  1. conda create -n pytorch_img_cls python=3.8
  2. conda activate pytorch_img_cls
  3. pip install torch torchvision matplotlib numpy

关键依赖说明:

  • torch:PyTorch核心库
  • torchvision:提供计算机视觉常用数据集和模型架构
  • matplotlib:用于可视化训练过程
  • numpy:数值计算基础库

二、数据集准备与预处理

1. 使用CIFAR-10数据集

CIFAR-10包含10个类别的60000张32x32彩色图像,分为50000张训练集和10000张测试集。

  1. import torchvision
  2. import torchvision.transforms as transforms
  3. # 定义数据预处理流程
  4. transform = transforms.Compose([
  5. transforms.ToTensor(), # 将PIL图像转换为Tensor,并归一化到[0,1]
  6. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]
  7. ])
  8. # 加载训练集
  9. trainset = torchvision.datasets.CIFAR10(
  10. root='./data',
  11. train=True,
  12. download=True,
  13. transform=transform
  14. )
  15. trainloader = torch.utils.data.DataLoader(
  16. trainset,
  17. batch_size=32,
  18. shuffle=True,
  19. num_workers=2
  20. )
  21. # 加载测试集
  22. testset = torchvision.datasets.CIFAR10(
  23. root='./data',
  24. train=False,
  25. download=True,
  26. transform=transform
  27. )
  28. testloader = torch.utils.data.DataLoader(
  29. testset,
  30. batch_size=32,
  31. shuffle=False,
  32. num_workers=2
  33. )
  34. # 类别名称
  35. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  36. 'dog', 'frog', 'horse', 'ship', 'truck')

关键点解析

  • transforms.Compose:组合多个数据预处理操作
  • ToTensor():将HWC格式的PIL图像转换为CHW格式的Tensor
  • Normalize:使用均值和标准差进行标准化,这里使用(0.5,0.5,0.5)将像素值映射到[-1,1]区间
  • DataLoader:实现批量加载、数据打乱和多线程加载

2. 自定义数据集加载

对于自定义数据集,可以继承torch.utils.data.Dataset类:

  1. from torch.utils.data import Dataset
  2. import os
  3. from PIL import Image
  4. class CustomImageDataset(Dataset):
  5. def __init__(self, img_dir, transform=None):
  6. self.img_labels = []
  7. self.img_paths = []
  8. self.transform = transform
  9. # 遍历目录,假设子目录名为类别名
  10. for class_name in os.listdir(img_dir):
  11. class_path = os.path.join(img_dir, class_name)
  12. if os.path.isdir(class_path):
  13. for img_name in os.listdir(class_path):
  14. self.img_paths.append(os.path.join(class_path, img_name))
  15. self.img_labels.append(classes.index(class_name))
  16. def __len__(self):
  17. return len(self.img_paths)
  18. def __getitem__(self, idx):
  19. img_path = self.img_paths[idx]
  20. image = Image.open(img_path)
  21. label = self.img_labels[idx]
  22. if self.transform:
  23. image = self.transform(image)
  24. return image, label

三、模型架构设计

1. 基础CNN模型

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class SimpleCNN(nn.Module):
  4. def __init__(self):
  5. super(SimpleCNN, self).__init__()
  6. # 输入通道3(RGB),输出通道32,3x3卷积核
  7. self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
  8. self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  9. self.pool = nn.MaxPool2d(2, 2) # 2x2最大池化
  10. self.fc1 = nn.Linear(64 * 8 * 8, 512) # CIFAR-10经过两次池化后为8x8
  11. self.fc2 = nn.Linear(512, 10) # 10个输出类别
  12. self.dropout = nn.Dropout(0.25)
  13. def forward(self, x):
  14. # 第一层卷积+ReLU+池化
  15. x = self.pool(F.relu(self.conv1(x)))
  16. # 第二层卷积+ReLU+池化
  17. x = self.pool(F.relu(self.conv2(x)))
  18. # 展平特征图
  19. x = x.view(-1, 64 * 8 * 8)
  20. # 全连接层+ReLU+Dropout
  21. x = self.dropout(F.relu(self.fc1(x)))
  22. # 输出层
  23. x = self.fc2(x)
  24. return x

架构解析

  • 两个卷积层提取空间特征,每个卷积层后接ReLU激活函数和最大池化
  • 两个全连接层完成分类,中间加入Dropout防止过拟合
  • 输入32x32x3图像,经过两次2x2池化后变为8x8x64特征图

2. 使用预训练模型

PyTorch提供了多种预训练模型,可通过torchvision.models加载:

  1. import torchvision.models as models
  2. def get_pretrained_model(model_name='resnet18', pretrained=True, num_classes=10):
  3. if model_name == 'resnet18':
  4. model = models.resnet18(pretrained=pretrained)
  5. # 修改最后一层全连接网络
  6. num_ftrs = model.fc.in_features
  7. model.fc = nn.Linear(num_ftrs, num_classes)
  8. elif model_name == 'vgg16':
  9. model = models.vgg16(pretrained=pretrained)
  10. num_ftrs = model.classifier[6].in_features
  11. model.classifier[6] = nn.Linear(num_ftrs, num_classes)
  12. else:
  13. raise ValueError("Unsupported model name")
  14. return model

四、训练流程实现

1. 完整训练代码

  1. import torch
  2. import torch.optim as optim
  3. from tqdm import tqdm # 进度条库
  4. def train_model(model, trainloader, testloader, criterion, optimizer, num_epochs=10):
  5. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  6. model.to(device)
  7. for epoch in range(num_epochs):
  8. # 训练阶段
  9. model.train()
  10. running_loss = 0.0
  11. correct = 0
  12. total = 0
  13. # 使用tqdm显示进度条
  14. train_loop = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
  15. for inputs, labels in train_loop:
  16. inputs, labels = inputs.to(device), labels.to(device)
  17. # 梯度清零
  18. optimizer.zero_grad()
  19. # 前向传播
  20. outputs = model(inputs)
  21. loss = criterion(outputs, labels)
  22. # 反向传播和优化
  23. loss.backward()
  24. optimizer.step()
  25. # 统计信息
  26. running_loss += loss.item()
  27. _, predicted = torch.max(outputs.data, 1)
  28. total += labels.size(0)
  29. correct += (predicted == labels).sum().item()
  30. # 更新进度条信息
  31. train_loop.set_postfix(loss=running_loss/(train_loop.n+1),
  32. acc=100.*correct/total)
  33. # 测试阶段
  34. test_loss, test_acc = evaluate_model(model, testloader, criterion, device)
  35. print(f'Epoch {epoch+1}, Train Loss: {running_loss/len(trainloader):.4f}, '
  36. f'Train Acc: {100*correct/total:.2f}%, '
  37. f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
  38. def evaluate_model(model, testloader, criterion, device):
  39. model.eval()
  40. test_loss = 0.0
  41. correct = 0
  42. total = 0
  43. with torch.no_grad():
  44. for inputs, labels in testloader:
  45. inputs, labels = inputs.to(device), labels.to(device)
  46. outputs = model(inputs)
  47. loss = criterion(outputs, labels)
  48. test_loss += loss.item()
  49. _, predicted = torch.max(outputs.data, 1)
  50. total += labels.size(0)
  51. correct += (predicted == labels).sum().item()
  52. return test_loss/len(testloader), 100*correct/total
  53. # 初始化模型
  54. model = SimpleCNN()
  55. # 或者使用预训练模型
  56. # model = get_pretrained_model('resnet18')
  57. # 定义损失函数和优化器
  58. criterion = nn.CrossEntropyLoss()
  59. optimizer = optim.Adam(model.parameters(), lr=0.001)
  60. # 开始训练
  61. train_model(model, trainloader, testloader, criterion, optimizer, num_epochs=10)

2. 关键训练参数说明

  • 学习率:控制参数更新步长,常用值为0.001(Adam)或0.01(SGD)
  • 批量大小:影响内存使用和梯度估计稳定性,CIFAR-10常用32或64
  • 优化器选择
    • Adam:自适应学习率,收敛快
    • SGD+Momentum:可能获得更好泛化性能
  • 损失函数:分类任务通常使用交叉熵损失

五、模型评估与可视化

1. 混淆矩阵实现

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn.metrics import confusion_matrix
  4. import seaborn as sns
  5. def plot_confusion_matrix(model, testloader, classes, device):
  6. model.eval()
  7. all_labels = []
  8. all_preds = []
  9. with torch.no_grad():
  10. for inputs, labels in testloader:
  11. inputs, labels = inputs.to(device), labels.to(device)
  12. outputs = model(inputs)
  13. _, predicted = torch.max(outputs, 1)
  14. all_labels.extend(labels.cpu().numpy())
  15. all_preds.extend(predicted.cpu().numpy())
  16. cm = confusion_matrix(all_labels, all_preds)
  17. plt.figure(figsize=(10,8))
  18. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
  19. xticklabels=classes, yticklabels=classes)
  20. plt.xlabel('Predicted')
  21. plt.ylabel('True')
  22. plt.title('Confusion Matrix')
  23. plt.show()
  24. # 调用示例
  25. plot_confusion_matrix(model, testloader, classes, device)

2. 训练过程可视化

  1. def plot_training_curve(train_losses, test_losses, train_accs, test_accs):
  2. plt.figure(figsize=(12, 4))
  3. plt.subplot(1, 2, 1)
  4. plt.plot(train_losses, label='Train Loss')
  5. plt.plot(test_losses, label='Test Loss')
  6. plt.xlabel('Epoch')
  7. plt.ylabel('Loss')
  8. plt.legend()
  9. plt.subplot(1, 2, 2)
  10. plt.plot(train_accs, label='Train Accuracy')
  11. plt.plot(test_accs, label='Test Accuracy')
  12. plt.xlabel('Epoch')
  13. plt.ylabel('Accuracy (%)')
  14. plt.legend()
  15. plt.tight_layout()
  16. plt.show()
  17. # 需要在训练过程中记录这些指标
  18. # 示例数据
  19. epochs = range(1, 11)
  20. train_losses = [2.3, 1.8, 1.5, 1.2, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
  21. test_losses = [2.1, 1.7, 1.4, 1.1, 0.95, 0.85, 0.78, 0.72, 0.68, 0.65]
  22. train_accs = [45, 58, 65, 70, 75, 78, 80, 82, 84, 85]
  23. test_accs = [50, 62, 68, 72, 76, 78, 80, 81, 82, 83]
  24. plot_training_curve(train_losses, test_losses, train_accs, test_accs)

六、模型部署建议

  1. 模型导出:使用torch.save保存模型参数

    1. torch.save(model.state_dict(), 'cifar_classifier.pth')
  2. 推理脚本示例

    1. def predict_image(image_path, model, transform, classes, device):
    2. image = Image.open(image_path)
    3. image = transform(image).unsqueeze(0).to(device)
    4. model.eval()
    5. with torch.no_grad():
    6. output = model(image)
    7. _, predicted = torch.max(output.data, 1)
    8. return classes[predicted.item()]
  3. 性能优化技巧

    • 使用混合精度训练(torch.cuda.amp)
    • 模型量化减少内存占用
    • 使用TensorRT加速推理

七、常见问题解决方案

  1. 训练不收敛

    • 检查学习率是否过大
    • 确认数据预处理是否正确
    • 尝试不同的优化器
  2. 过拟合问题

    • 增加数据增强
    • 添加Dropout层
    • 使用L2正则化
  3. GPU内存不足

    • 减小批量大小
    • 使用梯度累积
    • 清理缓存(torch.cuda.empty_cache())

本文完整代码可在GitHub获取,建议读者从简单CNN开始实践,逐步尝试预训练模型和更复杂的架构。通过调整超参数和观察训练曲线,可以深入理解深度学习模型的工作原理。

相关文章推荐

发表评论