前言

《End-to-End Object Detection with Transformers》
论文:https://arxiv.org/abs/2005.12872v3
Code:https://github.com/facebookresearch/detr


NO.1 准备自制数据集

这里假设原来为VOC数据集,需要转换成coco数据集:
①:将txt转换成xml

from xml.dom.minidom import Document
import os
import cv2
 
def makexml(picPath, txtPath, xmlPath):  # txt所在文件夹路径,xml文件保存路径,图片所在文件夹路径
    """此函数用于将yolo格式txt标注文件转换为voc格式xml标注文件
    在自己的标注图片文件夹下建三个子文件夹,分别命名为picture、txt、xml
    """
    dic = {'0': "A",  # 创建字典用来对类型进行转换
           '1': "B",  # 此处的字典要与自己的classes.txt文件中的类对应,且顺序要一致
           }
    files = os.listdir(txtPath)
    for i, name in enumerate(files):
        xmlBuilder = Document()
        annotation = xmlBuilder.createElement("annotation")  # 创建annotation标签
        xmlBuilder.appendChild(annotation)
        txtFile = open(txtPath + name)
        txtList = txtFile.readlines()
        img = cv2.imread(picPath + name[0:-4] + ".jpg")
        Pheight, Pwidth, Pdepth = img.shape
 
        folder = xmlBuilder.createElement("folder")  # folder标签
        foldercontent = xmlBuilder.createTextNode("driving_annotation_dataset")
        folder.appendChild(foldercontent)
        annotation.appendChild(folder)  # folder标签结束
 
        filename = xmlBuilder.createElement("filename")  # filename标签
        filenamecontent = xmlBuilder.createTextNode(name[0:-4] + ".jpg")
        filename.appendChild(filenamecontent)
        annotation.appendChild(filename)  # filename标签结束
 
        size = xmlBuilder.createElement("size")  # size标签
        width = xmlBuilder.createElement("width")  # size子标签width
        widthcontent = xmlBuilder.createTextNode(str(Pwidth))
        width.appendChild(widthcontent)
        size.appendChild(width)  # size子标签width结束
 
        height = xmlBuilder.createElement("height")  # size子标签height
        heightcontent = xmlBuilder.createTextNode(str(Pheight))
        height.appendChild(heightcontent)
        size.appendChild(height)  # size子标签height结束
 
        depth = xmlBuilder.createElement("depth")  # size子标签depth
        depthcontent = xmlBuilder.createTextNode(str(Pdepth))
        depth.appendChild(depthcontent)
        size.appendChild(depth)  # size子标签depth结束
 
        annotation.appendChild(size)  # size标签结束
 
        for j in txtList:
            oneline = j.strip().split(" ")
            object = xmlBuilder.createElement("object")  # object 标签
            picname = xmlBuilder.createElement("name")  # name标签
            namecontent = xmlBuilder.createTextNode(dic[oneline[0]])
            picname.appendChild(namecontent)
            object.appendChild(picname)  # name标签结束
 
            pose = xmlBuilder.createElement("pose")  # pose标签
            posecontent = xmlBuilder.createTextNode("Unspecified")
            pose.appendChild(posecontent)
            object.appendChild(pose)  # pose标签结束
 
            truncated = xmlBuilder.createElement("truncated")  # truncated标签
            truncatedContent = xmlBuilder.createTextNode("0")
            truncated.appendChild(truncatedContent)
            object.appendChild(truncated)  # truncated标签结束
 
            difficult = xmlBuilder.createElement("difficult")  # difficult标签
            difficultcontent = xmlBuilder.createTextNode("0")
            difficult.appendChild(difficultcontent)
            object.appendChild(difficult)  # difficult标签结束
 
            bndbox = xmlBuilder.createElement("bndbox")  # bndbox标签
            xmin = xmlBuilder.createElement("xmin")  # xmin标签
            mathData = int(((float(oneline[1])) * Pwidth + 1) - (float(oneline[3])) * 0.5 * Pwidth)
            xminContent = xmlBuilder.createTextNode(str(mathData))
            xmin.appendChild(xminContent)
            bndbox.appendChild(xmin)  # xmin标签结束
 
            ymin = xmlBuilder.createElement("ymin")  # ymin标签
            mathData = int(((float(oneline[2])) * Pheight + 1) - (float(oneline[4])) * 0.5 * Pheight)
            yminContent = xmlBuilder.createTextNode(str(mathData))
            ymin.appendChild(yminContent)
            bndbox.appendChild(ymin)  # ymin标签结束
 
            xmax = xmlBuilder.createElement("xmax")  # xmax标签
            mathData = int(((float(oneline[1])) * Pwidth + 1) + (float(oneline[3])) * 0.5 * Pwidth)
            xmaxContent = xmlBuilder.createTextNode(str(mathData))
            xmax.appendChild(xmaxContent)
            bndbox.appendChild(xmax)  # xmax标签结束
 
            ymax = xmlBuilder.createElement("ymax")  # ymax标签
            mathData = int(((float(oneline[2])) * Pheight + 1) + (float(oneline[4])) * 0.5 * Pheight)
            ymaxContent = xmlBuilder.createTextNode(str(mathData))
            ymax.appendChild(ymaxContent)
            bndbox.appendChild(ymax)  # ymax标签结束
 
            object.appendChild(bndbox)  # bndbox标签结束
 
            annotation.appendChild(object)  # object标签结束
 
        f = open(xmlPath + name[0:-4] + ".xml", 'w')
        xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
        f.close()
 
