前言:此文是我从yolov5替换到yolox训练的过程,前提是我们有图片和标注文件,而且都是yolov5的txt格式的;之前在网上看了一圈,怎么用自己的数据训练yolox模型,都是需要把标注文件整理成voc格式或coco数据集格式,连文件夹的存放方式都必须一样,真是麻烦;而我之前的任务都是基于yolov5训练的,所以图片,标注文件已经有了,我也不想按voc,coco那样再去改变格式,于是就有了此文。

yolov5数据集目录如下:
在这里插入图片描述

一、利用yolov5标注生成xml格式的标注

利用yolov5的txt格式的标注文件生成xml格式的标注文件,在生成的时候需注意:
1、yolov5的标注是经过归一化的c_x, c_y, w, h
2、背景图片yolov5可以不用标注,即没有对应的txt文件,但yolox训练却不行
3、图片名字不要带有空格,yolov5可以正常训练验证,但yolox在验证的时候会报错。
直接上生成xml的代码,文件名yolotxt2xml.py:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2021/09/14 11:14
# @Author  : lishanlu
# @File    : yolotxt2xml.py
# @Software: PyCharm
# @Discription:

from __future__ import absolute_import, print_function, division
import os
from xml.dom.minidom import Document
import xml.etree.ElementTree as ET
import cv2


'''
import xml
xml.dom.minidom.Document().writexml()
def writexml(self,
             writer: Any,
             indent: str = "",
             addindent: str = "",
             newl: str = "",
             encoding: Any = None) -> None
'''


