logo

从零掌握UNet图像分割:数据集制作、训练与推理全流程详解

作者:php是最好的2025.09.18 17:15浏览量:0

简介:本文详细介绍如何使用UNet模型进行图像语义分割,涵盖从自定义数据集制作、模型训练到推理测试的全流程,适合有一定编程基础的深度学习爱好者。

从零掌握UNet图像分割:数据集制作、训练与推理全流程详解

一、引言:为什么选择UNet进行图像语义分割?

UNet(U-shaped Network)是2015年由Olaf Ronneberger等人提出的经典卷积神经网络架构,专为医学图像分割任务设计,但因其独特的编码器-解码器结构和跳跃连接(skip connections),被广泛应用于各类图像语义分割场景。相比其他模型,UNet具有以下优势:

  • 轻量化:参数量相对较少,适合小规模数据集训练
  • 多尺度特征融合:通过跳跃连接保留低级特征,提升分割精度
  • 端到端训练:可直接输出与输入图像尺寸相同的分割结果

本文将详细介绍如何使用UNet模型,从零开始制作自定义数据集,完成模型训练,并进行推理测试。整个流程基于PyTorch框架实现,适合有一定Python和深度学习基础的读者。

二、数据集制作:从原始图像到标注文件

1. 数据集准备

首先需要准备原始图像和对应的标注掩码(mask)。标注掩码是一张与原始图像尺寸相同的单通道灰度图,其中每个像素值代表对应的类别。例如,在二分类任务中,背景像素值为0,目标像素值为1。

建议

  • 图像和标注文件应一一对应,文件名建议使用相同前缀(如img_001.jpg对应mask_001.png
  • 图像尺寸建议统一为256×256或512×512,便于模型处理
  • 划分训练集、验证集和测试集,比例建议为7:1:2

2. 标注工具推荐

  • Labelme:开源标注工具,支持多边形、矩形等多种标注方式,可导出JSON格式标注文件,需转换为掩码图
  • CVAT:专业视频/图像标注平台,支持团队协作
  • Photoshop:手动绘制掩码,适合简单场景

示例:使用Labelme标注后,通过以下Python代码将JSON文件转换为掩码图:

  1. import json
  2. import numpy as np
  3. from PIL import Image
  4. import os
  5. def json_to_mask(json_path, output_path):
  6. with open(json_path) as f:
  7. data = json.load(f)
  8. height = data['imageHeight']
  9. width = data['imageWidth']
  10. mask = np.zeros((height, width), dtype=np.uint8)
  11. for shape in data['shapes']:
  12. points = shape['points']
  13. label = shape['label']
  14. # 简单将所有标注区域设为1(实际应根据label区分类别)
  15. # 更复杂的实现可使用多边形填充算法
  16. # 这里仅作示例
  17. cv2.fillPoly(mask, [np.array(points, dtype=np.int32)], 1)
  18. Image.fromarray(mask * 255).save(output_path)
  19. # 使用示例
  20. json_to_mask('labelme_output.json', 'mask.png')

3. 数据增强

为提升模型泛化能力,建议对训练数据进行增强。常用方法包括:

  • 随机水平/垂直翻转
  • 随机旋转(±15度)
  • 随机缩放(0.9-1.1倍)
  • 随机亮度/对比度调整
  • 添加高斯噪声

PyTorch实现示例

  1. import torchvision.transforms as transforms
  2. from torchvision.transforms import functional as F
  3. import random
  4. class CustomAugmentation:
  5. def __init__(self):
  6. self.augmentations = [
  7. self.random_flip,
  8. self.random_rotation,
  9. self.random_scale
  10. ]
  11. def __call__(self, img, mask):
  12. for augment in self.augmentations:
  13. img, mask = augment(img, mask)
  14. return img, mask
  15. def random_flip(self, img, mask):
  16. if random.random() > 0.5:
  17. img = F.hflip(img)
  18. mask = F.hflip(mask)
  19. if random.random() > 0.5:
  20. img = F.vflip(img)
  21. mask = F.vflip(mask)
  22. return img, mask
  23. def random_rotation(self, img, mask):
  24. angle = random.uniform(-15, 15)
  25. img = F.rotate(img, angle)
  26. mask = F.rotate(mask, angle, fill=(0,))
  27. return img, mask
  28. def random_scale(self, img, mask):
  29. scale = random.uniform(0.9, 1.1)
  30. h, w = img.size[1], img.size[0]
  31. new_h, new_w = int(h * scale), int(w * scale)
  32. img = F.resize(img, (new_h, new_w))
  33. mask = F.resize(mask, (new_h, new_w), Image.NEAREST)
  34. # 随机裁剪回原尺寸
  35. i, j = random.randint(0, new_h - h), random.randint(0, new_w - w)
  36. img = F.crop(img, i, j, h, w)
  37. mask = F.crop(mask, i, j, h, w)
  38. return img, mask

三、UNet模型实现:从架构设计到代码实现

1. UNet架构解析

UNet由编码器(下采样路径)和解码器(上采样路径)组成,通过跳跃连接融合多尺度特征:

  • 编码器:4次下采样,每次通道数翻倍
  • 解码器:4次上采样,每次通道数减半
  • 跳跃连接:将编码器的特征图与解码器的上采样特征图拼接

2. PyTorch实现代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class DoubleConv(nn.Module):
  5. """(convolution => [BN] => ReLU) * 2"""
  6. def __init__(self, in_channels, out_channels):
  7. super().__init__()
  8. self.double_conv = nn.Sequential(
  9. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
  10. nn.BatchNorm2d(out_channels),
  11. nn.ReLU(inplace=True),
  12. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
  13. nn.BatchNorm2d(out_channels),
  14. nn.ReLU(inplace=True)
  15. )
  16. def forward(self, x):
  17. return self.double_conv(x)
  18. class Down(nn.Module):
  19. """Downscaling with maxpool then double conv"""
  20. def __init__(self, in_channels, out_channels):
  21. super().__init__()
  22. self.maxpool_conv = nn.Sequential(
  23. nn.MaxPool2d(2),
  24. DoubleConv(in_channels, out_channels)
  25. )
  26. def forward(self, x):
  27. return self.maxpool_conv(x)
  28. class Up(nn.Module):
  29. """Upscaling then double conv"""
  30. def __init__(self, in_channels, out_channels, bilinear=True):
  31. super().__init__()
  32. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) if bilinear else nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
  33. self.conv = DoubleConv(in_channels, out_channels)
  34. def forward(self, x1, x2):
  35. x1 = self.up(x1)
  36. # 输入是CHW
  37. diffY = x2.size()[2] - x1.size()[2]
  38. diffX = x2.size()[3] - x1.size()[3]
  39. x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
  40. diffY // 2, diffY - diffY // 2])
  41. x = torch.cat([x2, x1], dim=1)
  42. return self.conv(x)
  43. class OutConv(nn.Module):
  44. def __init__(self, in_channels, out_channels):
  45. super(OutConv, self).__init__()
  46. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
  47. def forward(self, x):
  48. return self.conv(x)
  49. class UNet(nn.Module):
  50. def __init__(self, n_channels, n_classes, bilinear=True):
  51. super(UNet, self).__init__()
  52. self.n_channels = n_channels
  53. self.n_classes = n_classes
  54. self.bilinear = bilinear
  55. self.inc = DoubleConv(n_channels, 64)
  56. self.down1 = Down(64, 128)
  57. self.down2 = Down(128, 256)
  58. self.down3 = Down(256, 512)
  59. self.down4 = Down(512, 1024)
  60. self.up1 = Up(1024, 512, bilinear)
  61. self.up2 = Up(512, 256, bilinear)
  62. self.up3 = Up(256, 128, bilinear)
  63. self.up4 = Up(128, 64, bilinear)
  64. self.outc = OutConv(64, n_classes)
  65. def forward(self, x):
  66. x1 = self.inc(x)
  67. x2 = self.down1(x1)
  68. x3 = self.down2(x2)
  69. x4 = self.down3(x3)
  70. x5 = self.down4(x4)
  71. x = self.up1(x5, x4)
  72. x = self.up2(x, x3)
  73. x = self.up3(x, x2)
  74. x = self.up4(x, x1)
  75. logits = self.outc(x)
  76. return logits