if __name__ == "__main__":
    picPath = "./VOC2007/JPEGImage/"  # 图片所在文件夹路径,后面的/一定要带上
    txtPath = "./VOC2007/yolo/"  # txt所在文件夹路径,后面的/一定要带上
    xmlPath = "./VOC2007/Annotations_train/"  # xml文件保存路径,后面的/一定要带上
    makexml(picPath, txtPath, xmlPath)

②:对train、val、test分别转换成json文件(以test为例)

# -*- coding=utf-8 -*-
#!/usr/bin/python

import sys
import os
import shutil
import numpy as np
import json
import xml.etree.ElementTree as ET

# 检测框的ID起始值
START_BOUNDING_BOX_ID = 1
# 类别列表无必要预先创建,程序中会根据所有图像中包含的ID来创建并更新
# If necessary, pre-define category and its id
# PRE_DEFINE_CATEGORIES = {"aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
#                           "bottle":5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
#                           "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
#                           "motorbike": 14, "person": 15, "pottedplant": 16,
#                           "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}
PRE_DEFINE_CATEGORIES = {"cat": 1, "dog": 2}


def get(root, name):
    vars = root.findall(name)
    return vars


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars


# 得到图片唯一标识号
def get_filename_as_int(filename):
    try:
        filename = os.path.splitext(filename)[0]
        return int(filename)
    except:
        raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename))


def convert(xml_list, xml_dir, json_file):
    '''
    :param xml_list: 需要转换的XML文件列表
    :param xml_dir: XML的存储文件夹
    :param json_file: 导出json文件的路径
    :return: None
    '''
    list_fp = xml_list
    # 标注基本结构
    json_dict = {"images":[],
                 "type": "instances",
                 "annotations": [],
                 "categories": []}
    categories = PRE_DEFINE_CATEGORIES
    bnd_id = START_BOUNDING_BOX_ID
    for num, line in enumerate(list_fp):
        line = line.strip()
        print("buddy~ Processing {}".format(line))
        # 解析XML
        xml_f = os.path.join(xml_dir, line)
        tree = ET.parse(xml_f)
        root = tree.getroot()
        path = get(root, 'path')
        filename  = line[:-4] + '.jpg'
        """
        # 取出图片名字
        if len(path) == 1:
            filename = os.path.basename(path[0].text)
        elif len(path) == 0:
            filename = get_and_check(root, 'filename', 1).text
        else:
            raise NotImplementedError('%d paths found in %s'%(len(path), line))
        """
        ## The filename must be a number
        ##image_id = get_filename_as_int(filename)  # 图片ID
        image_id = num + 1
        size = get_and_check(root, 'size', 1)
        # 图片的基本信息
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        image = {'file_name': filename,
                 'height': height,
                 'width': width,
                 'id':image_id}
        json_dict['images'].append(image)
        ## Cruuently we do not support segmentation
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        # 处理每个标注的检测框
        for obj in get(root, 'object'):
            # 取出检测框类别名称
            category = get_and_check(obj, 'name', 1).text
            # 更新类别ID字典
            if category not in categories:
                new_id = len(categories)
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(get_and_check(bndbox, 'xmin', 1).text) - 1
            ymin = int(get_and_check(bndbox, 'ymin', 1).text) - 1
            xmax = int(get_and_check(bndbox, 'xmax', 1).text)
            ymax = int(get_and_check(bndbox, 'ymax', 1).text)
            assert(xmax > xmin)
            assert(ymax > ymin)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            annotation = dict()
            annotation['area'] = o_width*o_height
            annotation['iscrowd'] = 0
            annotation['image_id'] = image_id
            annotation['bbox'] = [xmin, ymin, o_width, o_height]
            annotation['category_id'] = category_id
            annotation['id'] = bnd_id
            annotation['ignore'] = 0
            # 设置分割数据,点的顺序为逆时针方向
            annotation['segmentation'] = [[xmin,ymin,xmin,ymax,xmax,ymax,xmax,ymin]]

            json_dict['annotations'].append(annotation)
            bnd_id = bnd_id + 1

    # 写入类别ID字典
    for cate, cid in categories.items():
        cat = {'supercategory': cate, 'id': cid, 'name': cate}
        json_dict['categories'].append(cat)
    # 导出到json
    json_fp = open(json_file, 'w')
    json_str = json.dumps(json_dict, indent=4)
    json_fp.write(json_str)
    json_fp.close()