class YOLO2VOCConvert:
    def __init__(self, txts_path, xmls_path, imgs_path, classes_str_list):
        self.txts_path = txts_path   # 标注的yolo格式标签文件路径
        self.xmls_path = xmls_path   # 转化为voc格式标签之后保存路径
        self.imgs_path = imgs_path   # 读取读片的路径个图片名字,存储到xml标签文件中
        self.classes = classes_str_list  # 类别列表

    # 从所有的txt文件中提取出所有的类别, yolo格式的标签格式类别为数字 0,1,...
    # writer为True时,把提取的类别保存到'./Annotations/classes.txt'文件中
    def search_all_classes(self, writer=False):
        # 读取每一个txt标签文件,取出每个目标的标注信息
        all_names = set()
        txts = os.listdir(self.txts_path)
        # 使用列表生成式过滤出只有后缀名为txt的标签文件
        txts = [txt for txt in txts if txt.split('.')[-1] == 'txt']
        txts = [txt for txt in txts if not txt.split('.')[0] == "classes"]  # 过滤掉classes.txt文件
        print(len(txts), txts)
        # 11 ['0002030.txt', '0002031.txt', ... '0002039.txt', '0002040.txt']
        for txt in txts:
            txt_file = os.path.join(self.txts_path, txt)
            with open(txt_file, 'r') as f:
                objects = f.readlines()
                for object in objects:
                    object = object.strip().split(' ')
                    print(object)  # ['2', '0.506667', '0.553333', '0.490667', '0.658667']
                    all_names.add(int(object[0]))
            # print(objects)  # ['2 0.506667 0.553333 0.490667 0.658667\n', '0 0.496000 0.285333 0.133333 0.096000\n', '8 0.501333 0.412000 0.074667 0.237333\n']

        print("所有的类别标签:", all_names, "共标注数据集:%d张" % len(txts))

        # 把从xmls标签文件中提取的类别写入到'./Annotations/classes.txt'文件中
        # if writer:
        #     with open('./Annotations/classes.txt', 'w') as f:
        #         for label in all_names:
        #             f.write(label + '\n')

        return list(all_names)

    def yolo2voc(self):
        """
        可以转换图片和txtlabel数量不匹配的情况,即有些图片是背景
        :return:
        """
        # 创建一个保存xml标签文件的文件夹
        if not os.path.exists(self.xmls_path):
            os.makedirs(self.xmls_path)

        for img_name in os.listdir(self.imgs_path):
            # 读取图片的尺度信息
            print("读取图片:", img_name)
            try:
                img = cv2.imread(os.path.join(self.imgs_path, img_name))
                height_img, width_img, depth_img = img.shape
                print(height_img, width_img, depth_img)  # h 就是多少行(对应图片的高度), w就是多少列(对应图片的宽度)
            except Exception as e:
                print("%s read fail, %s"%(img_name, e))
                continue
            txt_name = img_name.replace(os.path.splitext(img_name)[1], '.txt')
            txt_file = os.path.join(self.txts_path, txt_name)
            all_objects = []
            if os.path.exists(txt_file):
                with open(txt_file, 'r') as f:
                    objects = f.readlines()
                    for object in objects:
                        object = object.strip().split(' ')
                        all_objects.append(object)
                        print(object)  # ['2', '0.506667', '0.553333', '0.490667', '0.658667']
            # 创建xml标签文件中的标签
            xmlBuilder = Document()
            # 创建annotation标签,也是根标签
            annotation = xmlBuilder.createElement("annotation")

            # 给标签annotation添加一个子标签
            xmlBuilder.appendChild(annotation)

            # 创建子标签folder
            folder = xmlBuilder.createElement("folder")
            # 给子标签folder中存入内容,folder标签中的内容是存放图片的文件夹,例如:JPEGImages
            folderContent = xmlBuilder.createTextNode(self.imgs_path.split('/')[-1])  # 标签内存
            folder.appendChild(folderContent)  # 把内容存入标签
            annotation.appendChild(folder)  # 把存好内容的folder标签放到 annotation根标签下

            # 创建子标签filename
            filename = xmlBuilder.createElement("filename")
            # 给子标签filename中存入内容,filename标签中的内容是图片的名字,例如:000250.jpg
            filenameContent = xmlBuilder.createTextNode(txt_name.split('.')[0] + '.jpg')  # 标签内容
            filename.appendChild(filenameContent)
            annotation.appendChild(filename)

            # 把图片的shape存入xml标签中
            size = xmlBuilder.createElement("size")
            # 给size标签创建子标签width
            width = xmlBuilder.createElement("width")  # size子标签width
            widthContent = xmlBuilder.createTextNode(str(width_img))
            width.appendChild(widthContent)
            size.appendChild(width)  # 把width添加为size的子标签
            # 给size标签创建子标签height
            height = xmlBuilder.createElement("height")  # size子标签height
            heightContent = xmlBuilder.createTextNode(str(height_img))  # xml标签中存入的内容都是字符串
            height.appendChild(heightContent)
            size.appendChild(height)  # 把width添加为size的子标签
            # 给size标签创建子标签depth
            depth = xmlBuilder.createElement("depth")  # size子标签width
            depthContent = xmlBuilder.createTextNode(str(depth_img))
            depth.appendChild(depthContent)
            size.appendChild(depth)  # 把width添加为size的子标签
            annotation.appendChild(size)  # 把size添加为annotation的子标签

            # 每一个object中存储的都是['2', '0.506667', '0.553333', '0.490667', '0.658667']一个标注目标
            for object_info in all_objects:
                # 开始创建标注目标的label信息的标签
                object = xmlBuilder.createElement("object")  # 创建object标签
                # 创建label类别标签
                # 创建name标签
                imgName = xmlBuilder.createElement("name")  # 创建name标签
                imgNameContent = xmlBuilder.createTextNode(self.classes[int(object_info[0])])
                imgName.appendChild(imgNameContent)
                object.appendChild(imgName)  # 把name添加为object的子标签

                # 创建pose标签
                pose = xmlBuilder.createElement("pose")
                poseContent = xmlBuilder.createTextNode("Unspecified")
                pose.appendChild(poseContent)
                object.appendChild(pose)  # 把pose添加为object的标签

                # 创建truncated标签
                truncated = xmlBuilder.createElement("truncated")
                truncatedContent = xmlBuilder.createTextNode("0")
                truncated.appendChild(truncatedContent)
                object.appendChild(truncated)

                # 创建difficult标签
                difficult = xmlBuilder.createElement("difficult")
                difficultContent = xmlBuilder.createTextNode("0")
                difficult.appendChild(difficultContent)
                object.appendChild(difficult)

                # 先转换一下坐标
                # (objx_center, objy_center, obj_width, obj_height)->(xmin,ymin, xmax,ymax)
                x_center = float(object_info[1]) * width_img + 1
                y_center = float(object_info[2]) * height_img + 1
                xminVal = int(
                    x_center - 0.5 * float(object_info[3]) * width_img)  # object_info列表中的元素都是字符串类型
                yminVal = int(y_center - 0.5 * float(object_info[4]) * height_img)
                xmaxVal = int(x_center + 0.5 * float(object_info[3]) * width_img)
                ymaxVal = int(y_center + 0.5 * float(object_info[4]) * height_img)

                # 创建bndbox标签(三级标签)
                bndbox = xmlBuilder.createElement("bndbox")
                # 在bndbox标签下再创建四个子标签(xmin,ymin, xmax,ymax) 即标注物体的坐标和宽高信息
                # 在voc格式中,标注信息:左上角坐标(xmin, ymin) (xmax, ymax)右下角坐标
                # 1、创建xmin标签
                xmin = xmlBuilder.createElement("xmin")  # 创建xmin标签(四级标签)
                xminContent = xmlBuilder.createTextNode(str(xminVal))
                xmin.appendChild(xminContent)
                bndbox.appendChild(xmin)
                # 2、创建ymin标签
                ymin = xmlBuilder.createElement("ymin")  # 创建ymin标签(四级标签)
                yminContent = xmlBuilder.createTextNode(str(yminVal))
                ymin.appendChild(yminContent)
                bndbox.appendChild(ymin)
                # 3、创建xmax标签
                xmax = xmlBuilder.createElement("xmax")  # 创建xmax标签(四级标签)
                xmaxContent = xmlBuilder.createTextNode(str(xmaxVal))
                xmax.appendChild(xmaxContent)
                bndbox.appendChild(xmax)
                # 4、创建ymax标签
                ymax = xmlBuilder.createElement("ymax")  # 创建ymax标签(四级标签)
                ymaxContent = xmlBuilder.createTextNode(str(ymaxVal))
                ymax.appendChild(ymaxContent)
                bndbox.appendChild(ymax)

                object.appendChild(bndbox)
                annotation.appendChild(object)  # 把object添加为annotation的子标签
            f = open(os.path.join(self.xmls_path, txt_name.split('.')[0] + '.xml'), 'w')
            xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
            f.close()


