Win上使用自己的数据集跑通DEtection TRansformer(DETR)
目录前言NO.1 准备自制数据集No.2 安装环境No.3 配置参数最后前言《End-to-End Object Detection with Transformers》论文:https://arxiv.org/abs/2005.12872v3Code:https://github.com/facebookresearch/detrNO.1 准备自制数据集这里假设原来为VOC数据集,需要转换成co
前言
《End-to-End Object Detection with Transformers》
论文:https://arxiv.org/abs/2005.12872v3
Code:https://github.com/facebookresearch/detr
NO.1 准备自制数据集
这里假设原来为VOC数据集,需要转换成coco数据集:
①:将txt转换成xml
from xml.dom.minidom import Document
import os
import cv2
def makexml(picPath, txtPath, xmlPath): # txt所在文件夹路径,xml文件保存路径,图片所在文件夹路径
"""此函数用于将yolo格式txt标注文件转换为voc格式xml标注文件
在自己的标注图片文件夹下建三个子文件夹,分别命名为picture、txt、xml
"""
dic = {'0': "A", # 创建字典用来对类型进行转换
'1': "B", # 此处的字典要与自己的classes.txt文件中的类对应,且顺序要一致
}
files = os.listdir(txtPath)
for i, name in enumerate(files):
xmlBuilder = Document()
annotation = xmlBuilder.createElement("annotation") # 创建annotation标签
xmlBuilder.appendChild(annotation)
txtFile = open(txtPath + name)
txtList = txtFile.readlines()
img = cv2.imread(picPath + name[0:-4] + ".jpg")
Pheight, Pwidth, Pdepth = img.shape
folder = xmlBuilder.createElement("folder") # folder标签
foldercontent = xmlBuilder.createTextNode("driving_annotation_dataset")
folder.appendChild(foldercontent)
annotation.appendChild(folder) # folder标签结束
filename = xmlBuilder.createElement("filename") # filename标签
filenamecontent = xmlBuilder.createTextNode(name[0:-4] + ".jpg")
filename.appendChild(filenamecontent)
annotation.appendChild(filename) # filename标签结束
size = xmlBuilder.createElement("size") # size标签
width = xmlBuilder.createElement("width") # size子标签width
widthcontent = xmlBuilder.createTextNode(str(Pwidth))
width.appendChild(widthcontent)
size.appendChild(width) # size子标签width结束
height = xmlBuilder.createElement("height") # size子标签height
heightcontent = xmlBuilder.createTextNode(str(Pheight))
height.appendChild(heightcontent)
size.appendChild(height) # size子标签height结束
depth = xmlBuilder.createElement("depth") # size子标签depth
depthcontent = xmlBuilder.createTextNode(str(Pdepth))
depth.appendChild(depthcontent)
size.appendChild(depth) # size子标签depth结束
annotation.appendChild(size) # size标签结束
for j in txtList:
oneline = j.strip().split(" ")
object = xmlBuilder.createElement("object") # object 标签
picname = xmlBuilder.createElement("name") # name标签
namecontent = xmlBuilder.createTextNode(dic[oneline[0]])
picname.appendChild(namecontent)
object.appendChild(picname) # name标签结束
pose = xmlBuilder.createElement("pose") # pose标签
posecontent = xmlBuilder.createTextNode("Unspecified")
pose.appendChild(posecontent)
object.appendChild(pose) # pose标签结束
truncated = xmlBuilder.createElement("truncated") # truncated标签
truncatedContent = xmlBuilder.createTextNode("0")
truncated.appendChild(truncatedContent)
object.appendChild(truncated) # truncated标签结束
difficult = xmlBuilder.createElement("difficult") # difficult标签
difficultcontent = xmlBuilder.createTextNode("0")
difficult.appendChild(difficultcontent)
object.appendChild(difficult) # difficult标签结束
bndbox = xmlBuilder.createElement("bndbox") # bndbox标签
xmin = xmlBuilder.createElement("xmin") # xmin标签
mathData = int(((float(oneline[1])) * Pwidth + 1) - (float(oneline[3])) * 0.5 * Pwidth)
xminContent = xmlBuilder.createTextNode(str(mathData))
xmin.appendChild(xminContent)
bndbox.appendChild(xmin) # xmin标签结束
ymin = xmlBuilder.createElement("ymin") # ymin标签
mathData = int(((float(oneline[2])) * Pheight + 1) - (float(oneline[4])) * 0.5 * Pheight)
yminContent = xmlBuilder.createTextNode(str(mathData))
ymin.appendChild(yminContent)
bndbox.appendChild(ymin) # ymin标签结束
xmax = xmlBuilder.createElement("xmax") # xmax标签
mathData = int(((float(oneline[1])) * Pwidth + 1) + (float(oneline[3])) * 0.5 * Pwidth)
xmaxContent = xmlBuilder.createTextNode(str(mathData))
xmax.appendChild(xmaxContent)
bndbox.appendChild(xmax) # xmax标签结束
ymax = xmlBuilder.createElement("ymax") # ymax标签
mathData = int(((float(oneline[2])) * Pheight + 1) + (float(oneline[4])) * 0.5 * Pheight)
ymaxContent = xmlBuilder.createTextNode(str(mathData))
ymax.appendChild(ymaxContent)
bndbox.appendChild(ymax) # ymax标签结束
object.appendChild(bndbox) # bndbox标签结束
annotation.appendChild(object) # object标签结束
f = open(xmlPath + name[0:-4] + ".xml", 'w')
xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
f.close()
if __name__ == "__main__":
picPath = "./VOC2007/JPEGImage/" # 图片所在文件夹路径,后面的/一定要带上
txtPath = "./VOC2007/yolo/" # txt所在文件夹路径,后面的/一定要带上
xmlPath = "./VOC2007/Annotations_train/" # xml文件保存路径,后面的/一定要带上
makexml(picPath, txtPath, xmlPath)
②:对train、val、test分别转换成json文件(以test为例)
# -*- coding=utf-8 -*-
#!/usr/bin/python
import sys
import os
import shutil
import numpy as np
import json
import xml.etree.ElementTree as ET
# 检测框的ID起始值
START_BOUNDING_BOX_ID = 1
# 类别列表无必要预先创建,程序中会根据所有图像中包含的ID来创建并更新
# If necessary, pre-define category and its id
# PRE_DEFINE_CATEGORIES = {"aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
# "bottle":5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
# "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
# "motorbike": 14, "person": 15, "pottedplant": 16,
# "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}
PRE_DEFINE_CATEGORIES = {"cat": 1, "dog": 2}
def get(root, name):
vars = root.findall(name)
return vars
def get_and_check(root, name, length):
vars = root.findall(name)
if len(vars) == 0:
raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
if length > 0 and len(vars) != length:
raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
if length == 1:
vars = vars[0]
return vars
# 得到图片唯一标识号
def get_filename_as_int(filename):
try:
filename = os.path.splitext(filename)[0]
return int(filename)
except:
raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename))
def convert(xml_list, xml_dir, json_file):
'''
:param xml_list: 需要转换的XML文件列表
:param xml_dir: XML的存储文件夹
:param json_file: 导出json文件的路径
:return: None
'''
list_fp = xml_list
# 标注基本结构
json_dict = {"images":[],
"type": "instances",
"annotations": [],
"categories": []}
categories = PRE_DEFINE_CATEGORIES
bnd_id = START_BOUNDING_BOX_ID
for num, line in enumerate(list_fp):
line = line.strip()
print("buddy~ Processing {}".format(line))
# 解析XML
xml_f = os.path.join(xml_dir, line)
tree = ET.parse(xml_f)
root = tree.getroot()
path = get(root, 'path')
filename = line[:-4] + '.jpg'
"""
# 取出图片名字
if len(path) == 1:
filename = os.path.basename(path[0].text)
elif len(path) == 0:
filename = get_and_check(root, 'filename', 1).text
else:
raise NotImplementedError('%d paths found in %s'%(len(path), line))
"""
## The filename must be a number
##image_id = get_filename_as_int(filename) # 图片ID
image_id = num + 1
size = get_and_check(root, 'size', 1)
# 图片的基本信息
width = int(get_and_check(size, 'width', 1).text)
height = int(get_and_check(size, 'height', 1).text)
image = {'file_name': filename,
'height': height,
'width': width,
'id':image_id}
json_dict['images'].append(image)
## Cruuently we do not support segmentation
# segmented = get_and_check(root, 'segmented', 1).text
# assert segmented == '0'
# 处理每个标注的检测框
for obj in get(root, 'object'):
# 取出检测框类别名称
category = get_and_check(obj, 'name', 1).text
# 更新类别ID字典
if category not in categories:
new_id = len(categories)
categories[category] = new_id
category_id = categories[category]
bndbox = get_and_check(obj, 'bndbox', 1)
xmin = int(get_and_check(bndbox, 'xmin', 1).text) - 1
ymin = int(get_and_check(bndbox, 'ymin', 1).text) - 1
xmax = int(get_and_check(bndbox, 'xmax', 1).text)
ymax = int(get_and_check(bndbox, 'ymax', 1).text)
assert(xmax > xmin)
assert(ymax > ymin)
o_width = abs(xmax - xmin)
o_height = abs(ymax - ymin)
annotation = dict()
annotation['area'] = o_width*o_height
annotation['iscrowd'] = 0
annotation['image_id'] = image_id
annotation['bbox'] = [xmin, ymin, o_width, o_height]
annotation['category_id'] = category_id
annotation['id'] = bnd_id
annotation['ignore'] = 0
# 设置分割数据,点的顺序为逆时针方向
annotation['segmentation'] = [[xmin,ymin,xmin,ymax,xmax,ymax,xmax,ymin]]
json_dict['annotations'].append(annotation)
bnd_id = bnd_id + 1
# 写入类别ID字典
for cate, cid in categories.items():
cat = {'supercategory': cate, 'id': cid, 'name': cate}
json_dict['categories'].append(cat)
# 导出到json
json_fp = open(json_file, 'w')
json_str = json.dumps(json_dict, indent=4)
json_fp.write(json_str)
json_fp.close()
if __name__ == '__main__':
root_path = os.getcwd()
anno_file = "test_annotations"
xml_dir = os.path.join(root_path, anno_file)
xml_labels = os.listdir(os.path.join(root_path, anno_file))
np.random.shuffle(xml_labels)
#split_point = int(len(xml_labels)/10)
xml_list = xml_labels
json_file = './instances_test2017.json'
convert(xml_list, xml_dir, json_file)
形成:
coco
--annotations
--instances_train2017.json
--instances_val2017.json
--instances_test2017.json
--train2017 #放train图片
--val2017 #放val图片
--test2017 #放test图片
No.2 安装环境
①:这里Anaconda管理、新建、激活环境不再赘述
② :安装DETR运行必要环境(方案很多,这里:PyCharm中的terminal中安装)
#示例
pip install -r E:\x\xx\...\requirements.txt
其中:
①:git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI&egg=pycocotools(下载)
#cd进入PythonAPI文件夹
python setup.py build_ext install
②:git+https://github.com/cocodataset/panopticapi.git#egg=panopticapi(下载)
pip install .\panopticapi-master.zip
No.3 配置参数
①(可选):
import torch
pretrained_weights = torch.load("./detr-r50-e632da11.pth") #官方处下载
num_class = 2 + 1 #假设类别数量为2,修改为自己类别数量,+1的1为背景
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights,'detr_r50_%d.pth'%num_class)
②:detr.py中
num_classes = 3 if args.dataset_file != 'coco' else 3 #这里全改3了
③:main.py中
parser.add_argument('--dataset_file', default='coco')
parser.add_argument('--coco_path', defaut='coco数据集路径', type=str)
parser.add_argument('--resume', default='①生成的detr_r50_3.pth(可不填)', help='resume from checkpoint')
#其他参数根据需要进行修改
最后
①:
self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True)
……
……
……
TypeError: ‘numpy.float64’ object cannot be interpreted as an integer
可参考解决方案:
对环境位置site-packages\pycocotools中的cocoeval.py进行修改(具体查看自己报错位置)
#self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True)
#self.recThrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True)
self.iouThrs = np.linspace(.5, 0.95, 10, endpoint=True)
self.recThrs = np.linspace(.0, 1.00, 101, endpoint=True)
②:Plot
在.\util\plot_utils.py最后添加
if __name__ == '__main__':
files = list(Path('../outputs/eval').glob('*.pth'))
plot_precision_recall(files)
plt.show()
plot_logs(logs=Path('../outputs/log/'),fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt')
plt.show()
更多推荐
所有评论(0)