从零掌握UNet图像分割:数据集制作、训练与推理全流程详解
2025.09.18 17:15浏览量:77简介:本文详细介绍如何使用UNet模型进行图像语义分割,涵盖从自定义数据集制作、模型训练到推理测试的全流程,适合有一定编程基础的深度学习爱好者。
从零掌握UNet图像分割:数据集制作、训练与推理全流程详解
一、引言:为什么选择UNet进行图像语义分割?
UNet(U-shaped Network)是2015年由Olaf Ronneberger等人提出的经典卷积神经网络架构,专为医学图像分割任务设计,但因其独特的编码器-解码器结构和跳跃连接(skip connections),被广泛应用于各类图像语义分割场景。相比其他模型,UNet具有以下优势:
- 轻量化:参数量相对较少,适合小规模数据集训练
- 多尺度特征融合:通过跳跃连接保留低级特征,提升分割精度
- 端到端训练:可直接输出与输入图像尺寸相同的分割结果
本文将详细介绍如何使用UNet模型,从零开始制作自定义数据集,完成模型训练,并进行推理测试。整个流程基于PyTorch框架实现,适合有一定Python和深度学习基础的读者。
二、数据集制作:从原始图像到标注文件
1. 数据集准备
首先需要准备原始图像和对应的标注掩码(mask)。标注掩码是一张与原始图像尺寸相同的单通道灰度图,其中每个像素值代表对应的类别。例如,在二分类任务中,背景像素值为0,目标像素值为1。
建议:
- 图像和标注文件应一一对应,文件名建议使用相同前缀(如
img_001.jpg对应mask_001.png) - 图像尺寸建议统一为256×256或512×512,便于模型处理
- 划分训练集、验证集和测试集,比例建议为7
2
2. 标注工具推荐
- Labelme:开源标注工具,支持多边形、矩形等多种标注方式,可导出JSON格式标注文件,需转换为掩码图
- CVAT:专业视频/图像标注平台,支持团队协作
- Photoshop:手动绘制掩码,适合简单场景
示例:使用Labelme标注后,通过以下Python代码将JSON文件转换为掩码图:
import jsonimport numpy as npfrom PIL import Imageimport osdef json_to_mask(json_path, output_path):with open(json_path) as f:data = json.load(f)height = data['imageHeight']width = data['imageWidth']mask = np.zeros((height, width), dtype=np.uint8)for shape in data['shapes']:points = shape['points']label = shape['label']# 简单将所有标注区域设为1(实际应根据label区分类别)# 更复杂的实现可使用多边形填充算法# 这里仅作示例cv2.fillPoly(mask, [np.array(points, dtype=np.int32)], 1)Image.fromarray(mask * 255).save(output_path)# 使用示例json_to_mask('labelme_output.json', 'mask.png')
3. 数据增强
为提升模型泛化能力,建议对训练数据进行增强。常用方法包括:
- 随机水平/垂直翻转
- 随机旋转(±15度)
- 随机缩放(0.9-1.1倍)
- 随机亮度/对比度调整
- 添加高斯噪声
PyTorch实现示例:
import torchvision.transforms as transformsfrom torchvision.transforms import functional as Fimport randomclass CustomAugmentation:def __init__(self):self.augmentations = [self.random_flip,self.random_rotation,self.random_scale]def __call__(self, img, mask):for augment in self.augmentations:img, mask = augment(img, mask)return img, maskdef random_flip(self, img, mask):if random.random() > 0.5:img = F.hflip(img)mask = F.hflip(mask)if random.random() > 0.5:img = F.vflip(img)mask = F.vflip(mask)return img, maskdef random_rotation(self, img, mask):angle = random.uniform(-15, 15)img = F.rotate(img, angle)mask = F.rotate(mask, angle, fill=(0,))return img, maskdef random_scale(self, img, mask):scale = random.uniform(0.9, 1.1)h, w = img.size[1], img.size[0]new_h, new_w = int(h * scale), int(w * scale)img = F.resize(img, (new_h, new_w))mask = F.resize(mask, (new_h, new_w), Image.NEAREST)# 随机裁剪回原尺寸i, j = random.randint(0, new_h - h), random.randint(0, new_w - w)img = F.crop(img, i, j, h, w)mask = F.crop(mask, i, j, h, w)return img, mask
三、UNet模型实现:从架构设计到代码实现
1. UNet架构解析
UNet由编码器(下采样路径)和解码器(上采样路径)组成,通过跳跃连接融合多尺度特征:
- 编码器:4次下采样,每次通道数翻倍
- 解码器:4次上采样,每次通道数减半
- 跳跃连接:将编码器的特征图与解码器的上采样特征图拼接
2. PyTorch实现代码
import torchimport torch.nn as nnimport torch.nn.functional as Fclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) if bilinear else nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# 输入是CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])x = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=True):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 1024)self.up1 = Up(1024, 512, bilinear)self.up2 = Up(512, 256, bilinear)self.up3 = Up(256, 128, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits
四、模型训练:从数据加载到优化策略
1. 数据加载器实现
from torch.utils.data import Dataset, DataLoaderfrom PIL import Imageimport torchimport osclass SegmentationDataset(Dataset):def __init__(self, img_dir, mask_dir, transform=None):self.img_dir = img_dirself.mask_dir = mask_dirself.transform = transformself.images = os.listdir(img_dir)def __len__(self):return len(self.images)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.images[idx])mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '.png'))image = Image.open(img_path).convert('RGB')mask = Image.open(mask_path).convert('L') # 转换为灰度图if self.transform:image, mask = self.transform(image, mask)# 归一化并添加通道维度image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0mask = torch.from_numpy(np.array(mask)).long() # 分类任务使用long类型return image, mask# 使用示例train_transform = CustomAugmentation() # 前文定义的数据增强train_dataset = SegmentationDataset('data/train/images', 'data/train/masks', transform=train_transform)train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
2. 训练脚本实现
import torch.optim as optimfrom tqdm import tqdmimport numpy as npdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def train_model(model, train_loader, val_loader, epochs=50):model = model.to(device)criterion = nn.CrossEntropyLoss() # 多分类任务optimizer = optim.Adam(model.parameters(), lr=1e-4)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)best_val_loss = float('inf')for epoch in range(epochs):model.train()train_loss = 0progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')for images, masks in progress_bar:images, masks = images.to(device), masks.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()train_loss += loss.item()progress_bar.set_postfix(loss=loss.item())train_loss /= len(train_loader)# 验证阶段val_loss = validate(model, val_loader, criterion)scheduler.step(val_loss)print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), 'best_model.pth')def validate(model, val_loader, criterion):model.eval()val_loss = 0with torch.no_grad():for images, masks in val_loader:images, masks = images.to(device), masks.to(device)outputs = model(images)loss = criterion(outputs, masks)val_loss += loss.item()return val_loss / len(val_loader)# 初始化模型model = UNet(n_channels=3, n_classes=2) # 假设二分类任务# 开始训练train_model(model, train_loader, val_loader)
五、推理测试:从模型加载到结果可视化
1. 模型加载与预处理
def load_model(model_path, n_channels=3, n_classes=2):model = UNet(n_channels, n_classes)model.load_state_dict(torch.load(model_path))model.eval()return modeldef preprocess_image(image_path, target_size=(256, 256)):image = Image.open(image_path).convert('RGB')if image.size != target_size:image = image.resize(target_size, Image.BILINEAR)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])return transform(image).unsqueeze(0) # 添加batch维度
2. 推理与后处理
import matplotlib.pyplot as pltimport matplotlib.colors as mcolorsdef predict(model, image_tensor):with torch.no_grad():output = model(image_tensor.to(device))pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()return preddef visualize_prediction(image, pred_mask, original_mask=None):plt.figure(figsize=(12, 6))plt.subplot(1, 3, 1)plt.imshow(image)plt.title('Original Image')plt.axis('off')plt.subplot(1, 3, 2)# 创建颜色映射(0:黑色,1:红色)cmap = mcolors.ListedColormap(['black', 'red'])plt.imshow(pred_mask, cmap=cmap)plt.title('Predicted Mask')plt.axis('off')if original_mask is not None:plt.subplot(1, 3, 3)plt.imshow(original_mask, cmap=cmap)plt.title('Original Mask')plt.axis('off')plt.tight_layout()plt.show()# 使用示例model = load_model('best_model.pth')image_tensor = preprocess_image('test_image.jpg')pred_mask = predict(model, image_tensor)# 显示结果(假设有原始图像和标注)original_image = Image.open('test_image.jpg')original_mask = Image.open('test_mask.png').convert('L')visualize_prediction(original_image, pred_mask, np.array(original_mask))
六、进阶优化建议
损失函数选择:
- 二分类任务:BCEWithLogitsLoss
- 多分类任务:CrossEntropyLoss
- 类别不平衡时:加权CrossEntropy或Dice Loss
模型改进:
- 使用ResNet或EfficientNet作为编码器 backbone
- 添加注意力机制(如CBAM)
- 使用DeepLabv3+的空洞卷积结构
训练技巧:
- 使用学习率预热(warmup)
- 实现早停(Early Stopping)
- 使用混合精度训练(AMP)
部署优化:
- 模型量化(INT8)
- TensorRT加速
- ONNX格式导出
七、总结与展望
本文详细介绍了使用UNet进行图像语义分割的完整流程,包括自定义数据集制作、模型实现、训练优化和推理测试。通过实践,读者可以掌握:
- 如何准备符合要求的语义分割数据集
- 如何实现经典的UNet架构
- 如何设计有效的训练策略
- 如何进行模型推理和结果可视化
未来工作可以探索:
- 3D医学图像分割(如使用3D UNet)
- 实时语义分割(如使用轻量化模型)
- 弱监督/半监督学习方法
希望本文能为深度学习初学者提供实用的入门指南,也为研究人员提供有价值的实践参考。

发表评论
登录后可评论,请前往 登录 或 注册