if __name__ == '__main__':
    imgs_path1 = 'F:/Dataset/road/images/val'        # ['train', 'val']
    txts_path1 = 'F:/Dataset/road/labels/val'        # ['train', 'val']
    xmls_path1 = 'F:/Dataset/road/xmls/val'          # ['train', 'val']
    classes_str_list = ['road_crack','road_sag']     # class name

    yolo2voc_obj1 = YOLO2VOCConvert(txts_path1, xmls_path1, imgs_path1, classes_str_list)
    labels = yolo2voc_obj1.search_all_classes()
    print('labels: ', labels)
    yolo2voc_obj1.yolo2voc()

将train和val都转换生成后,目录格式如下:
在这里插入图片描述

二、定义数据读取文件

整个YOLOX的工程,训练过程,要想有一个大概浏览,可以见我的另一篇文章yolox训练解析
进入到YOLOX主目录
在yolox/data/datasets/目录下定义了数据的读取方式,有按coco方式读取,有按voc方式读取,另外mosaic增强也定义在这个文件夹下,我们添加新的读取方式就在这个目录下添加,添加yolo_style.py文件,代码如下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2021/12/23 9:13
# @Author  : lishanlu
# @File    : yolo_style.py
# @Software: PyCharm
# @Discription: 读入yolox风格的xmls数据

