如何微调SAM模型:从环境配置到训练实现的完整指南

补充1:

(2025年1月2日)

很多朋友来问数据标注是什么格式,因此添加补充1作解答。

运行代码末尾提供的demo,既可以生成标注格式的demo示例。

python sam-data-setup.py

数据集目录下,放images文件夹、masks文件夹、和annotations.txt,
在这里插入图片描述

images里放原始图片,这里随机生成的。可在这个文件夹里放入自己的数据。

在这里插入图片描述

images里放对应的掩码图像,并且对应更改文件后缀名,在这个文件夹里放入自己数据对应的标签掩码图像。

在这里插入图片描述

annotations.txt里放图片对应的检测框坐标信息。

在这里插入图片描述

引言

Segment Anything Model (SAM) 是 Meta AI 推出的一个强大的图像分割模型。尽管预训练模型表现优秀,但在特定领域(如医疗影像、工业检测等)可能需要进行微调以获得更好的性能。本文将详细介绍如何微调 SAM 模型,包括环境配置、数据准备和训练实现。

目录

  1. 环境配置
  2. 项目结构
  3. 数据准备
  4. 模型微调
  5. 训练过程
  6. 注意事项和优化建议

1. 环境配置

首先,我们需要配置正确的 Python 环境和依赖包。推荐使用虚拟环境来管理依赖:

# 创建并激活虚拟环境
python -m venv sam_env
# Windows:
.\sam_env\Scripts\activate
# Linux/Mac:
source sam_env/bin/activate

# 安装依赖
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install opencv-python
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install numpy matplotlib

# 下载预训练模型
# Windows PowerShell:
Invoke-WebRequest -Uri "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" -OutFile "sam_vit_b_01ec64.pth"
# Linux/Mac:
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

2. 项目结构

推荐的项目结构如下:

project_root/
├── stamps/
│   ├── images/         # 训练图像
│   ├── masks/          # 分割掩码
│   └── annotations.txt # 边界框标注
├── checkpoints/        # 模型检查点
├── setup_sam_data.py   # 数据准备脚本
└── sam_finetune.py     # 训练脚本

3. 数据准备

为了训练模型,我们需要准备以下数据:

  • 训练图像
  • 分割掩码
  • 边界框标注

以下是数据准备脚本的实现:

import os
import numpy as np
import cv2
from pathlib import Path

def create_project_structure():
    """创建项目所需的目录结构"""
    directories = [
        './stamps/images',
        './stamps/masks',
        './checkpoints'
    ]
    
    for dir_path in directories:
        Path(dir_path).mkdir(parents=True, exist_ok=True)
    
    return directories

def create_sample_data(num_samples=5):
    """创建示例训练数据"""
    annotations = []
    
    for i in range(num_samples):
        # 创建示例图像
        image = np.ones((500, 500, 3), dtype=np.uint8) * 255
        center_x = np.random.randint(150, 350)
        center_y = np.random.randint(150, 350)
        radius = np.random.randint(50, 100)
        
        # 绘制对象
        cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
        
        # 创建掩码
        mask = np.zeros((500, 500), dtype=np.uint8)
        cv2.circle(mask, (center_x, center_y), radius, 255, -1)
        
        # 保存文件
        cv2.imwrite(f'./stamps/images/sample_{i}.jpg', image)
        cv2.imwrite(f'./stamps/masks/sample_{i}_mask.png', mask)
        
        # 计算边界框
        x1 = max(0, center_x - radius)
        y1 = max(0, center_y - radius)
        x2 = min(500, center_x + radius)
        y2 = min(500, center_y + radius)
        
        annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
    
    # 保存标注文件
    with open('./stamps/annotations.txt', 'w') as f:
        f.writelines(annotations)

4. 模型微调

4.1 数据集类实现

首先实现自定义数据集类:

