从零开始:使用PyTorch实现图像分类(含完整代码与注释)
2025.09.18 17:01浏览量:0简介:本文详细介绍如何使用PyTorch框架实现一个完整的图像分类模型,包含数据加载、模型构建、训练与评估全流程,并提供逐行代码注释,适合初学者快速入门。
从零开始:使用PyTorch实现图像分类(含完整代码与注释)
一、引言
图像分类是计算机视觉领域的核心任务之一,广泛应用于人脸识别、医学影像分析、自动驾驶等场景。PyTorch作为主流深度学习框架,以其动态计算图和简洁的API设计,成为研究者与开发者的首选工具。本文将通过一个完整的CIFAR-10数据集分类案例,详细讲解如何使用PyTorch实现图像分类,包含数据加载、模型构建、训练与评估全流程,并提供逐行代码注释。
二、环境准备
2.1 依赖安装
pip install torch torchvision matplotlib numpy
torch
:PyTorch核心库torchvision
:提供计算机视觉工具(数据集、模型架构、图像变换)matplotlib
:用于可视化训练过程numpy
:数值计算基础库
2.2 硬件要求
- CPU:建议Intel i5及以上
- GPU(可选):NVIDIA显卡(CUDA支持可加速训练)
- 内存:8GB以上(CIFAR-10数据集约150MB)
三、完整代码实现
3.1 导入必要库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
torch.nn
:定义神经网络层与模型torch.optim
:优化器(如SGD、Adam)torchvision.transforms
:图像预处理(归一化、裁剪等)
3.2 数据加载与预处理
# 定义数据增强与归一化
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL图像转为Tensor,范围[0,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1,1]
])
# 加载CIFAR-10训练集与测试集
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=32, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(
testset, batch_size=32, shuffle=False, num_workers=2)
# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
关键点:
Compose
:组合多个变换操作Normalize
:参数为(均值,标准差),CIFAR-10是RGB三通道DataLoader
:shuffle=True
打乱训练数据,num_workers
加速数据加载
3.3 定义卷积神经网络(CNN)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # 输入3通道,输出6通道,5x5卷积核
self.pool = nn.MaxPool2d(2, 2) # 2x2最大池化
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) # 输出10类
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x))) # 卷积+ReLU+池化
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5) # 展平为向量
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
net = CNN()
架构解析:
- 卷积层1:32x32输入 → 6通道特征图(28x28)
- 池化层1:特征图 → 14x14
- 卷积层2:14x14 → 16通道(10x10)
- 池化层2:特征图 → 5x5
- 全连接层:5x5x16=400维 → 120 → 84 → 10(输出类别)
3.4 定义损失函数与优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 随机梯度下降
CrossEntropyLoss
:适用于多分类问题,组合了LogSoftmax
和NLLLoss
SGD
:参数momentum=0.9
可加速收敛
3.5 训练模型
for epoch in range(10): # 训练10个epoch
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 梯度清零
optimizer.zero_grad()
# 前向传播+反向传播+优化
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 打印统计信息
running_loss += loss.item()
if i % 200 == 199: # 每200个batch打印一次
print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 200:.3f}')
running_loss = 0.0
print('Training finished')
训练逻辑:
optimizer.zero_grad()
:清除上一步的梯度loss.backward()
:计算梯度optimizer.step()
:更新参数
3.6 测试模型
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1) # 取概率最大的类别
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy on test set: {100 * correct / total:.2f}%')
评估指标:
- 准确率(Accuracy)= 正确预测数 / 总样本数
3.7 可视化训练过程(可选)
# 记录每个epoch的损失(需在训练循环中修改代码)
train_losses = []
test_accuracies = []
# 修改训练循环以记录损失
for epoch in range(10):
epoch_loss = 0.0
for i, data in enumerate(trainloader, 0):
# ...(原训练代码)
epoch_loss += loss.item()
train_losses.append(epoch_loss / len(trainloader))
# 测试集评估
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
# ...(原测试代码)
test_accuracies.append(100 * correct / total)
# 绘制曲线
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.show()
四、关键优化建议
- 数据增强:添加随机裁剪、水平翻转提升泛化能力
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
- 学习率调度:使用
torch.optim.lr_scheduler
动态调整学习率scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# 在每个epoch后调用scheduler.step()
- 模型保存与加载:
torch.save(net.state_dict(), 'model.pth') # 保存模型参数
net.load_state_dict(torch.load('model.pth')) # 加载参数
五、总结与扩展
本文通过CIFAR-10数据集展示了PyTorch实现图像分类的完整流程,核心步骤包括:
- 数据加载与预处理
- CNN模型定义
- 训练与优化
- 测试与评估
扩展方向:
- 尝试更深的网络(如ResNet)
- 使用预训练模型(Transfer Learning)
- 部署到移动端(PyTorch Mobile)
通过理解本例的代码逻辑,读者可快速迁移到其他图像分类任务(如MNIST、ImageNet),并进一步探索目标检测、语义分割等高级任务。
发表评论
登录后可评论,请前往 登录 或 注册