from __future__ import absolute_import, print_function, division
import os
import os.path
import pickle
import xml.etree.ElementTree as ET

import cv2
import numpy as np

from yolox.evaluators.voc_eval import voc_eval

from .datasets_wrapper import Dataset
from pathlib import Path
import glob
from tqdm import tqdm
from PIL import Image, ExifTags
import torch


class AnnotationTransform(object):

    """Transforms a annotation into a Tensor of bbox coords and label index
    Initilized with a dictionary lookup of classnames to indexes

    Arguments:
        classes_name: (str, str, ...): dictionary lookup of classnames -> indexes
        keep_difficult (bool, optional): keep difficult instances or not
            (default: False)
        height (int): height
        width (int): width
    """
    def __init__(self, classes_name, keep_difficult=True):
        self.class_to_ind = dict(zip(classes_name, range(len(classes_name))))
        self.keep_difficult = keep_difficult

    def __call__(self, target):
        """
        Arguments:
            target (annotation) : the target annotation to be made usable
                will be an ET.Element
        Returns:
            a list containing lists of bounding boxes  [bbox coords, class name]
        """
        res = np.empty((0, 5))
        for obj in target.iter("object"):
            difficult = obj.find("difficult")
            if difficult is not None:
                difficult = int(difficult.text) == 1
            else:
                difficult = False
            if not self.keep_difficult and difficult:
                continue
            name = obj.find("name").text.strip()
            bbox = obj.find("bndbox")

            pts = ["xmin", "ymin", "xmax", "ymax"]
            bndbox = []
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1
                # scale height or width
                # cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
            label_idx = self.class_to_ind[name]
            bndbox.append(label_idx)
            res = np.vstack((res, bndbox))  # [xmin, ymin, xmax, ymax, label_ind]
            # img_id = target.find('filename').text[:-4]
        width = int(target.find("size").find("width").text)
        height = int(target.find("size").find("height").text)
        img_info = (height, width)

        return res, img_info


"""
generation yolo style dataloader.
"""
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp']  # acceptable image suffixes
# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
    if ExifTags.TAGS[orientation] == 'Orientation':
        break


def img2xml_paths(img_paths):
    # Define xml paths as a function of image paths
    sa, sb = os.sep + 'images' + os.sep, os.sep + 'xmls' + os.sep  # /images/, /xmls/ substrings
    return ['xml'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]


def get_hash(files):
    # Returns a single hash value of a list of files
    return sum(os.path.getsize(f) for f in files if os.path.isfile(f))


def exif_size(img):
    # Returns exif-corrected PIL size
    s = img.size  # (width, height)
    try:
        rotation = dict(img._getexif().items())[orientation]
        if rotation == 6:  # rotation 270
            s = (s[1], s[0])
        elif rotation == 8:  # rotation 90
            s = (s[1], s[0])
    except:
        pass

    return s


def xyxy2xywh(x):
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
    y[:, 2] = x[:, 2] - x[:, 0]  # width
    y[:, 3] = x[:, 3] - x[:, 1]  # height
    return y


def segments2boxes(segments):
    # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
    boxes = []
    for s in segments:
        x, y = s.T  # segment xy
        boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
    return xyxy2xywh(np.array(boxes))  # cls, xywh