class StampDataset(Dataset):
    def __init__(self, image_dir, mask_dir, bbox_file):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = ResizeLongestSide(1024)
        
        # 加载标注
        self.annotations = []
        with open(bbox_file, 'r') as f:
            for line in f:
                img_name, x1, y1, x2, y2 = line.strip().split(',')
                self.annotations.append({
                    'image': img_name,
                    'bbox': [float(x1), float(y1), float(x2), float(y2)]
                })
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        # 加载和预处理图像
        image = cv2.imread(os.path.join(self.image_dir, ann['image']))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(self.mask_dir, 
                         ann['image'].replace('.jpg', '_mask.png')), 
                         cv2.IMREAD_GRAYSCALE)
        mask = mask.astype(np.float32) / 255.0
        
        # 图像处理
        original_size = image.shape[:2]
        input_image = self.transform.apply_image(image)
        input_image = input_image.astype(np.float32) / 255.0
        input_image = torch.from_numpy(input_image).permute(2, 0, 1)
        
        # 标准化
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
        input_image = (input_image - mean) / std
        
        # 处理边界框和掩码
        bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
        bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
        mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
        
        return {
            'image': input_image.float(),
            'original_size': original_size,
            'bbox': bbox_torch,
            'mask': mask_torch
        }

4.2 训练函数实现

训练函数的核心实现:

def train_sam(
    model_type='vit_b',
    checkpoint_path='sam_vit_b_01ec64.pth',
    num_epochs=10,
    batch_size=1,
    learning_rate=1e-5
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 初始化模型
    sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
    sam_model.to(device)
    
    # 准备数据和优化器
    dataset = StampDataset(image_dir='./stamps/images',
                          mask_dir='./stamps/masks',
                          bbox_file='./stamps/annotations.txt')
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
    loss_fn = torch.nn.MSELoss()
    
    # 训练循环
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            # 准备数据
            input_image = batch['image'].to(device)
            original_size = batch['original_size']
            bbox = batch['bbox'].to(device)
            gt_mask = batch['mask'].to(device)
            
            # 前向传播
            with torch.no_grad():
                image_embedding = sam_model.image_encoder(input_image)
                sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                    points=None,
                    boxes=bbox,
                    masks=None,
                )
            
            # 生成预测
            mask_predictions, _ = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            # 后处理
            upscaled_masks = sam_model.postprocess_masks(
                mask_predictions,
                input_size=input_image.shape[-2:],
                original_size=original_size[0]
            ).to(device)
            
            binary_masks = torch.sigmoid(upscaled_masks)
            
            # 计算损失并优化
            loss = loss_fn(binary_masks, gt_mask)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        # 输出epoch统计
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
        
        # 保存检查点
        if (epoch + 1) % 5 == 0:
            checkpoint_file = f'./checkpoints/sam_finetuned_epoch_{epoch+1}.pth'
            torch.save(sam_model.state_dict(), checkpoint_file)

5. 训练过程

完整的训练过程如下:

  1. 准备环境和数据:
python setup_sam_data.py

在这里插入图片描述

  1. 开始训练:
python sam_finetune.py

在这里插入图片描述

6. 注意事项和优化建议

  1. 数据预处理:

    • 确保图像数据类型正确(float32)
    • 进行适当的数据标准化
    • 注意图像尺寸的一致性
  2. 训练优化:

    • 根据GPU内存调整batch_size
    • 适当调整学习率
    • 考虑使用学习率调度器
    • 添加验证集评估
    • 实现早停机制
  3. 可能的改进:

    • 添加数据增强
    • 使用不同的损失函数
    • 实现多GPU训练
    • 添加训练过程可视化
    • 实现模型验证和测试

7. 模型预测和可视化

在完成模型微调后,我们需要一个方便的方式来使用模型进行预测并可视化结果。以下是完整的实现:

7.1 预测器类实现

首先,我们封装一个预测器类,用于处理模型加载、图像预处理和预测:

class SAMPredictor:
    def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
        self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
        self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.sam_model.to(self.device)
        self.transform = ResizeLongestSide(1024)

