深度学习实战:使用PyTorch实现图像分类(含完整代码与注释)
2025.09.18 17:51浏览量:77简介:本文详细介绍如何使用PyTorch框架实现图像分类任务,包含从数据加载到模型训练的全流程代码,并附有详细注释,适合PyTorch初学者和深度学习开发者参考。
深度学习实战:使用PyTorch实现图像分类(含完整代码与注释)
一、引言
图像分类是计算机视觉领域的核心任务之一,广泛应用于人脸识别、医学影像分析、自动驾驶等场景。PyTorch作为主流深度学习框架,以其动态计算图和简洁的API设计受到开发者青睐。本文将通过一个完整的图像分类案例,展示如何使用PyTorch实现从数据加载、模型构建到训练评估的全流程,代码均附有详细注释。
二、环境准备
2.1 安装依赖
pip install torch torchvision matplotlib numpy
torch: PyTorch核心库torchvision: 提供计算机视觉相关工具(数据集、模型、变换)matplotlib: 用于可视化训练过程numpy: 数值计算基础库
2.2 硬件要求
- CPU或GPU(推荐NVIDIA GPU+CUDA加速)
- 至少4GB内存(数据集较小时)
三、完整代码实现
3.1 数据准备与预处理
import torchfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoader# 定义数据变换(训练集和测试集不同)train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.RandomRotation(10), # 随机旋转(-10°~+10°)transforms.ToTensor(), # 转为Tensor并归一化到[0,1]transforms.Normalize((0.5,), (0.5,)) # 标准化到[-1,1]])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 加载MNIST数据集(手写数字识别)train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=train_transform)test_dataset = datasets.MNIST(root='./data',train=False,download=True,transform=test_transform)# 创建数据加载器(批量大小64,4个工作进程)train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True,num_workers=4)test_loader = DataLoader(test_dataset,batch_size=64,shuffle=False,num_workers=4)
关键点说明:
transforms.Compose: 组合多个数据增强操作RandomHorizontalFlip/RandomRotation: 提升模型泛化能力Normalize: 使用均值和标准差进行标准化(MNIST的均值=0.5,标准差=0.5)DataLoader: 实现批量加载、随机打乱和多线程加速
3.2 模型定义(CNN)
import torch.nn as nnimport torch.nn.functional as Fclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层1:输入1通道,输出16通道,3x3卷积核self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)# 卷积层2:输入16通道,输出32通道self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 128) # 输入尺寸需计算self.fc2 = nn.Linear(128, 10) # 输出10类# 池化层和Dropoutself.pool = nn.MaxPool2d(2, 2)self.dropout = nn.Dropout(0.25)def forward(self, x):# 卷积层1 + ReLU + 池化x = self.pool(F.relu(self.conv1(x)))# 卷积层2 + ReLU + 池化x = self.pool(F.relu(self.conv2(x)))# 展平特征图x = x.view(-1, 32 * 7 * 7)# 全连接层 + Dropoutx = self.dropout(F.relu(self.fc1(x)))x = self.fc2(x)return x
模型结构解析:
- 输入:28x28灰度图像(1通道)
- 卷积块1:3x3卷积→ReLU→2x2最大池化(输出16x14x14)
- 卷积块2:同上(输出32x7x7)
- 全连接层:3277→128→10(输出类别概率)
- Dropout:防止过拟合(训练时随机丢弃25%神经元)
3.3 训练流程
def train_model():# 初始化模型、损失函数和优化器model = CNN()criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 移动模型到GPU(如果可用)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)# 训练循环for epoch in range(10): # 训练10个epochmodel.train()running_loss = 0.0for i, (images, labels) in enumerate(train_loader):# 移动数据到设备images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 统计损失running_loss += loss.item()if i % 100 == 99: # 每100个batch打印一次print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')running_loss = 0.0# 保存模型torch.save(model.state_dict(), 'mnist_cnn.pth')print('训练完成,模型已保存')
训练要点:
CrossEntropyLoss: 多分类交叉熵损失Adam优化器: 自适应学习率,适合初学者model.train(): 启用Dropout等训练专用层zero_grad(): 清除上一步的梯度state_dict(): 保存模型参数(不包含结构)
3.4 评估与预测
def evaluate_model():model = CNN()model.load_state_dict(torch.load('mnist_cnn.pth'))device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)model.eval() # 切换到评估模式correct = 0total = 0with torch.no_grad(): # 禁用梯度计算for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'测试集准确率: {100 * correct / total:.2f}%')# 单张图片预测示例def predict_image(image_path):from PIL import Imageimport numpy as np# 加载并预处理图片(假设是28x28灰度图)image = Image.open(image_path).convert('L')image = image.resize((28, 28))transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])image = transform(image).unsqueeze(0) # 添加batch维度# 加载模型并预测model = CNN()model.load_state_dict(torch.load('mnist_cnn.pth'))model.eval()with torch.no_grad():output = model(image)_, predicted = torch.max(output.data, 1)print(f'预测结果: {predicted.item()}')
四、关键优化技巧
- 学习率调度:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)# 在每个epoch后调用 scheduler.step()
- 早停机制:监控验证集损失,当连续3个epoch未改善时停止训练。
- 模型微调:使用预训练模型(如ResNet)的卷积基,仅训练最后的全连接层。
五、常见问题解决方案
- CUDA内存不足:减小
batch_size或使用torch.cuda.empty_cache() - 过拟合:增加Dropout比例、添加L2正则化或使用数据增强
- 收敛慢:尝试不同的学习率(如1e-3、5e-4)或优化器(SGD+Momentum)
六、扩展应用
- 多标签分类:修改输出层为
nn.Sigmoid(),使用BCELoss - 自定义数据集:继承
torch.utils.data.Dataset实现__len__和__getitem__ - 分布式训练:使用
torch.nn.parallel.DistributedDataParallel
七、总结
本文通过MNIST手写数字分类案例,完整展示了PyTorch实现图像分类的流程。关键步骤包括:
- 数据加载与增强
- CNN模型构建
- 训练循环与优化
- 模型评估与部署
读者可基于此代码框架,替换数据集和模型结构以适应不同任务(如CIFAR-10分类)。建议进一步探索:
- 更复杂的模型架构(ResNet、EfficientNet)
- 混合精度训练加速
- 使用TensorBoard可视化训练过程
完整代码已通过PyTorch 1.12和Python 3.8验证,可直接运行。

发表评论
登录后可评论,请前往 登录 或 注册