class YOLODetection(Dataset):

    """
    YOLO Style Detection Dataset Object (read label from yolo style XML)

    input is image, target is annotation

    Args:
        data_dir (string): filepath to data folder.
        classes (string, string, ....): class string names.
        image_set (string): imageset to use (eg. 'train', 'val', 'test')
        preproc (callable, optional): transformation to perform on the input image
        target_transform (callable, optional): transformation to perform on the target `annotation`
            (eg: take in caption string, return tensor of word indices)
        dataset_name (string, optional): which dataset to load  (default: 'yolo_dataset')
    """

    def __init__(
        self,
        data_dir,
        classes,
        image_sets=['train'],
        img_size=(416, 416),
        preproc=None,
        dataset_name="yolo_dataset",
        cache=False,
    ):
        super().__init__(img_size)
        self.root = data_dir
        self.image_set = image_sets
        self.img_size = img_size
        self.preproc = preproc
        self._classes = classes
        self.target_transform = AnnotationTransform(self._classes, keep_difficult=True)
        self.name = dataset_name
        for name in image_sets:
            rootpath = self.root
            image_dir = os.path.join(rootpath, 'images', name)
            self.image_files = [os.path.join(image_dir, image_name) for image_name in os.listdir(image_dir)]
            if name == 'val':
                self.val_ids = [os.path.splitext(image_name)[0] for image_name in os.listdir(image_dir)]
                with open(os.path.join(rootpath, name+'.txt'), 'w') as f:
                    for id in self.val_ids:
                        f.write(id+'\n')
        self.xml_files = img2xml_paths(self.image_files)           # list, xml file path
        self.annotations = self._load_xml_annotations()
        self.imgs = None
        if cache:
            self._cache_images()

    def __len__(self):
        return len(self.image_files)

    def _load_xml_annotations(self):
        return [self.load_anno_from_ids(_ids) for _ids in range(len(self.xml_files))]

    def _cache_images(self):
        pass

    def load_anno_from_ids(self, index):
        xml_file = self.xml_files[index]
        target = ET.parse(xml_file).getroot()

        assert self.target_transform is not None
        res, img_info = self.target_transform(target)
        height, width = img_info

        r = min(self.img_size[0] / height, self.img_size[1] / width)
        res[:, :4] *= r
        resized_info = (int(height * r), int(width * r))

        return (res, img_info, resized_info)

    def load_anno(self, index):
        return self.annotations[index][0]

    def load_resized_img(self, index):
        img = self.load_image(index)
        r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
        resized_img = cv2.resize(
            img,
            (int(img.shape[1] * r), int(img.shape[0] * r)),
            interpolation=cv2.INTER_LINEAR,
        ).astype(np.uint8)

        return resized_img

    def load_image(self, index):
        img = cv2.imread(self.image_files[index], cv2.IMREAD_COLOR)
        assert img is not None

        return img

    def pull_item(self, index):
        """Returns the original image and target at an index for mixup

        Note: not using self.__getitem__(), as any transformations passed in
        could mess up this functionality.

        Argument:
            index (int): index of img to show
        Return:
            img, target
        """
        if self.imgs is not None:
            target, img_info, resized_info = self.annotations[index]
            pad_img = self.imgs[index]
            img = pad_img[: resized_info[0], : resized_info[1], :].copy()
        else:
            img = self.load_resized_img(index)
            target, img_info, _ = self.annotations[index]

        return img, target, img_info, index

    @Dataset.mosaic_getitem
    def __getitem__(self, index):
        img, target, img_info, img_id = self.pull_item(index)      # 此target坐标为(x,y,x,y,cls)

        ### show read image and label.
        # from PIL import Image,ImageDraw
        # from matplotlib import pyplot as plt
        # img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        # draw = ImageDraw.Draw(img)
        # for j in range(target.shape[0]):
        #     name = int(target[j][4])
        #     left = int(target[j][0])
        #     top = int(target[j][1])
        #     right = int(target[j][2])
        #     bottom = int(target[j][3])
        #     draw.text((left+10, top+10), f'{name}', fill='blue')
        #     draw.rectangle((left, top, right, bottom), outline='red', width=2)
        # plt.imshow(img)
        # plt.show()

        if self.preproc is not None:
            img, target = self.preproc(img, target, self.input_dim)  # 此target坐标为(cls, cx,cy,w,h)

            # from PIL import Image,ImageDraw
            # from matplotlib import pyplot as plt
            # img = np.transpose(img.astype(np.uint8), (1, 2, 0))
            # img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            # draw = ImageDraw.Draw(img)
            # for j in range(target.shape[0]):
            #     name = int(target[j][0])
            #     left = int(target[j][1]-target[j][3]/2)
            #     top = int(target[j][2]-target[j][4]/2)
            #     right = int(target[j][1]+target[j][3]/2)
            #     bottom = int(target[j][2]+target[j][4]/2)
            #     draw.text((left+10, top+10), f'{name}', fill='blue')
            #     draw.rectangle((left, top, right, bottom), outline='red', width=2)
            # plt.imshow(img)
            # plt.show()

        return img, target, img_info, img_id

    def evaluate_detections(self, all_boxes, output_dir=None):
        """
        all_boxes is a list of length number-of-classes.
        Each list element is a list of length number-of-images.
        Each of those list elements is either an empty list []
        or a numpy array of detection.

        all_boxes[class][image] = [] or np.array of shape #dets x 5
        """
        self._write_voc_results_file(all_boxes)
        IouTh = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
        mAPs = []
        for iou in IouTh:
            mAP = self._do_python_eval(output_dir, iou)
            mAPs.append(mAP)

        print("--------------------------------------------------------------")
        print("map_5095:", np.mean(mAPs))
        print("map_50:", mAPs[0])
        print("--------------------------------------------------------------")
        return np.mean(mAPs), mAPs[0]

    def _get_voc_results_file_template(self):
        filename = "comp4_det_test" + "_{:s}.txt"
        filedir = os.path.join(self.root, "results")
        if not os.path.exists(filedir):
            os.makedirs(filedir)
        path = os.path.join(filedir, filename)
        return path

    def _write_voc_results_file(self, all_boxes):
        self.ids = [os.path.splitext(os.path.split(image_file)[1])[0] for image_file in self.image_files]
        for cls_ind, cls in enumerate(self._classes):
            cls_ind = cls_ind
            if cls == "__background__":
                continue
            print("Writing {} VOC results file".format(cls))
            filename = self._get_voc_results_file_template().format(cls)
            with open(filename, "wt") as f:
                for im_ind, index in enumerate(self.ids):
                    #index = index[1]
                    dets = all_boxes[cls_ind][im_ind]
                    if dets == []:
                        continue
                    for k in range(dets.shape[0]):
                        f.write(
                            "{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format(
                                index,
                                dets[k, -1],
                                dets[k, 0] + 1,
                                dets[k, 1] + 1,
                                dets[k, 2] + 1,
                                dets[k, 3] + 1,
                            )
                        )

    def _do_python_eval(self, output_dir="output", iou=0.5):
        rootpath = self.root
        name = self.image_set[0]
        annopath = os.path.join(rootpath, "xmls", "val", "{:s}.xml")
        imagesetfile = os.path.join(rootpath, name + ".txt")
        cachedir = os.path.join(
            self.root, "annotations_cache"
        )
        if not os.path.exists(cachedir):
            os.makedirs(cachedir)
        aps = []
        # The PASCAL VOC metric changed in 2010
        # use_07_metric = True if int(self._year) < 2010 else False
        use_07_metric = True
        print("Eval IoU : {:.2f}".format(iou))
        if output_dir is not None and not os.path.isdir(output_dir):
            os.mkdir(output_dir)
        for i, cls in enumerate(self._classes):

            if cls == "__background__":
                continue

            filename = self._get_voc_results_file_template().format(cls)
            rec, prec, ap = voc_eval(
                filename,
                annopath,
                imagesetfile,
                cls,
                cachedir,
                ovthresh=iou,
                use_07_metric=use_07_metric,
            )
            aps += [ap]
            if iou == 0.5:
                print("AP for {} = {:.4f}".format(cls, ap))
            if output_dir is not None:
                with open(os.path.join(output_dir, cls + "_pr.pkl"), "wb") as f:
                    pickle.dump({"rec": rec, "prec": prec, "ap": ap}, f)
        if iou == 0.5:
            print("Mean AP = {:.4f}".format(np.mean(aps)))
            print("~~~~~~~~")
            print("Results:")
            for ap in aps:
                print("{:.3f}".format(ap))
            print("{:.3f}".format(np.mean(aps)))
            print("~~~~~~~~")
            print("")
            print("--------------------------------------------------------------")
            print("Results computed with the **unofficial** Python eval code.")
            print("Results should be very close to the official MATLAB eval code.")
            print("Recompute with `./tools/reval.py --matlab ...` for your paper.")
            print("-- Thanks, The Management")
            print("--------------------------------------------------------------")

        return np.mean(aps)

