从零开始:UNet图像语义分割实战指南——数据集制作、训练与推理全流程解析
2025.09.18 17:15浏览量:380简介:本文详细介绍如何使用UNet模型进行图像语义分割,从自制数据集的标注与预处理,到模型训练与推理测试,提供完整的代码实现与操作指南,适合初学者与进阶开发者。
一、引言:为什么选择UNet进行语义分割?
UNet(U-shaped Network)由Ronneberger等人在2015年提出,是一种专为医学图像分割设计的全卷积网络(FCN)。其核心思想是通过编码器-解码器结构和跳跃连接实现特征的高效提取与空间信息的保留,尤其适用于小数据集与高分辨率图像。相较于其他模型,UNet具有以下优势:
- 轻量化设计:参数量少,适合边缘设备部署。
- 跳跃连接:融合浅层(空间细节)与深层(语义信息)特征,提升分割精度。
- 扩展性强:可轻松适配不同任务(如卫星图像、工业缺陷检测等)。
本文将围绕UNet展开,从数据集制作到模型训练与推理,提供一套完整的解决方案。
二、数据集制作:从原始图像到标注数据
1. 数据收集与预处理
步骤1:图像采集
- 使用手机或相机拍摄目标场景(如道路、植物、工业零件等),确保光照与角度一致。
- 示例:若需分割道路裂缝,需采集不同光照条件下的裂缝图像。
步骤2:图像标准化
- 统一分辨率(如512×512),避免尺寸差异导致训练不稳定。
- 归一化像素值至[0,1]范围,加速模型收敛。
import cv2import numpy as npdef preprocess_image(image_path, target_size=(512, 512)):image = cv2.imread(image_path)image = cv2.resize(image, target_size)image = image.astype(np.float32) / 255.0 # 归一化return image
2. 标注工具与格式转换
工具选择:
- Labelme:开源标注工具,支持多边形、矩形标注,导出JSON格式。
- CVAT:在线标注平台,适合团队协作。
标注流程:
- 使用Labelme标注目标区域(如裂缝、植物),生成JSON文件。
- 转换为掩码(Mask)图像:
```python
import json
import os
from PIL import Image, ImageDraw
def json_to_mask(json_path, output_path, image_shape=(512, 512)):
with open(json_path) as f:
data = json.load(f)
mask = Image.new('L', image_shape, 0) # 'L'表示灰度图draw = ImageDraw.Draw(mask)for shape in data['shapes']:points = shape['points']if shape['shape_type'] == 'polygon':draw.polygon(points, fill=255) # 填充为白色(255)mask.save(output_path)
**数据集结构**:
dataset/
├── images/
│ ├── img1.jpg
│ └── img2.jpg
└── masks/
├── img1_mask.png
└── img2_mask.png
# 三、UNet模型实现:从编码器到解码器## 1. 模型架构解析UNet的核心是**对称的U型结构**,包含:- **编码器(下采样)**:通过卷积与池化提取高级特征。- **解码器(上采样)**:通过转置卷积恢复空间分辨率。- **跳跃连接**:将编码器的特征图与解码器的上采样结果拼接。## 2. 代码实现(PyTorch)```pythonimport torchimport torch.nn as nnimport torch.nn.functional as Fclass DoubleConv(nn.Module):"""双卷积块:Conv2d + ReLU + Conv2d + ReLU"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class UNet(nn.Module):def __init__(self, n_classes=1):super().__init__()# 编码器self.dconv1 = DoubleConv(3, 64)self.dconv2 = DoubleConv(64, 128)self.dconv3 = DoubleConv(128, 256)self.dconv4 = DoubleConv(256, 512)# 解码器self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)# 输出层self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)def forward(self, x):# 编码器conv1 = self.dconv1(x)pool1 = F.max_pool2d(conv1, 2)conv2 = self.dconv2(pool1)pool2 = F.max_pool2d(conv2, 2)conv3 = self.dconv3(pool2)pool3 = F.max_pool2d(conv3, 2)conv4 = self.dconv4(pool3)# 解码器 + 跳跃连接up3 = self.upconv3(conv4)up3 = torch.cat([up3, conv3], dim=1) # 拼接特征图up2 = self.upconv2(up3)up2 = torch.cat([up2, conv2], dim=1)up1 = self.upconv1(up2)up1 = torch.cat([up1, conv1], dim=1)# 输出out = self.final_conv(up1)return out
四、模型训练:从数据加载到优化
1. 数据加载器(DataLoader)
from torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsclass CustomDataset(Dataset):def __init__(self, image_paths, mask_paths, transform=None):self.image_paths = image_pathsself.mask_paths = mask_pathsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = cv2.imread(self.image_paths[idx])image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)if self.transform:image = self.transform(image)mask = self.transform(mask)return image, mask# 示例:创建DataLoadertransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])dataset = CustomDataset(image_paths, mask_paths, transform=transform)dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
2. 训练循环
import torch.optim as optimfrom tqdm import tqdmdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = UNet(n_classes=1).to(device)criterion = nn.BCEWithLogitsLoss() # 二分类任务optimizer = optim.Adam(model.parameters(), lr=1e-4)def train_model(model, dataloader, epochs=50):model.train()for epoch in range(epochs):running_loss = 0.0for images, masks in tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}'):images = images.to(device)masks = masks.float().unsqueeze(1).to(device) # 增加通道维度optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.4f}')train_model(model, dataloader)
五、推理测试:从模型加载到可视化
1. 模型保存与加载
# 保存模型torch.save(model.state_dict(), 'unet_model.pth')# 加载模型model = UNet(n_classes=1).to(device)model.load_state_dict(torch.load('unet_model.pth'))model.eval()
2. 推理与可视化
import matplotlib.pyplot as pltdef predict_and_visualize(model, image_path, mask_path=None):image = preprocess_image(image_path)image_tensor = transforms.ToTensor()(image).unsqueeze(0).to(device)with torch.no_grad():output = model(image_tensor)pred_mask = torch.sigmoid(output).squeeze().cpu().numpy()# 可视化plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.imshow(image)plt.title('Original Image')plt.subplot(1, 2, 2)plt.imshow(pred_mask, cmap='gray')plt.title('Predicted Mask')if mask_path:true_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)plt.figure(figsize=(5, 5))plt.imshow(true_mask, cmap='gray')plt.title('True Mask')plt.show()# 示例调用predict_and_visualize('test_image.jpg', 'test_mask.png')
六、总结与优化建议
- 数据增强:通过旋转、翻转增加数据多样性,提升模型泛化能力。
- 超参数调优:调整学习率、批次大小以优化训练效果。
- 模型轻量化:使用MobileUNet等变体,适配移动端部署。
通过本文的完整流程,读者可快速掌握UNet从数据集制作到模型推理的全过程,为实际项目提供技术支撑。

发表评论
登录后可评论,请前往 登录 或 注册