这个类提供了简单的接口来加载模型并进行预测。主要功能包括:

  • 模型加载和设备配置
  • 图像预处理
  • 掩码预测
  • 后处理优化

7.2 可视化函数

为了better展示预测结果,我们实现了一个可视化函数:

def visualize_prediction(image, mask, bbox, confidence, save_path=None):
    plt.figure(figsize=(15, 5))
    # 显示原始图像、预测掩码和叠加结果
    ...

这个函数可以同时显示:

  • 原始图像(带边界框)
  • 预测的分割掩码
  • 结果叠加视图

7.3 使用示例

以下是如何使用这些工具的完整示例:

# 初始化预测器
predictor = SAMPredictor("./checkpoints/sam_finetuned_final.pth")

# 读取测试图像
image = cv2.imread("test_image.jpg")
bbox = [x1, y1, x2, y2]  # 边界框坐标

# 预测
mask, confidence = predictor.predict(image, bbox)

# 可视化
visualize_prediction(image, mask, bbox, confidence, "result.png")

在这里插入图片描述

7.4 注意事项

在使用预测器时,需要注意以下几点:

  1. 输入图像处理:

    • 确保图像格式正确(RGB)
    • 注意图像尺寸的一致性
    • 正确的数据类型和范围
  2. 边界框格式:

    • 使用 [x1, y1, x2, y2] 格式
    • 确保坐标在图像范围内
    • 坐标值为浮点数
  3. 性能优化:

    • 批处理预测
    • GPU 内存管理
    • 结果缓存

7.5 可能的改进

  1. 批量处理功能:
def predict_batch(self, images, bboxes):
    results = []
    for image, bbox in zip(images, bboxes):
        mask, conf = self.predict(image, bbox)
        results.append((mask, conf))
    return results
  1. 多边界框支持:
def predict_multiple_boxes(self, image, bboxes):
    masks = []
    for bbox in bboxes:
        mask, _ = self.predict(image, bbox)
        masks.append(mask)
    return np.stack(masks)
  1. 交互式可视化:
def interactive_visualization(image, predictor):
    def onclick(event):
        if event.button == 1:  # 左键点击
            bbox = [event.xdata-50, event.ydata-50, 
                   event.xdata+50, event.ydata+50]
            mask, _ = predictor.predict(image, bbox)
            visualize_prediction(image, mask, bbox)
    
    fig, ax = plt.subplots()
    ax.imshow(image)
    fig.canvas.mpl_connect('button_press_event', onclick)
    plt.show()

这些工具和示例可以帮助你更好地理解和使用微调后的SAM模型。根据具体需求,你可以进一步优化和扩展这些功能。

结论

通过以上步骤,我们实现了SAM模型的微调过程。这个实现可以作为基础,根据具体需求进行优化和改进。在实际应用中,可能需要根据具体任务调整数据预处理、损失函数和训练策略。

建议在使用时注意以下几点:

  1. 确保训练数据质量
  2. 合理设置训练参数
  3. 定期保存检查点
  4. 监控训练过程
  5. 适当使用数据增强

希望这个教程对你的项目有所帮助!如果有任何问题,欢迎讨论和交流。

参考资料

  1. Segment Anything 官方仓库
  2. PyTorch 文档
  3. SAM 论文:Segment Anything
  4. torchvision 文档

快速部署:

下载这三个代码,配置好运行环境,依次运行:

# sam-data-setup.py
import os
import numpy as np
import cv2
from pathlib import Path

def create_project_structure():
    """创建项目所需的目录结构"""
    # 创建主目录
    directories = [
        './stamps/images',
        './stamps/masks',
        './checkpoints'
    ]
    
    for dir_path in directories:
        Path(dir_path).mkdir(parents=True, exist_ok=True)
    
    return directories

