使用PyTorch构建图像分类系统:完整代码与深度解析
2025.09.19 17:05浏览量:0简介:本文详细介绍如何使用PyTorch框架实现图像分类任务,包含从数据加载到模型训练的全流程代码,每行代码均附有详细注释,适合PyTorch初学者及有一定基础的开发者参考。
使用PyTorch构建图像分类系统:完整代码与深度解析
图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,提供了简洁高效的API支持。本文将通过完整代码示例,展示如何使用PyTorch实现从数据准备到模型部署的全流程,所有代码均包含详细注释,确保读者能够理解每个步骤的实现原理。
一、环境准备与依赖安装
首先需要安装PyTorch及相关依赖库。推荐使用conda创建虚拟环境:
conda create -n pytorch_img_cls python=3.8
conda activate pytorch_img_cls
pip install torch torchvision matplotlib numpy
关键依赖说明:
torch
:PyTorch核心库torchvision
:提供计算机视觉常用数据集和模型架构matplotlib
:用于可视化训练过程numpy
:数值计算基础库
二、数据集准备与预处理
1. 使用CIFAR-10数据集
CIFAR-10包含10个类别的60000张32x32彩色图像,分为50000张训练集和10000张测试集。
import torchvision
import torchvision.transforms as transforms
# 定义数据预处理流程
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL图像转换为Tensor,并归一化到[0,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]
])
# 加载训练集
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=32,
shuffle=True,
num_workers=2
)
# 加载测试集
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=32,
shuffle=False,
num_workers=2
)
# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
关键点解析:
transforms.Compose
:组合多个数据预处理操作ToTensor()
:将HWC格式的PIL图像转换为CHW格式的TensorNormalize
:使用均值和标准差进行标准化,这里使用(0.5,0.5,0.5)将像素值映射到[-1,1]区间DataLoader
:实现批量加载、数据打乱和多线程加载
2. 自定义数据集加载
对于自定义数据集,可以继承torch.utils.data.Dataset
类:
from torch.utils.data import Dataset
import os
from PIL import Image
class CustomImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_labels = []
self.img_paths = []
self.transform = transform
# 遍历目录,假设子目录名为类别名
for class_name in os.listdir(img_dir):
class_path = os.path.join(img_dir, class_name)
if os.path.isdir(class_path):
for img_name in os.listdir(class_path):
self.img_paths.append(os.path.join(class_path, img_name))
self.img_labels.append(classes.index(class_name))
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_path = self.img_paths[idx]
image = Image.open(img_path)
label = self.img_labels[idx]
if self.transform:
image = self.transform(image)
return image, label
三、模型架构设计
1. 基础CNN模型
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 输入通道3(RGB),输出通道32,3x3卷积核
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2) # 2x2最大池化
self.fc1 = nn.Linear(64 * 8 * 8, 512) # CIFAR-10经过两次池化后为8x8
self.fc2 = nn.Linear(512, 10) # 10个输出类别
self.dropout = nn.Dropout(0.25)
def forward(self, x):
# 第一层卷积+ReLU+池化
x = self.pool(F.relu(self.conv1(x)))
# 第二层卷积+ReLU+池化
x = self.pool(F.relu(self.conv2(x)))
# 展平特征图
x = x.view(-1, 64 * 8 * 8)
# 全连接层+ReLU+Dropout
x = self.dropout(F.relu(self.fc1(x)))
# 输出层
x = self.fc2(x)
return x
架构解析:
- 两个卷积层提取空间特征,每个卷积层后接ReLU激活函数和最大池化
- 两个全连接层完成分类,中间加入Dropout防止过拟合
- 输入32x32x3图像,经过两次2x2池化后变为8x8x64特征图
2. 使用预训练模型
PyTorch提供了多种预训练模型,可通过torchvision.models
加载:
import torchvision.models as models
def get_pretrained_model(model_name='resnet18', pretrained=True, num_classes=10):
if model_name == 'resnet18':
model = models.resnet18(pretrained=pretrained)
# 修改最后一层全连接网络
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
elif model_name == 'vgg16':
model = models.vgg16(pretrained=pretrained)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, num_classes)
else:
raise ValueError("Unsupported model name")
return model
四、训练流程实现
1. 完整训练代码
import torch
import torch.optim as optim
from tqdm import tqdm # 进度条库
def train_model(model, trainloader, testloader, criterion, optimizer, num_epochs=10):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(num_epochs):
# 训练阶段
model.train()
running_loss = 0.0
correct = 0
total = 0
# 使用tqdm显示进度条
train_loop = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
for inputs, labels in train_loop:
inputs, labels = inputs.to(device), labels.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计信息
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 更新进度条信息
train_loop.set_postfix(loss=running_loss/(train_loop.n+1),
acc=100.*correct/total)
# 测试阶段
test_loss, test_acc = evaluate_model(model, testloader, criterion, device)
print(f'Epoch {epoch+1}, Train Loss: {running_loss/len(trainloader):.4f}, '
f'Train Acc: {100*correct/total:.2f}%, '
f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
def evaluate_model(model, testloader, criterion, device):
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return test_loss/len(testloader), 100*correct/total
# 初始化模型
model = SimpleCNN()
# 或者使用预训练模型
# model = get_pretrained_model('resnet18')
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 开始训练
train_model(model, trainloader, testloader, criterion, optimizer, num_epochs=10)
2. 关键训练参数说明
- 学习率:控制参数更新步长,常用值为0.001(Adam)或0.01(SGD)
- 批量大小:影响内存使用和梯度估计稳定性,CIFAR-10常用32或64
- 优化器选择:
- Adam:自适应学习率,收敛快
- SGD+Momentum:可能获得更好泛化性能
- 损失函数:分类任务通常使用交叉熵损失
五、模型评估与可视化
1. 混淆矩阵实现
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix(model, testloader, classes, device):
model.eval()
all_labels = []
all_preds = []
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
all_labels.extend(labels.cpu().numpy())
all_preds.extend(predicted.cpu().numpy())
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
# 调用示例
plot_confusion_matrix(model, testloader, classes, device)
2. 训练过程可视化
def plot_training_curve(train_losses, test_losses, train_accs, test_accs):
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.show()
# 需要在训练过程中记录这些指标
# 示例数据
epochs = range(1, 11)
train_losses = [2.3, 1.8, 1.5, 1.2, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
test_losses = [2.1, 1.7, 1.4, 1.1, 0.95, 0.85, 0.78, 0.72, 0.68, 0.65]
train_accs = [45, 58, 65, 70, 75, 78, 80, 82, 84, 85]
test_accs = [50, 62, 68, 72, 76, 78, 80, 81, 82, 83]
plot_training_curve(train_losses, test_losses, train_accs, test_accs)
六、模型部署建议
模型导出:使用
torch.save
保存模型参数torch.save(model.state_dict(), 'cifar_classifier.pth')
推理脚本示例:
def predict_image(image_path, model, transform, classes, device):
image = Image.open(image_path)
image = transform(image).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
return classes[predicted.item()]
性能优化技巧:
- 使用混合精度训练(
torch.cuda.amp
) - 模型量化减少内存占用
- 使用TensorRT加速推理
- 使用混合精度训练(
七、常见问题解决方案
训练不收敛:
- 检查学习率是否过大
- 确认数据预处理是否正确
- 尝试不同的优化器
过拟合问题:
- 增加数据增强
- 添加Dropout层
- 使用L2正则化
GPU内存不足:
- 减小批量大小
- 使用梯度累积
- 清理缓存(
torch.cuda.empty_cache()
)
本文完整代码可在GitHub获取,建议读者从简单CNN开始实践,逐步尝试预训练模型和更复杂的架构。通过调整超参数和观察训练曲线,可以深入理解深度学习模型的工作原理。
发表评论
登录后可评论,请前往 登录 或 注册