四、模型训练:从数据加载到优化策略

1. 数据加载器实现

  1. from torch.utils.data import Dataset, DataLoader
  2. from PIL import Image
  3. import torch
  4. import os
  5. class SegmentationDataset(Dataset):
  6. def __init__(self, img_dir, mask_dir, transform=None):
  7. self.img_dir = img_dir
  8. self.mask_dir = mask_dir
  9. self.transform = transform
  10. self.images = os.listdir(img_dir)
  11. def __len__(self):
  12. return len(self.images)
  13. def __getitem__(self, idx):
  14. img_path = os.path.join(self.img_dir, self.images[idx])
  15. mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '.png'))
  16. image = Image.open(img_path).convert('RGB')
  17. mask = Image.open(mask_path).convert('L') # 转换为灰度图
  18. if self.transform:
  19. image, mask = self.transform(image, mask)
  20. # 归一化并添加通道维度
  21. image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
  22. mask = torch.from_numpy(np.array(mask)).long() # 分类任务使用long类型
  23. return image, mask
  24. # 使用示例
  25. train_transform = CustomAugmentation() # 前文定义的数据增强
  26. train_dataset = SegmentationDataset('data/train/images', 'data/train/masks', transform=train_transform)
  27. train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)

2. 训练脚本实现

  1. import torch.optim as optim
  2. from tqdm import tqdm
  3. import numpy as np
  4. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  5. def train_model(model, train_loader, val_loader, epochs=50):
  6. model = model.to(device)
  7. criterion = nn.CrossEntropyLoss() # 多分类任务
  8. optimizer = optim.Adam(model.parameters(), lr=1e-4)
  9. scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
  10. best_val_loss = float('inf')
  11. for epoch in range(epochs):
  12. model.train()
  13. train_loss = 0
  14. progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
  15. for images, masks in progress_bar:
  16. images, masks = images.to(device), masks.to(device)
  17. optimizer.zero_grad()
  18. outputs = model(images)
  19. loss = criterion(outputs, masks)
  20. loss.backward()
  21. optimizer.step()
  22. train_loss += loss.item()
  23. progress_bar.set_postfix(loss=loss.item())
  24. train_loss /= len(train_loader)
  25. # 验证阶段
  26. val_loss = validate(model, val_loader, criterion)
  27. scheduler.step(val_loss)
  28. print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
  29. if val_loss < best_val_loss:
  30. best_val_loss = val_loss
  31. torch.save(model.state_dict(), 'best_model.pth')
  32. def validate(model, val_loader, criterion):
  33. model.eval()
  34. val_loss = 0
  35. with torch.no_grad():
  36. for images, masks in val_loader:
  37. images, masks = images.to(device), masks.to(device)
  38. outputs = model(images)
  39. loss = criterion(outputs, masks)
  40. val_loss += loss.item()
  41. return val_loss / len(val_loader)
  42. # 初始化模型
  43. model = UNet(n_channels=3, n_classes=2) # 假设二分类任务
  44. # 开始训练
  45. train_model(model, train_loader, val_loader)