def create_sample_data(num_samples=5):
    """创建示例训练数据"""
    # 创建示例图像和掩码
    annotations = []
    
    for i in range(num_samples):
        # 创建示例图像 (500x500)
        image = np.ones((500, 500, 3), dtype=np.uint8) * 255
        # 添加一个示例印章 (随机位置的圆形)
        center_x = np.random.randint(150, 350)
        center_y = np.random.randint(150, 350)
        radius = np.random.randint(50, 100)
        
        # 绘制印章
        cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
        
        # 创建对应的掩码
        mask = np.zeros((500, 500), dtype=np.uint8)
        cv2.circle(mask, (center_x, center_y), radius, 255, -1)
        
        # 保存图像和掩码
        image_path = f'./stamps/images/sample_{i}.jpg'
        mask_path = f'./stamps/masks/sample_{i}_mask.png'
        
        cv2.imwrite(image_path, image)
        cv2.imwrite(mask_path, mask)
        
        # 计算边界框
        x1 = max(0, center_x - radius)
        y1 = max(0, center_y - radius)
        x2 = min(500, center_x + radius)
        y2 = min(500, center_y + radius)
        
        # 添加到注释列表
        annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
    
    # 保存注释文件
    with open('./stamps/annotations.txt', 'w') as f:
        f.writelines(annotations)

def main():
    print("开始创建项目结构...")
    directories = create_project_structure()
    for dir_path in directories:
        print(f"创建目录: {dir_path}")
    
    print("\n创建示例训练数据...")
    create_sample_data()
    print("示例数据创建完成!")
    
    print("\n项目结构:")
    for root, dirs, files in os.walk('./stamps'):
        level = root.replace('./stamps', '').count(os.sep)
        indent = ' ' * 4 * level
        print(f"{indent}{os.path.basename(root)}/")
        sub_indent = ' ' * 4 * (level + 1)
        for f in files:
            print(f"{sub_indent}{f}")

if __name__ == '__main__':
    main()
# sam-finetune.py
import torch
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide
from torch.utils.data import Dataset, DataLoader
import cv2
import os

class StampDataset(Dataset):
    def __init__(self, image_dir, mask_dir, bbox_file):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = ResizeLongestSide(1024)  # SAM default size
        
        # Load bbox annotations
        self.annotations = []
        with open(bbox_file, 'r') as f:
            for line in f:
                img_name, x1, y1, x2, y2 = line.strip().split(',')
                self.annotations.append({
                    'image': img_name,
                    'bbox': [float(x1), float(y1), float(x2), float(y2)]
                })
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        # Load image
        image = cv2.imread(os.path.join(self.image_dir, ann['image']))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask
        mask_name = ann['image'].replace('.jpg', '_mask.png')
        mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)
        mask = mask.astype(np.float32) / 255.0
        
        # Prepare image
        original_size = image.shape[:2]
        input_image = self.transform.apply_image(image)
        
        # Convert to float32 and normalize to 0-1 range
        input_image = input_image.astype(np.float32) / 255.0
        
        # Convert to tensor and normalize according to ImageNet stats
        input_image = torch.from_numpy(input_image).permute(2, 0, 1).contiguous()
        
        # Apply ImageNet normalization
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
        input_image = (input_image - mean) / std
        
        # Prepare bbox
        bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
        bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
        
        # Prepare mask
        mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
        
        return {
            'image': input_image.float(),  # ensure float tensor
            'original_size': original_size,
            'bbox': bbox_torch,
            'mask': mask_torch
        }