if __name__ == '__main__':
    root_path = os.getcwd()
    anno_file = "test_annotations"
    xml_dir = os.path.join(root_path, anno_file)

    xml_labels = os.listdir(os.path.join(root_path, anno_file))
    np.random.shuffle(xml_labels)
    #split_point = int(len(xml_labels)/10)
    xml_list = xml_labels
    json_file = './instances_test2017.json'
    convert(xml_list, xml_dir, json_file)

形成:

coco
    --annotations
      --instances_train2017.json
      --instances_val2017.json
      --instances_test2017.json
    --train2017 #放train图片
    --val2017   #放val图片
    --test2017  #放test图片

No.2 安装环境

①:这里Anaconda管理、新建、激活环境不再赘述
② :安装DETR运行必要环境(方案很多,这里:PyCharm中的terminal中安装)

#示例
pip install -r E:\x\xx\...\requirements.txt

其中:
①:git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI&egg=pycocotools(下载

#cd进入PythonAPI文件夹
python setup.py build_ext install

②:git+https://github.com/cocodataset/panopticapi.git#egg=panopticapi(下载

pip install .\panopticapi-master.zip

No.3 配置参数

①(可选):

import torch

pretrained_weights = torch.load("./detr-r50-e632da11.pth") #官方处下载

num_class = 2 + 1    #假设类别数量为2,修改为自己类别数量,+1的1为背景
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)

torch.save(pretrained_weights,'detr_r50_%d.pth'%num_class)

②:detr.py中

num_classes = 3 if args.dataset_file != 'coco' else 3  #这里全改3了

③:main.py中

    parser.add_argument('--dataset_file', default='coco')
    parser.add_argument('--coco_path', defaut='coco数据集路径', type=str)
    parser.add_argument('--resume', default='①生成的detr_r50_3.pth(可不填)', help='resume from checkpoint')
    #其他参数根据需要进行修改

最后

①:

self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True)
……
……
……
TypeError: ‘numpy.float64’ object cannot be interpreted as an integer

可参考解决方案:
对环境位置site-packages\pycocotools中的cocoeval.py进行修改(具体查看自己报错位置)

#self.iouThrs = np.linspace(.5, 0.95, np.round((0.95 - .5) / .05) + 1, endpoint=True)
#self.recThrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True)
self.iouThrs = np.linspace(.5, 0.95, 10, endpoint=True)
self.recThrs = np.linspace(.0, 1.00, 101, endpoint=True)

②:Plot
在.\util\plot_utils.py最后添加

if __name__ == '__main__':
    files = list(Path('../outputs/eval').glob('*.pth'))
    plot_precision_recall(files)
    plt.show()
    plot_logs(logs=Path('../outputs/log/'),fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt')
    plt.show()
Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