从零掌握UNet图像分割:数据集制作、训练与推理全流程详解
2025.09.18 17:15浏览量:0简介:本文详细介绍如何使用UNet模型进行图像语义分割,涵盖从自定义数据集制作、模型训练到推理测试的全流程,适合有一定编程基础的深度学习爱好者。
从零掌握UNet图像分割:数据集制作、训练与推理全流程详解
一、引言:为什么选择UNet进行图像语义分割?
UNet(U-shaped Network)是2015年由Olaf Ronneberger等人提出的经典卷积神经网络架构,专为医学图像分割任务设计,但因其独特的编码器-解码器结构和跳跃连接(skip connections),被广泛应用于各类图像语义分割场景。相比其他模型,UNet具有以下优势:
- 轻量化:参数量相对较少,适合小规模数据集训练
- 多尺度特征融合:通过跳跃连接保留低级特征,提升分割精度
- 端到端训练:可直接输出与输入图像尺寸相同的分割结果
本文将详细介绍如何使用UNet模型,从零开始制作自定义数据集,完成模型训练,并进行推理测试。整个流程基于PyTorch框架实现,适合有一定Python和深度学习基础的读者。
二、数据集制作:从原始图像到标注文件
1. 数据集准备
首先需要准备原始图像和对应的标注掩码(mask)。标注掩码是一张与原始图像尺寸相同的单通道灰度图,其中每个像素值代表对应的类别。例如,在二分类任务中,背景像素值为0,目标像素值为1。
建议:
- 图像和标注文件应一一对应,文件名建议使用相同前缀(如
img_001.jpg
对应mask_001.png
) - 图像尺寸建议统一为256×256或512×512,便于模型处理
- 划分训练集、验证集和测试集,比例建议为7
2
2. 标注工具推荐
- Labelme:开源标注工具,支持多边形、矩形等多种标注方式,可导出JSON格式标注文件,需转换为掩码图
- CVAT:专业视频/图像标注平台,支持团队协作
- Photoshop:手动绘制掩码,适合简单场景
示例:使用Labelme标注后,通过以下Python代码将JSON文件转换为掩码图:
import json
import numpy as np
from PIL import Image
import os
def json_to_mask(json_path, output_path):
with open(json_path) as f:
data = json.load(f)
height = data['imageHeight']
width = data['imageWidth']
mask = np.zeros((height, width), dtype=np.uint8)
for shape in data['shapes']:
points = shape['points']
label = shape['label']
# 简单将所有标注区域设为1(实际应根据label区分类别)
# 更复杂的实现可使用多边形填充算法
# 这里仅作示例
cv2.fillPoly(mask, [np.array(points, dtype=np.int32)], 1)
Image.fromarray(mask * 255).save(output_path)
# 使用示例
json_to_mask('labelme_output.json', 'mask.png')
3. 数据增强
为提升模型泛化能力,建议对训练数据进行增强。常用方法包括:
- 随机水平/垂直翻转
- 随机旋转(±15度)
- 随机缩放(0.9-1.1倍)
- 随机亮度/对比度调整
- 添加高斯噪声
PyTorch实现示例:
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
import random
class CustomAugmentation:
def __init__(self):
self.augmentations = [
self.random_flip,
self.random_rotation,
self.random_scale
]
def __call__(self, img, mask):
for augment in self.augmentations:
img, mask = augment(img, mask)
return img, mask
def random_flip(self, img, mask):
if random.random() > 0.5:
img = F.hflip(img)
mask = F.hflip(mask)
if random.random() > 0.5:
img = F.vflip(img)
mask = F.vflip(mask)
return img, mask
def random_rotation(self, img, mask):
angle = random.uniform(-15, 15)
img = F.rotate(img, angle)
mask = F.rotate(mask, angle, fill=(0,))
return img, mask
def random_scale(self, img, mask):
scale = random.uniform(0.9, 1.1)
h, w = img.size[1], img.size[0]
new_h, new_w = int(h * scale), int(w * scale)
img = F.resize(img, (new_h, new_w))
mask = F.resize(mask, (new_h, new_w), Image.NEAREST)
# 随机裁剪回原尺寸
i, j = random.randint(0, new_h - h), random.randint(0, new_w - w)
img = F.crop(img, i, j, h, w)
mask = F.crop(mask, i, j, h, w)
return img, mask
三、UNet模型实现:从架构设计到代码实现
1. UNet架构解析
UNet由编码器(下采样路径)和解码器(上采样路径)组成,通过跳跃连接融合多尺度特征:
- 编码器:4次下采样,每次通道数翻倍
- 解码器:4次上采样,每次通道数减半
- 跳跃连接:将编码器的特征图与解码器的上采样特征图拼接
2. PyTorch实现代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) if bilinear else nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# 输入是CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
self.up1 = Up(1024, 512, bilinear)
self.up2 = Up(512, 256, bilinear)
self.up3 = Up(256, 128, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
四、模型训练:从数据加载到优化策略
1. 数据加载器实现
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import os
class SegmentationDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(img_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '.png'))
image = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L') # 转换为灰度图
if self.transform:
image, mask = self.transform(image, mask)
# 归一化并添加通道维度
image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
mask = torch.from_numpy(np.array(mask)).long() # 分类任务使用long类型
return image, mask
# 使用示例
train_transform = CustomAugmentation() # 前文定义的数据增强
train_dataset = SegmentationDataset('data/train/images', 'data/train/masks', transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
2. 训练脚本实现
import torch.optim as optim
from tqdm import tqdm
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train_model(model, train_loader, val_loader, epochs=50):
model = model.to(device)
criterion = nn.CrossEntropyLoss() # 多分类任务
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
best_val_loss = float('inf')
for epoch in range(epochs):
model.train()
train_loss = 0
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
for images, masks in progress_bar:
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
train_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
train_loss /= len(train_loader)
# 验证阶段
val_loss = validate(model, val_loader, criterion)
scheduler.step(val_loss)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
def validate(model, val_loader, criterion):
model.eval()
val_loss = 0
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item()
return val_loss / len(val_loader)
# 初始化模型
model = UNet(n_channels=3, n_classes=2) # 假设二分类任务
# 开始训练
train_model(model, train_loader, val_loader)
五、推理测试:从模型加载到结果可视化
1. 模型加载与预处理
def load_model(model_path, n_channels=3, n_classes=2):
model = UNet(n_channels, n_classes)
model.load_state_dict(torch.load(model_path))
model.eval()
return model
def preprocess_image(image_path, target_size=(256, 256)):
image = Image.open(image_path).convert('RGB')
if image.size != target_size:
image = image.resize(target_size, Image.BILINEAR)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0) # 添加batch维度
2. 推理与后处理
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
def predict(model, image_tensor):
with torch.no_grad():
output = model(image_tensor.to(device))
pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
return pred
def visualize_prediction(image, pred_mask, original_mask=None):
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 3, 2)
# 创建颜色映射(0:黑色,1:红色)
cmap = mcolors.ListedColormap(['black', 'red'])
plt.imshow(pred_mask, cmap=cmap)
plt.title('Predicted Mask')
plt.axis('off')
if original_mask is not None:
plt.subplot(1, 3, 3)
plt.imshow(original_mask, cmap=cmap)
plt.title('Original Mask')
plt.axis('off')
plt.tight_layout()
plt.show()
# 使用示例
model = load_model('best_model.pth')
image_tensor = preprocess_image('test_image.jpg')
pred_mask = predict(model, image_tensor)
# 显示结果(假设有原始图像和标注)
original_image = Image.open('test_image.jpg')
original_mask = Image.open('test_mask.png').convert('L')
visualize_prediction(original_image, pred_mask, np.array(original_mask))
六、进阶优化建议
损失函数选择:
- 二分类任务:BCEWithLogitsLoss
- 多分类任务:CrossEntropyLoss
- 类别不平衡时:加权CrossEntropy或Dice Loss
模型改进:
- 使用ResNet或EfficientNet作为编码器 backbone
- 添加注意力机制(如CBAM)
- 使用DeepLabv3+的空洞卷积结构
训练技巧:
- 使用学习率预热(warmup)
- 实现早停(Early Stopping)
- 使用混合精度训练(AMP)
部署优化:
- 模型量化(INT8)
- TensorRT加速
- ONNX格式导出
七、总结与展望
本文详细介绍了使用UNet进行图像语义分割的完整流程,包括自定义数据集制作、模型实现、训练优化和推理测试。通过实践,读者可以掌握:
- 如何准备符合要求的语义分割数据集
- 如何实现经典的UNet架构
- 如何设计有效的训练策略
- 如何进行模型推理和结果可视化
未来工作可以探索:
- 3D医学图像分割(如使用3D UNet)
- 实时语义分割(如使用轻量化模型)
- 弱监督/半监督学习方法
希望本文能为深度学习初学者提供实用的入门指南,也为研究人员提供有价值的实践参考。
发表评论
登录后可评论,请前往 登录 或 注册