【论文阅读】Directional Connectivity-based Segmentation of Medical Images(可以跑通代码)
CVPR2023有效地将方向子空间从共享潜在空间中解耦可以显著增强基于连通性网络中的特征表示。
论文:Directional Connectivity-based Segmentation of Medical Images
代码:https://github.com/zyun-y/dconnnet
摘要
出发点:生物标志分割中的解剖学一致性对许多医学图像分析任务至关重要。
之前工作的问题:以往的连通性工作忽略了潜在空间中丰富的信道方向的信息。
证明:有效地将方向子空间从共享潜在空间中解耦可以显著增强基于连通性网络中的特征表示。
提出:一种用于分割的定向连通性建模方案,该方案解耦、跟踪和利用跨网络的方向信息。
介绍
介绍了基于像素分类和基于连通性的模型之间潜在的空间差异。前者仅突出分类特征,eg:边界。后者包含方向信息,例如:边界像素之间的水平连接。
将两组潜在特征(范畴性和方向性)在DconnNet的潜在空间中的流向用T - SNE进行可视化。它们先被解缠,然后在一个投影的共享流形中有效地融合,基于聚类的结果进行颜色的渲染。
这个是普通的分割掩码变成连通性掩码的示意图。每一个原来一个像素的位置包含了周围8个像素的mask值。一个一个对应即可。感觉好像这个图中间那个错了,中间像素的C1是positive。
方法
由于不同像素类别和方向之间的连通性,基于连通性的网络的潜在空间中存在两组特征:类别信息和方向信息。每一组特征在隐空间中形成其特定的子空间。两个子空间是高度耦合的。我们证明了方向空间的有效解缠和有效利用可以增强连通性模型中的整体特征表示。
Pretrained ResNet:提取特征。
SDE:特征信息和方向信息解耦。
IFD:特征信息与方向信息融合。
效果
利用T - SNE对DconnNet在SDE模块前后的隐通道嵌入进行可视化。( b )中的颜色表示无监督聚类结果。当应用于SDE时,通道嵌入自然地分组为几个不同的部分。
结论
其核心思想是将方向子空间从共享的潜在空间中解耦出来,并利用提取的方向特征来增强整体的数据表示。
- 通过与其他先进方法的统计比较,显示了DconnNet的整体性能更好。
- 通过在一个拓扑敏感的数据集上定性和定量地将DconnNet与其他方法进行比较,展示了其保留拓扑结构的能力。
- 通过对DconnNet的隐空间进行可视化,揭示了方向子空间的解纠缠过程
跑通代码
数据集方面
#作者的数据读取中是读取的3通道图像和二值的mask图像,我们写一个数据集读取的函数能让他输出读取的3通道图像和二值的mask图像变成的tensor就可以。自己可以加一点数据增强。
#root_path 是数据集的地址,fold_json存储了10折的图片的名称,fold_num是选取哪一个折作为验证集,就是十折交叉验证的内容,image_size最后resize的图片大小,mode是训练还是验证,augmentation_prob就是数据增强的概率
#大家按照自己的数据集,写一个数据集的函数,之后跑别人代码的时候直接用就可以
import os
import random
from random import shuffle
import numpy as np
import torch
from torch.utils import data
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
import numpy as np
import json
from .GetDataset_CHASE import connectivity_matrix
class ImageFolder(data.Dataset):
def __init__(self, image_root,label_root,json_path, fold=1, image_size=400, mode='train', augmentation_prob=0.4):
"""Initializes image paths and preprocessing module."""
self.root = image_root
with open(json_path, 'r') as load_f:
self.fold_data = json.load(load_f)
self.data_list=[]
if mode == 'train':
for i in range(1, 11):
if i != fold:
self.data_list += self.fold_data['Fold ' + str(i)]
elif mode == 'val':
self.data_list = self.fold_data['Fold ' + str(fold)]
else:
raise ValueError("数据类型只有train和val")
self.image_size = image_size
self.label_root = label_root
self.mode = mode
# self.RotationDegree = [0,90,180,270]
self.augmentation_prob = augmentation_prob
print("image count in {} path :{}".format(self.mode,len(self.data_list)))
def __getitem__(self, index):
"""Reads an image from a file and preprocesses it and returns."""
image_path = os.path.join(self.root,self.data_list[index])
GT_path = os.path.join(self.label_root ,self.data_list[index])
image = Image.open(image_path).convert('RGB')
GT = Image.open(GT_path).convert('1')
aspect_ratio = image.size[0]/image.size[1]#weight/height
Transform = []
ResizeRange = random.randint(500,525)
Transform.append(T.Resize((ResizeRange,int(ResizeRange*aspect_ratio))))
p_transform = random.random()
if (self.mode == 'train') and p_transform <= self.augmentation_prob:
RotationRange = random.randint(-10,10)
Transform.append(T.RandomRotation((RotationRange,RotationRange)))
CropRange = random.randint(500,525)
Transform.append(T.CenterCrop((CropRange,int(CropRange*aspect_ratio))))
Transform = T.Compose(Transform)
image = Transform(image)
GT = Transform(GT)
ShiftRange_left = random.randint(0,20)
ShiftRange_upper = random.randint(0,20)
ShiftRange_right = image.size[0] - random.randint(0,20)
ShiftRange_lower = image.size[1] - random.randint(0,20)
image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
#
# if random.random() < 0.5:
# image = F.hflip(image)
# GT = F.hflip(GT)
#
# if random.random() < 0.5:
# image = F.vflip(image)
# GT = F.vflip(GT)
Transform = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02)
image = Transform(image)
Transform =[]
Transform.append(T.Resize([256,256]))
Transform.append(T.ToTensor())
Transform = T.Compose(Transform)
image = Transform(image)
GT = Transform(GT)
mean = [0.1591, 0.1591, 0.1591]
std = [0.2593, 0.2593, 0.2593]
Norm_ = T.Normalize(mean, std)
image = Norm_(image)
# images = image
# image = torch.unsqueeze(image,0)
# images = torch.cat([images,image],dim=1)
return image, GT
def __len__(self):
"""Returns the total number of font files."""
return len(self.data_list)
def get_loader(image_path,label_path,image_size, batch_size, json_path, fold =None,num_workers=2, mode='train',augmentation_prob=0.4):
"""Builds and returns Dataloader."""
dataset = ImageFolder(image_root = image_path,label_root=label_path,json_path=json_path, fold = fold,image_size =image_size, mode=mode,augmentation_prob=augmentation_prob)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader
main函数修改
我直接在函数上写上自己数据集的地址和相关参数了由于时间原因,大家可以把它加入到args参数里面更规范。我是class=1的任务,所以一定把class修改了,源代码中是4.
def main(args):
## K-fold cross validation ##
for exp_id in range(args.folds):
#
train_loader = get_loader(image_path='自己数据集中image的位置',
label_path='自己数据集中mask的位置',
json_path="json文件的位置",
image_size=(256,256),
batch_size=1,
fold=1,
num_workers=8,
mode='train',
augmentation_prob=0.4)
val_loader = get_loader(image_path='自己数据集中image的位置',
label_path='自己数据集中mask的位置',
json_path="json文件的位置",
image_size=(256,256),
batch_size=1,
fold=1,
num_workers=8,
mode='val',
augmentation_prob=0.)
print("Train batch number: %i" % len(train_loader))
print("Test batch number: %i" % len(val_loader))
#### Above: define how you get the data on your own dataset ######
model = DconnNet(num_class=1).cuda()
if args.pretrained:
model.load_state_dict(torch.load(args.pretrained,map_location = torch.device('cpu')))
model = model.cuda()
solver = Solver(args)
solver.train(model, train_loader, val_loader,exp_id+1, num_epochs=args.epochs)
connect_loss.py
我是一个类别的,运行一直报错,是connect_loss.py这个函数它最后有个conn = conn.squeeze()注释掉就可以运行了。
def connectivity_matrix(multimask, class_num):
##### converting segmentation masks to connectivity masks ####
[batch,_,rows, cols] = multimask.shape
# batch = 1
conn = torch.zeros([batch,class_num*8,rows, cols]).cuda()
for i in range(class_num):
mask = multimask[:,i,:,:]
# print(mask.shape)
up = torch.zeros([batch,rows, cols]).cuda()#move the orignal mask to up
down = torch.zeros([batch,rows, cols]).cuda()
left = torch.zeros([batch,rows, cols]).cuda()
right = torch.zeros([batch,rows, cols]).cuda()
up_left = torch.zeros([batch,rows, cols]).cuda()
up_right = torch.zeros([batch,rows, cols]).cuda()
down_left = torch.zeros([batch,rows, cols]).cuda()
down_right = torch.zeros([batch,rows, cols]).cuda()
up[:,:rows-1, :] = mask[:,1:rows,:]
down[:,1:rows,:] = mask[:,0:rows-1,:]
left[:,:,:cols-1] = mask[:,:,1:cols]
right[:,:,1:cols] = mask[:,:,:cols-1]
up_left[:,0:rows-1,0:cols-1] = mask[:,1:rows,1:cols]
up_right[:,0:rows-1,1:cols] = mask[:,1:rows,0:cols-1]
down_left[:,1:rows,0:cols-1] = mask[:,0:rows-1,1:cols]
down_right[:,1:rows,1:cols] = mask[:,0:rows-1,0:cols-1]
conn[:,(i*8)+0,:,:] = mask*down_right
conn[:,(i*8)+1,:,:] = mask*down
conn[:,(i*8)+2,:,:] = mask*down_left
conn[:,(i*8)+3,:,:] = mask*right
conn[:,(i*8)+4,:,:] = mask*left
conn[:,(i*8)+5,:,:] = mask*up_right
conn[:,(i*8)+6,:,:] = mask*up
conn[:,(i*8)+7,:,:] = mask*up_left
conn = conn.float()
# conn = conn.squeeze()
# print(conn.shape)
return conn
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)