五、推理测试:从模型加载到结果可视化

1. 模型加载与预处理

  1. def load_model(model_path, n_channels=3, n_classes=2):
  2. model = UNet(n_channels, n_classes)
  3. model.load_state_dict(torch.load(model_path))
  4. model.eval()
  5. return model
  6. def preprocess_image(image_path, target_size=(256, 256)):
  7. image = Image.open(image_path).convert('RGB')
  8. if image.size != target_size:
  9. image = image.resize(target_size, Image.BILINEAR)
  10. transform = transforms.Compose([
  11. transforms.ToTensor(),
  12. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  13. ])
  14. return transform(image).unsqueeze(0) # 添加batch维度

2. 推理与后处理

  1. import matplotlib.pyplot as plt
  2. import matplotlib.colors as mcolors
  3. def predict(model, image_tensor):
  4. with torch.no_grad():
  5. output = model(image_tensor.to(device))
  6. pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
  7. return pred
  8. def visualize_prediction(image, pred_mask, original_mask=None):
  9. plt.figure(figsize=(12, 6))
  10. plt.subplot(1, 3, 1)
  11. plt.imshow(image)
  12. plt.title('Original Image')
  13. plt.axis('off')
  14. plt.subplot(1, 3, 2)
  15. # 创建颜色映射(0:黑色,1:红色)
  16. cmap = mcolors.ListedColormap(['black', 'red'])
  17. plt.imshow(pred_mask, cmap=cmap)
  18. plt.title('Predicted Mask')
  19. plt.axis('off')
  20. if original_mask is not None:
  21. plt.subplot(1, 3, 3)
  22. plt.imshow(original_mask, cmap=cmap)
  23. plt.title('Original Mask')
  24. plt.axis('off')
  25. plt.tight_layout()
  26. plt.show()
  27. # 使用示例
  28. model = load_model('best_model.pth')
  29. image_tensor = preprocess_image('test_image.jpg')
  30. pred_mask = predict(model, image_tensor)
  31. # 显示结果(假设有原始图像和标注)
  32. original_image = Image.open('test_image.jpg')
  33. original_mask = Image.open('test_mask.png').convert('L')
  34. visualize_prediction(original_image, pred_mask, np.array(original_mask))

六、进阶优化建议

  1. 损失函数选择

    • 二分类任务:BCEWithLogitsLoss
    • 多分类任务:CrossEntropyLoss
    • 类别不平衡时:加权CrossEntropy或Dice Loss
  2. 模型改进

    • 使用ResNet或EfficientNet作为编码器 backbone
    • 添加注意力机制(如CBAM)
    • 使用DeepLabv3+的空洞卷积结构
  3. 训练技巧

    • 使用学习率预热(warmup)
    • 实现早停(Early Stopping)
    • 使用混合精度训练(AMP)
  4. 部署优化

    • 模型量化(INT8)
    • TensorRT加速
    • ONNX格式导出

七、总结与展望

本文详细介绍了使用UNet进行图像语义分割的完整流程,包括自定义数据集制作、模型实现、训练优化和推理测试。通过实践,读者可以掌握:

  • 如何准备符合要求的语义分割数据集
  • 如何实现经典的UNet架构
  • 如何设计有效的训练策略
  • 如何进行模型推理和结果可视化

未来工作可以探索:

  • 3D医学图像分割(如使用3D UNet)
  • 实时语义分割(如使用轻量化模型)
  • 弱监督/半监督学习方法

希望本文能为深度学习初学者提供实用的入门指南,也为研究人员提供有价值的实践参考。

相关文章推荐

发表评论