def train_sam(
    model_type='vit_b',
    checkpoint_path='sam_vit_b_01ec64.pth',
    image_dir='./stamps/images',
    mask_dir='./stamps/masks',
    bbox_file='./stamps/annotations.txt',
    output_dir='./checkpoints',
    num_epochs=10,
    batch_size=1,
    learning_rate=1e-5
):
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model
    sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
    sam_model.to(device)
    
    # Prepare dataset
    dataset = StampDataset(image_dir, mask_dir, bbox_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Setup optimizer
    optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
    
    # Loss function
    loss_fn = torch.nn.MSELoss()
    
    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            # Move inputs to device
            input_image = batch['image'].to(device)
            original_size = batch['original_size']
            bbox = batch['bbox'].to(device)
            gt_mask = batch['mask'].to(device)
            
            # Print shapes and types for debugging
            if batch_idx == 0 and epoch == 0:
                print(f"Input image shape: {input_image.shape}")
                print(f"Input image type: {input_image.dtype}")
                print(f"Input image range: [{input_image.min():.2f}, {input_image.max():.2f}]")
            
            # Get image embedding (without gradient)
            with torch.no_grad():
                image_embedding = sam_model.image_encoder(input_image)
                
                # Get prompt embeddings
                sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                    points=None,
                    boxes=bbox,
                    masks=None,
                )
            
            # Generate mask prediction
            mask_predictions, iou_predictions = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            # Upscale masks to original size
            upscaled_masks = sam_model.postprocess_masks(
                mask_predictions,
                input_size=input_image.shape[-2:],
                original_size=original_size[0]
            ).to(device)
            
            # Convert to binary mask
            binary_masks = torch.sigmoid(upscaled_masks)
            
            # Calculate loss
            loss = loss_fn(binary_masks, gt_mask)
            
            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            checkpoint_file = os.path.join(output_dir, f'sam_finetuned_epoch_{epoch+1}.pth')
            torch.save(sam_model.state_dict(), checkpoint_file)
            print(f'Checkpoint saved: {checkpoint_file}')
    
    # Save final model
    final_checkpoint = os.path.join(output_dir, 'sam_finetuned_final.pth')
    torch.save(sam_model.state_dict(), final_checkpoint)
    print(f'Final model saved to {final_checkpoint}')

if __name__ == '__main__':
    # Create output directory if it doesn't exist
    os.makedirs('./checkpoints', exist_ok=True)
    
    # Start training
    train_sam()
import torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide
import cv2
from pathlib import Path

class SAMPredictor:
    def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
        """
        初始化SAM预测器
        Args:
            checkpoint_path: 模型权重路径
            model_type: 模型类型 ("vit_h", "vit_l", "vit_b")
            device: 使用设备 ("cuda" or "cpu")
        """
        self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
        print(f"Using device: {self.device}")
        
        # 加载模型
        self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.sam_model.to(self.device)
        
        # 创建图像变换器
        self.transform = ResizeLongestSide(1024)
    
    def resize_bbox(self, bbox, original_size, target_size=(1024, 1024)):
        """
        调整边界框坐标以匹配调整大小后的图像
        Args:
            bbox: 原始边界框坐标 [x1, y1, x2, y2]
            original_size: 原始图像尺寸 (height, width)
            target_size: 目标图像尺寸 (height, width)
        Returns:
            resized_bbox: 调整后的边界框坐标
        """
        orig_h, orig_w = original_size
        target_h, target_w = target_size
        
        # 计算缩放比例
        scale_x = target_w / orig_w
        scale_y = target_h / orig_h
        
        # 调整边界框坐标
        x1, y1, x2, y2 = bbox
        resized_bbox = [
            x1 * scale_x,
            y1 * scale_y,
            x2 * scale_x,
            y2 * scale_y
        ]
        
        return resized_bbox
        
    def preprocess_image(self, image):
        """预处理输入图像"""
        # 保存原始尺寸
        original_size = image.shape[:2]
        
        # 确保图像是RGB格式
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
        elif len(image.shape) == 3 and image.shape[2] == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
        # 调整图像大小
        image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
        
        # 转换为float32并归一化
        input_image = image.astype(np.float32) / 255.0
        
        # 转换为tensor并添加batch维度
        input_image = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0)
        
        # 标准化
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
        input_image = (input_image - mean) / std
        
        return input_image.to(self.device), original_size, image
        
    def predict(self, image, bbox):
        """
        预测单个图像的分割掩码
        Args:
            image: numpy array 格式的图像
            bbox: [x1, y1, x2, y2] 格式的边界框
        Returns:
            binary_mask: 二值化的分割掩码
            confidence: 预测的置信度
        """
        # 预处理图像
        input_image, original_size, resized_image = self.preprocess_image(image)
        
        # 调整边界框大小
        resized_bbox = self.resize_bbox(bbox, original_size)
        print(resized_bbox, image.shape, resized_image.shape)
        
        # 准备边界框
        bbox_torch = torch.tensor(resized_bbox, dtype=torch.float, device=self.device).unsqueeze(0)
        
        # 获取图像嵌入
        with torch.no_grad():
            image_embedding = self.sam_model.image_encoder(input_image)
            
            # 获取提示嵌入
            sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
                points=None,
                boxes=bbox_torch,
                masks=None,
            )
            
            # 生成掩码预测
            mask_predictions, iou_predictions = self.sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            # 后处理掩码
            upscaled_masks = self.sam_model.postprocess_masks(
                mask_predictions,
                input_size=input_image.shape[-2:],
                original_size=original_size
            ).to(self.device)
            
            # 转换为二值掩码
            binary_mask = torch.sigmoid(upscaled_masks) > 0.5
            
        return binary_mask[0, 0].cpu().numpy(), iou_predictions[0, 0].item()

