从零开始:使用PyTorch实现图像分类(附完整代码与注释)
2025.09.18 17:43浏览量:0简介:本文详细介绍如何使用PyTorch框架实现图像分类任务,包含数据加载、模型构建、训练与评估全流程,提供完整可运行的代码及逐行注释,适合PyTorch初学者和深度学习实践者。
从零开始:使用PyTorch实现图像分类(附完整代码与注释)
引言
图像分类是计算机视觉领域的核心任务之一,广泛应用于人脸识别、医学影像分析、自动驾驶等场景。PyTorch作为主流深度学习框架,以其动态计算图和简洁的API设计,成为实现图像分类任务的理想选择。本文将通过一个完整的案例,展示如何使用PyTorch从零开始实现图像分类,涵盖数据准备、模型构建、训练优化和结果评估的全流程,并提供详细注释的完整代码。
1. 环境准备与依赖安装
在开始之前,需要确保已安装Python 3.6+环境,并安装以下依赖库:
pip install torch torchvision matplotlib numpy
PyTorch提供预编译的二进制包,支持CPU和GPU版本。若使用GPU加速,需安装与CUDA版本匹配的PyTorch版本。
2. 数据准备与预处理
2.1 数据集选择
本文使用经典的CIFAR-10数据集,包含10个类别的6万张32x32彩色图像(5万训练集,1万测试集)。PyTorch的torchvision
库提供了便捷的数据加载接口:
import torchvision
import torchvision.transforms as transforms
# 定义数据预处理流程
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL图像或numpy数组转换为Tensor,并缩放至[0,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化至[-1,1]
])
# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=32,
shuffle=True,
num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=32,
shuffle=False,
num_workers=2
)
关键点说明:
transforms.Compose
:将多个预处理操作组合为一个流水线。ToTensor()
:自动将图像从[0,255]
的uint8类型转换为[0,1]
的float32类型Tensor。Normalize
:使用均值和标准差进行归一化,公式为(x - mean) / std
。DataLoader
:提供批量加载、多线程加速和随机打乱功能。
2.2 数据可视化
为验证数据加载是否正确,可随机显示部分训练图像:
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 获取一个批次的图像和标签
dataiter = iter(trainloader)
images, labels = next(dataiter)
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join(f'{trainset.classes[labels[j]]}' for j in range(4)))
3. 模型构建:卷积神经网络(CNN)
3.1 网络架构设计
本文实现一个简化的CNN模型,包含2个卷积层和2个全连接层:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 第一个卷积层:输入通道3(RGB),输出通道6,5x5卷积核
self.conv1 = nn.Conv2d(3, 6, 5)
# 第二个卷积层:输入通道6,输出通道16,5x5卷积核
self.conv2 = nn.Conv2d(6, 16, 5)
# 第一个全连接层:输入16*5*5(经过池化后的特征图尺寸),输出120
self.fc1 = nn.Linear(16 * 5 * 5, 120)
# 第二个全连接层:输入120,输出84
self.fc2 = nn.Linear(120, 84)
# 输出层:输入84,输出10(CIFAR-10类别数)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# 最大池化,kernel_size=2, stride=2
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
# 展平特征图
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
架构解析:
- 卷积层:提取空间特征,通过
Conv2d
实现,参数包括输入通道数、输出通道数和卷积核大小。 - 激活函数:使用ReLU引入非线性。
- 池化层:通过
max_pool2d
降低特征图尺寸,减少计算量。 - 全连接层:将特征映射到类别空间。
3.2 损失函数与优化器
import torch.optim as optim
# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()
# 随机梯度下降优化器,学习率0.001,动量0.9
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
选择依据:
- 交叉熵损失:适用于多分类问题,直接比较预测概率分布与真实标签。
- SGD优化器:加入动量项可加速收敛并减少震荡。
4. 模型训练与评估
4.1 训练循环
def train(net, trainloader, criterion, optimizer, epochs=10):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入和标签
inputs, labels = data
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = net(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
# 统计损失
running_loss += loss.item()
if i % 2000 == 1999: # 每2000个batch打印一次
print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
train(net, trainloader, criterion, optimizer, epochs=10)
关键步骤:
optimizer.zero_grad()
:清除历史梯度,防止梯度累积。loss.backward()
:自动计算梯度。optimizer.step()
:根据梯度更新参数。
4.2 模型评估
def evaluate(net, testloader):
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1) # 获取概率最大的类别
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy on test set: {accuracy:.2f}%')
return accuracy
evaluate(net, testloader)
评估指标:
- 准确率:正确预测的样本数占总样本数的比例。
5. 完整代码与注释
# 完整代码整合(含详细注释)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 1. 数据准备
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=32, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=32, shuffle=False, num_workers=2
)
# 2. 定义网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # 输入3通道,输出6通道,5x5卷积核
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2) # 卷积+ReLU+池化
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, 16 * 5 * 5) # 展平
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 4. 训练函数
def train(net, trainloader, criterion, optimizer, epochs=10):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
train(net, trainloader, criterion, optimizer, epochs=10)
# 5. 评估函数
def evaluate(net, testloader):
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy on test set: {accuracy:.2f}%')
return accuracy
evaluate(net, testloader)
6. 实践建议与扩展方向
模型改进:
- 增加卷积层深度(如使用ResNet架构)。
- 引入批归一化(
nn.BatchNorm2d
)加速收敛。 - 使用数据增强(如随机裁剪、水平翻转)提升泛化能力。
超参数调优:
- 学习率调度器(如
torch.optim.lr_scheduler.StepLR
)。 - 网格搜索或贝叶斯优化寻找最优超参数。
- 学习率调度器(如
部署应用:
- 导出模型为TorchScript格式(
torch.jit.trace
)。 - 部署至移动端(通过PyTorch Mobile)或云端服务。
- 导出模型为TorchScript格式(
总结
本文通过一个完整的案例,展示了使用PyTorch实现图像分类的全流程,包括数据加载、模型构建、训练优化和结果评估。代码中提供了详细的注释,帮助读者理解每个步骤的原理和实现细节。对于初学者,建议从简化模型开始,逐步尝试更复杂的架构和技巧;对于进阶用户,可探索分布式训练、模型压缩等高级主题。”
发表评论
登录后可评论,请前往 登录 或 注册