从零开始:PyTorch实现图像分类全流程+代码详解
2025.09.19 11:29浏览量:1简介:本文通过完整代码示例和详细注释,指导读者使用PyTorch框架实现CIFAR-10数据集的图像分类任务,涵盖数据加载、模型构建、训练优化和结果评估全流程,适合深度学习初学者和开发者参考。
使用PyTorch实现图像分类:完整代码与详细解析
引言
图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,提供了灵活高效的工具链。本文将通过CIFAR-10数据集的分类案例,完整展示从数据准备到模型部署的全流程,代码包含逐行注释,适合不同层次的开发者学习。
一、环境准备
1.1 基础环境配置
# 版本说明# Python 3.8+# PyTorch 2.0+# torchvision 0.15+# CUDA 11.7+ (如需GPU加速)import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms, modelsimport matplotlib.pyplot as pltimport numpy as np# 设备配置device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")
关键点:明确版本要求,自动检测可用硬件,为后续训练提供基础保障。
1.2 数据预处理
# 定义数据增强和归一化transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), # 随机裁剪增强transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(), # 转为Tensortransforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)) # CIFAR-10均值标准差])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))])
设计原则:训练集采用增强策略提升泛化性,测试集仅做标准化保持一致性。
二、数据加载模块
2.1 数据集准备
# 下载并加载CIFAR-10数据集train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=transform_train)test_dataset = datasets.CIFAR10(root='./data',train=False,download=True,transform=transform_test)# 创建数据加载器batch_size = 128train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=2)test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=False,num_workers=2)# 类别名称映射classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
优化建议:设置num_workers加速数据加载,根据硬件调整batch_size。
三、模型架构设计
3.1 基础CNN实现
class SimpleCNN(nn.Module):def __init__(self, num_classes=10):super(SimpleCNN, self).__init__()self.features = nn.Sequential(# 输入3x32x32,输出64x32x32nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2), # 输出64x16x16# 输出128x16x16nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2), # 输出128x8x8# 输出256x8x8nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2) # 输出256x4x4)self.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(256 * 4 * 4, 1024),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1) # 展平特征图x = self.classifier(x)return x
架构解析:
- 3个卷积块:每个包含卷积+BN+ReLU+池化
- 2个全连接层:中间加入Dropout防止过拟合
- 参数计算:约1.2M可训练参数
3.2 预训练模型加载
def load_pretrained_model():# 加载ResNet18预训练模型model = models.resnet18(pretrained=True)# 冻结所有卷积层参数for param in model.parameters():param.requires_grad = False# 修改最后的全连接层num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, 10)return model
迁移学习技巧:
- 冻结浅层特征提取器
- 仅训练分类层实现快速收敛
- 适合数据量较小的场景
四、训练流程实现
4.1 训练函数封装
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):best_acc = 0.0for epoch in range(num_epochs):print(f'Epoch {epoch+1}/{num_epochs}')print('-' * 10)# 每个epoch都有训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train() # 设置训练模式dataloader = train_loaderelse:model.eval() # 设置评估模式dataloader = test_loaderrunning_loss = 0.0running_corrects = 0# 迭代数据for inputs, labels in dataloader:inputs = inputs.to(device)labels = labels.to(device)# 梯度清零optimizer.zero_grad()# 前向传播with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 反向传播+优化仅在训练阶段if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / len(dataloader.dataset)epoch_acc = running_corrects.double() / len(dataloader.dataset)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 深度复制模型if phase == 'val' and epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), 'best_model.pth')print(f'Best val Acc: {best_acc:.4f}')return model
关键机制:
- 训练/验证模式切换
- 自动混合精度训练支持
- 学习率调度器集成
- 最佳模型保存策略
4.2 主训练流程
def main():# 初始化模型model = SimpleCNN().to(device)# model = load_pretrained_model().to(device) # 切换预训练模型# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)# 学习率调度器scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 训练模型model = train_model(model, criterion, optimizer, scheduler, num_epochs=25)# 保存最终模型torch.save(model.state_dict(), 'final_model.pth')
超参数建议:
- 初始学习率:0.1(CNN),0.001(预训练模型)
- 动量:0.9
- 调度策略:每7个epoch衰减10倍
五、评估与可视化
5.1 模型评估
def evaluate_model(model_path):model = SimpleCNN().to(device)model.load_state_dict(torch.load(model_path))model.eval()correct = 0total = 0class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 按类别统计c = (predicted == labels).squeeze()for i in range(len(labels)):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1print(f'Accuracy: {100 * correct / total:.2f}%')# 打印每类准确率for i in range(10):print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')
5.2 错误分析可视化
def visualize_errors(model_path, num_images=6):model = SimpleCNN().to(device)model.load_state_dict(torch.load(model_path))model.eval()dataiter = iter(test_loader)images, labels = next(dataiter)images, labels = images.to(device), labels.to(device)with torch.no_grad():outputs = model(images)_, preds = torch.max(outputs, 1)# 移动到CPU并转为numpyimages = images.cpu().numpy()# 绘制图像fig = plt.figure(figsize=(10, 10))for idx in range(num_images):ax = fig.add_subplot(1, num_images, idx+1, xticks=[], yticks=[])img = images[idx].transpose((1, 2, 0))mean = np.array([0.4914, 0.4822, 0.4465])std = np.array([0.2023, 0.1994, 0.2010])img = std * img + mean # 反归一化img = np.clip(img, 0, 1)ax.imshow(img)ax.set_title(f'{classes[preds[idx]]}\n({classes[labels[idx]]})',color=("green" if preds[idx]==labels[idx] else "red"))plt.show()
六、进阶优化技巧
6.1 学习率查找
def find_lr(model, optimizer, criterion, init_value=1e-7, final_value=10., beta=0.98):num = len(train_loader)-1mult = (final_value / init_value) ** (1/num)lr = init_valueoptimizer.param_groups[0]['lr'] = lravg_loss = 0.best_loss = 0.batch_num = 0losses = []log_lrs = []model.train()for inputs, labels in train_loader:batch_num +=1inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)# 计算平滑损失avg_loss = beta * avg_loss + (1-beta) *loss.item()smoothed_loss = avg_loss / (1 - beta**batch_num)# 记录最佳损失if batch_num > 1 and smoothed_loss > best_loss - np.log(10):best_loss = smoothed_loss# 记录当前学习率和损失losses.append(smoothed_loss)log_lrs.append(math.log10(lr))# 反向传播loss.backward()optimizer.step()# 更新学习率lr *= multoptimizer.param_groups[0]['lr'] = lrif lr > 1e2 or smoothed_loss > 1e7:breakplt.plot(log_lrs[10:-5], losses[10:-5])plt.xlabel('log10(lr)')plt.ylabel('Loss')plt.show()
6.2 混合精度训练
from torch.cuda.amp import GradScaler, autocastdef train_with_amp(model, criterion, optimizer, num_epochs=25):scaler = GradScaler()for epoch in range(num_epochs):model.train()running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()with autocast():outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()running_loss += loss.item() * inputs.size(0)epoch_loss = running_loss / len(train_loader.dataset)print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}')
七、部署准备
7.1 模型导出为ONNX
def export_to_onnx(model_path, output_path='model.onnx'):model = SimpleCNN()model.load_state_dict(torch.load(model_path))model.eval()# 创建随机输入dummy_input = torch.randn(1, 3, 32, 32).to(device)# 导出模型torch.onnx.export(model,dummy_input,output_path,export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes={'input': {0: 'batch_size'},'output': {0: 'batch_size'}})print(f'Model exported to {output_path}')
7.2 性能优化建议
- 量化感知训练:使用
torch.quantization模块减少模型大小 - TensorRT加速:将ONNX模型转换为TensorRT引擎
- 多线程推理:设置
torch.set_num_threads()优化CPU推理
八、完整代码整合
# 完整代码整合(见GitHub仓库)# 包含:# 1. 所有模块的完整实现# 2. 训练/评估脚本# 3. 可视化工具# 4. 部署接口
九、总结与展望
本文通过CIFAR-10分类任务,系统展示了PyTorch实现图像分类的完整流程。关键收获包括:
- 掌握数据增强、模型构建、训练优化的核心方法
- 理解迁移学习、混合精度训练等进阶技术
- 获得可复用的代码模板和工具函数
未来研究方向:
- 尝试更先进的架构(Vision Transformer等)
- 探索自监督学习预训练方法
- 优化大规模数据集的分布式训练

发表评论
登录后可评论,请前往 登录 或 注册