def visualize_prediction(image, mask, bbox, confidence, save_path=None):
    """
    可视化预测结果
    Args:
        image: 原始图像
        mask: 预测的掩码
        bbox: 边界框坐标
        confidence: 预测置信度
        save_path: 保存路径(可选)
    """
    # 创建图形
    plt.figure(figsize=(15, 5))
    
    # 显示原始图像
    plt.subplot(131)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title('Original Image')
    # 绘制边界框
    x1, y1, x2, y2 = map(int, bbox)
    plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'r-', linewidth=2)
    plt.axis('off')
    
    # 显示预测掩码
    plt.subplot(132)
    plt.imshow(mask, cmap='gray')
    plt.title(f'Predicted Mask\nConfidence: {confidence:.2f}')
    plt.axis('off')
    
    # 显示叠加结果
    plt.subplot(133)
    overlay = image.copy()
    overlay[mask > 0] = overlay[mask > 0] * 0.7 + np.array([0, 255, 0], dtype=np.uint8) * 0.3
    plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    plt.title('Overlay')
    plt.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        print(f"结果已保存到: {save_path}")
    
    plt.show()

def main():
    # 配置参数
    checkpoint_path = "./checkpoints/sam_finetuned_final.pth"  # 使用微调后的模型
    test_image_path = "./stamps/images/sample_0.jpg"
    output_dir = "./predictions"
    
    # 创建输出目录
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # 初始化预测器
    predictor = SAMPredictor(checkpoint_path)
    
    # 读取测试图像
    image = cv2.imread(test_image_path)
    # image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
    
    # 读取边界框(这里使用示例边界框,实际应用中可能需要从标注文件读取)
    with open('./stamps/annotations.txt', 'r') as f:
        first_line = f.readline().strip()
        _, x1, y1, x2, y2 = first_line.split(',')
        bbox = [float(x1), float(y1), float(x2), float(y2)]
        print(bbox)
    
    # 进行预测
    mask, confidence = predictor.predict(image, bbox)
    
    # 可视化结果
    save_path = str(Path(output_dir) / "prediction_result.png")
    visualize_prediction(image, mask, bbox, confidence, save_path)

if __name__ == "__main__":
    main()

运行结果:

在这里插入图片描述



分割线



补充2:

上文提到的是微调decoder部分,下面补充微调encoder部分的代码:

注意事项:

微调encoder需要更多的计算资源和训练时间
需要更大的训练数据集以避免过拟合
建议使用验证集监控性能,防止模型退化
可能需要更多的训练轮次才能收敛