定义好这个文件,别忘了在yolox/data/datasets/的__init__.py文件中加入from .yolo_style import YOLODetection

三、定义训练用的配置文件

在exps/example/目录下新建一个任务目录,比如road,在这个目录下新建文件yolox_road.py,这个文件用于定义训练用的类Exp,它继承自yolox/exp/下的yolox_base.py中的Exp类,主要定义模型参数,数据集参数及数据增强参数,创建dataloader等函数。代码示例如下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2021/12/23 8:58
# @Author  : lishanlu
# @File    : yolox_road.py
# @Software: PyCharm
# @Discription:

from __future__ import absolute_import, print_function, division
import os
import torch
import torch.nn as nn
import torch.distributed as dist

from yolox.data import get_yolox_datadir
from yolox.exp import Exp as MyExp


class Exp(MyExp):
    def __init__(self):
        super(Exp, self).__init__()
        # ------------ model config -------------------#
        self.num_classes = 2                      # 修改为和自己的数据类别一致
        self.depth = 0.67
        self.width = 0.75

        # ---------------- dataloader config ---------------- #
        # set worker to 4 for shorter dataloader init time
        self.data_num_workers = 4
        self.input_size = (640, 640)  # (height, width)
        # Actual multiscale ranges: [640-5*32, 640+5*32].
        # To disable multiscale training, set the
        # self.multiscale_range to 0.
        self.multiscale_range = 5
        # You can uncomment this line to specify a multiscale range
        # self.random_size = (14, 26)
        self.data_dir = 'your data rootdir'         # 指定数据的根目录
        self.classes_name = ('class1','class2')      # 指定类别名字
        self.dataset_name = 'yolo_dataset'          # 数据库名字,可以不用修改

        # --------------- transform config ----------------- #
        self.mosaic_prob = 1.0
        self.mixup_prob = 1.0
        self.hsv_prob = 1.0
        self.flip_prob = 0.5
        self.degrees = 5.0
        self.translate = 0.1
        self.mosaic_scale = (0.5, 1.5)
        self.mixup_scale = (0.5, 1.5)
        self.shear = 2.0
        self.perspective = 0.0
        self.enable_mixup = False

        # --------------  training config --------------------- #
        self.warmup_epochs = 5
        self.max_epoch = 300
        self.warmup_lr = 0
        self.basic_lr_per_img = 0.01 / 64.0
        self.scheduler = "yoloxwarmcos"
        self.milestones = [70, 120, 180, 300]    # 该参数只用于multi_step学习率衰减
        self.gamma = 0.1                    # 该参数只用于multi_step学习率衰减    
        self.no_aug_epochs = 300
        self.min_lr_ratio = 0.05
        self.ema = True

        self.weight_decay = 5e-4
        self.momentum = 0.9
        self.print_interval = 10
        self.eval_interval = 1
        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]

        # -----------------  testing config ------------------ #
        self.test_size = (640, 640)
        self.test_conf = 0.01
        self.nmsthre = 0.65

    def get_model(self):
        from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead

        def init_yolo(M):
            for m in M.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eps = 1e-3
                    m.momentum = 0.03

        if getattr(self, "model", None) is None:
            in_channels = [256, 512, 1024]
            backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels)
            head = YOLOXHead(self.num_classes, self.width,
                             in_channels=in_channels)  # strides=[8,16,32], in_channels=in_channels
            self.model = YOLOX(backbone, head)

        self.model.apply(init_yolo)
        self.model.head.initialize_biases(1e-2)
        return self.model

    def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
        from yolox.data import (
            YOLODetection,
            TrainTransform,
            YoloBatchSampler,
            DataLoader,
            InfiniteSampler,
            MosaicDetection,
            worker_init_reset_seed,
        )

        from yolox.utils import (
            wait_for_the_master,
            get_local_rank,
        )

        local_rank = get_local_rank()

        with wait_for_the_master(local_rank):
            dataset = YOLODetection(data_dir=self.data_dir,
                                    classes=self.classes_name,
                                    image_sets=['train'],
                                    img_size=self.input_size,
                                    preproc=TrainTransform(
                                        max_labels=50,
                                        flip_prob=self.flip_prob,
                                        hsv_prob=self.hsv_prob),
                                    dataset_name=self.dataset_name,
                                    cache=cache_img)

        dataset = MosaicDetection(
            dataset,
            mosaic=not no_aug,
            img_size=self.input_size,
            preproc=TrainTransform(
                max_labels=120,
                flip_prob=self.flip_prob,
                hsv_prob=self.hsv_prob),
            degrees=self.degrees,
            translate=self.translate,
            mosaic_scale=self.mosaic_scale,
            mixup_scale=self.mixup_scale,
            shear=self.shear,
            perspective=self.perspective,
            enable_mixup=self.enable_mixup,
            mosaic_prob=self.mosaic_prob,
            mixup_prob=self.mixup_prob,
        )
        # import pdb;pdb.set_trace()
        self.dataset = dataset

        if is_distributed:
            batch_size = batch_size // dist.get_world_size()

        sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)

        batch_sampler = YoloBatchSampler(
            sampler=sampler,
            batch_size=batch_size,
            drop_last=False,
            mosaic=not no_aug,
        )

        dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
        dataloader_kwargs["batch_sampler"] = batch_sampler
        dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
        train_loader = DataLoader(self.dataset, **dataloader_kwargs)

        return train_loader

    def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
        from yolox.data import YOLODetection, ValTransform

        valdataset = YOLODetection(
            data_dir=self.data_dir,
            classes=self.classes_name,
            image_sets=['val'],
            img_size=self.test_size,
            preproc=ValTransform(legacy=legacy),
            dataset_name=self.dataset_name
        )

        if is_distributed:
            batch_size = batch_size // dist.get_world_size()
            sampler = torch.utils.data.distributed.DistributedSampler(
                valdataset, shuffle=False
            )
        else:
            sampler = torch.utils.data.SequentialSampler(valdataset)

        dataloader_kwargs = {
            "num_workers": self.data_num_workers,
            "pin_memory": True,
            "sampler": sampler,
        }
        dataloader_kwargs["batch_size"] = batch_size
        val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)

        return val_loader

    def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
        from yolox.evaluators import VOCEvaluator

        val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
        evaluator = VOCEvaluator(
            dataloader=val_loader,
            img_size=self.test_size,
            confthre=self.test_conf,
            nmsthre=self.nmsthre,
            num_classes=self.num_classes,
        )
        return evaluator

    def get_lr_scheduler(self, lr, iters_per_epoch, **kwargs):
        from yolox.utils import LRScheduler

        scheduler = LRScheduler(
            self.scheduler,
            lr,
            iters_per_epoch,
            self.max_epoch,
            warmup_epochs=self.warmup_epochs,
            warmup_lr_start=self.warmup_lr,
            no_aug_epochs=self.no_aug_epochs,
            min_lr_ratio=self.min_lr_ratio,
            **kwargs
        )
        return scheduler

四、启动训练

写一个sh文件train.sh,代码如下:

python tools/train.py \
--experiment-name yolox_road \
--batch-size 48 \
--devices 0 \
--exp_file exps/example/road/yolox_road.py \
--fp16 \
--ckpt pre_train/yolox_m.pth

运行命令bash ./train.sh就可以启动训练

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