import torch
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import cv2
import os
from tqdm import tqdm
import logging
import json
from datetime import datetime

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class StampDataset(Dataset):
    def __init__(self, image_dir, mask_dir, bbox_file, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform if transform else ResizeLongestSide(1024)
        
        # 加载标注文件
        self.annotations = []
        with open(bbox_file, 'r') as f:
            for line in f:
                img_name, x1, y1, x2, y2 = line.strip().split(',')
                self.annotations.append({
                    'image': img_name,
                    'bbox': [float(x1), float(y1), float(x2), float(y2)]
                })
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        # 读取图像
        image_path = os.path.join(self.image_dir, ann['image'])
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # 读取mask
        mask_name = ann['image'].replace('.jpg', '_mask.png')
        mask_path = os.path.join(self.mask_dir, mask_name)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise ValueError(f"Failed to load mask: {mask_path}")
        mask = mask.astype(np.float32) / 255.0
        
        # 准备图像
        original_size = image.shape[:2]
        input_image = self.transform.apply_image(image)
        input_image = input_image.astype(np.float32) / 255.0
        
        # 转换为tensor并进行ImageNet归一化
        input_image = torch.from_numpy(input_image).permute(2, 0, 1)
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
        input_image = (input_image - mean) / std
        
        # 准备bbox
        bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
        bbox_torch = torch.tensor(bbox, dtype=torch.float)
        
        # 准备mask
        mask_torch = torch.from_numpy(mask).float()
        
        return {
            'image': input_image,
            'original_size': original_size,
            'bbox': bbox_torch,
            'mask': mask_torch,
            'image_path': image_path  # 用于调试
        }

class SAMFineTuner:
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.setup_model()
        self.setup_datasets()
        self.setup_training()
        
        # 创建输出目录
        os.makedirs(config['output_dir'], exist_ok=True)
        
        # 保存配置
        config_path = os.path.join(config['output_dir'], 'config.json')
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=4)
    
    def setup_model(self):
        logger.info(f"Loading SAM model: {self.config['model_type']}")
        self.model = sam_model_registry[self.config['model_type']](
            checkpoint=self.config['checkpoint_path']
        )
        self.model.to(self.device)
    
    def setup_datasets(self):
        logger.info("Setting up datasets")
        self.train_dataset = StampDataset(
            self.config['train_image_dir'],
            self.config['train_mask_dir'],
            self.config['train_bbox_file']
        )
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config['batch_size'],
            shuffle=True,
            num_workers=self.config['num_workers'],
            pin_memory=True
        )
        
        if self.config.get('val_bbox_file'):
            self.val_dataset = StampDataset(
                self.config['val_image_dir'],
                self.config['val_mask_dir'],
                self.config['val_bbox_file']
            )
            self.val_loader = DataLoader(
                self.val_dataset,
                batch_size=self.config['batch_size'],
                shuffle=False,
                num_workers=self.config['num_workers'],
                pin_memory=True
            )
    
    def setup_training(self):
        logger.info("Setting up training components")
        # 分别设置encoder和decoder的学习率
        self.optimizer = torch.optim.Adam([
            {'params': self.model.image_encoder.parameters(), 
             'lr': self.config['encoder_lr']},
            {'params': self.model.mask_decoder.parameters(), 
             'lr': self.config['decoder_lr']}
        ])
        
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            verbose=True
        )
        
        self.loss_fn = torch.nn.MSELoss()
        self.scaler = GradScaler()
        
        # 记录最佳模型
        self.best_loss = float('inf')
    
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch + 1}')
        for batch_idx, batch in enumerate(pbar):
            # 将数据移到GPU
            input_image = batch['image'].to(self.device)
            bbox = batch['bbox'].to(self.device)
            gt_mask = batch['mask'].to(self.device)
            
            self.optimizer.zero_grad()
            
            with autocast():
                # 前向传播
                image_embedding = self.model.image_encoder(input_image)
                
                with torch.no_grad():
                    sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                        points=None,
                        boxes=bbox,
                        masks=None,
                    )
                
                mask_predictions, _ = self.model.mask_decoder(
                    image_embeddings=image_embedding,
                    image_pe=self.model.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=False,
                )
                
                upscaled_masks = self.model.postprocess_masks(
                    mask_predictions,
                    input_size=input_image.shape[-2:],
                    original_size=batch['original_size'][0]
                ).to(self.device)
                
                binary_masks = torch.sigmoid(upscaled_masks)
                loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))
            
            # 反向传播
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})
        
        return total_loss / len(self.train_loader)
    
    @torch.no_grad()
    def validate(self):
        if not hasattr(self, 'val_loader'):
            return None
            
        self.model.eval()
        total_loss = 0
        
        for batch in tqdm(self.val_loader, desc='Validating'):
            input_image = batch['image'].to(self.device)
            bbox = batch['bbox'].to(self.device)
            gt_mask = batch['mask'].to(self.device)
            
            with autocast():
                image_embedding = self.model.image_encoder(input_image)
                sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                    points=None,
                    boxes=bbox,
                    masks=None,
                )
                
                mask_predictions, _ = self.model.mask_decoder(
                    image_embeddings=image_embedding,
                    image_pe=self.model.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=False,
                )
                
                upscaled_masks = self.model.postprocess_masks(
                    mask_predictions,
                    input_size=input_image.shape[-2:],
                    original_size=batch['original_size'][0]
                ).to(self.device)
                
                binary_masks = torch.sigmoid(upscaled_masks)
                loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))
                
                total_loss += loss.item()
        
        return total_loss / len(self.val_loader)
    
    def save_checkpoint(self, epoch, loss, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': loss,
            'config': self.config
        }
        
        # 保存最新的checkpoint
        checkpoint_path = os.path.join(
            self.config['output_dir'],
            f'checkpoint_epoch_{epoch+1}.pth'
        )
        torch.save(checkpoint, checkpoint_path)
        
        # 如果是最佳模型,额外保存一份
        if is_best:
            best_path = os.path.join(self.config['output_dir'], 'best_model.pth')
            torch.save(checkpoint, best_path)
            logger.info(f"Saved best model with loss: {loss:.4f}")
    
    def train(self):
        logger.info("Starting training")
        for epoch in range(self.config['num_epochs']):
            # 训练一个epoch
            train_loss = self.train_epoch(epoch)
            logger.info(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}")
            
            # 验证
            val_loss = self.validate()
            if val_loss is not None:
                logger.info(f"Epoch {epoch + 1} - Val Loss: {val_loss:.4f}")
                self.scheduler.step(val_loss)
                
                # 检查是否是最佳模型
                is_best = val_loss < self.best_loss
                if is_best:
                    self.best_loss = val_loss
            else:
                is_best = False
                self.scheduler.step(train_loss)
            
            # 保存checkpoint
            if (epoch + 1) % self.config['save_interval'] == 0:
                self.save_checkpoint(
                    epoch,
                    val_loss if val_loss is not None else train_loss,
                    is_best
                )

def main():
    # 训练配置
    config = {
        'model_type': 'vit_b',
        'checkpoint_path': 'sam_vit_b_01ec64.pth',
        'train_image_dir': './data/train/images',
        'train_mask_dir': './data/train/masks',
        'train_bbox_file': './data/train/annotations.txt',
        'val_image_dir': './data/val/images',
        'val_mask_dir': './data/val/masks',
        'val_bbox_file': './data/val/annotations.txt',
        'output_dir': f'./outputs/sam_finetune_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
        'num_epochs': 50,
        'batch_size': 4,
        'num_workers': 4,
        'encoder_lr': 1e-6,
        'decoder_lr': 1e-5,
        'save_interval': 5
    }
    
    # 创建训练器并开始训练
    trainer = SAMFineTuner(config)
    trainer.train()

if __name__ == '__main__':
    main()
Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