打算用yolov8作为baseline进行高光谱目标检测,但是数据输入不是3通道,直接用是跑不起来的,所有记录一下如何修改代码使yolov8能够进行多通道图像训练

这里以yolov8n的训练举例,这是官方给的标准训练代码命令:

from ultralytics import YOLO

# Load a model
model = YOLO('yolov8n.yaml')  # build a new model from YAML
model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
model = YOLO('yolov8n.yaml').load('yolov8n.pt')  # build from YAML and transfer weights

# Train the model
results = model.train(data='LED.yaml', epochs=100, imgsz=640)

LED.yaml是我自己的数据集配置文件。

数据集的文件结构如下图所示:

一、配置文件修改

1.数据集配置文件

首先是数据集配置文件,按照如图所示方式创建自己的数据集配置文件:

设置好数据集的路径和数据的检测类别,我这里举例只有1个类别

2.模型配置文件

修改模型配置文件中关于模型定义的参数:

将nc改为数据的类别总数,前面提到的1;再增加ch: 544,这一行指的是第一层卷积层的输入通道数,也就是我使用的高光谱数据的通道数是544

3.训练超参数配置文件

按照需求修改一些训练的超参数:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# Default training settings and hyperparameters for medium-augmentation COCO training

task: detect  # (str) YOLO task, i.e. detect, segment, classify, pose
mode: train  # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark

# Train settings -------------------------------------------------------------------------------------------------------
model:  # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
data:  # (str, optional) path to data file, i.e. coco128.yaml
epochs: 100  # (int) number of epochs to train for
time:  # (float, optional) number of hours to train for, overrides epochs if supplied
patience: 50  # (int) epochs to wait for no observable improvement for early stopping of training
batch: 4  # (int) number of images per batch (-1 for AutoBatch)
imgsz: 640  # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
save: True  # (bool) save train checkpoints and predict results
save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
cache: False  # (bool) True/ram, disk or False. Use cache for data loading
device:  # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
workers: 8  # (int) number of worker threads for data loading (per RANK if DDP)
project:  # (str, optional) project name
name:  # (str, optional) experiment name, results saved to 'project/name' directory
exist_ok: False  # (bool) whether to overwrite existing experiment
pretrained: True  # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
optimizer: auto  # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
verbose: True  # (bool) whether to print verbose output
seed: 0  # (int) random seed for reproducibility
deterministic: True  # (bool) whether to enable deterministic mode
single_cls: False  # (bool) train multi-class data as single-class
rect: False  # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
cos_lr: False  # (bool) use cosine learning rate scheduler
close_mosaic: 10  # (int) disable mosaic augmentation for final epochs (0 to disable)
resume: False  # (bool) resume training from last checkpoint
amp: True  # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
fraction: 1.0  # (float) dataset fraction to train on (default is 1.0, all images in train set)
profile: False  # (bool) profile ONNX and TensorRT speeds during training for loggers
freeze: None  # (int | list, optional) freeze first n layers, or freeze list of layer indices during training
multi_scale: False   # (bool) Whether to use multi-scale during training
# Segmentation
overlap_mask: True  # (bool) masks should overlap during training (segment train only)
mask_ratio: 4  # (int) mask downsample ratio (segment train only)
# Classification
dropout: 0.0  # (float) use dropout regularization (classify train only)

# Val/Test settings ----------------------------------------------------------------------------------------------------
val: True  # (bool) validate/test during training
split: val  # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
save_json: False  # (bool) save results to JSON file
save_hybrid: False  # (bool) save hybrid version of labels (labels + additional predictions)
conf:  # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
iou: 0.7  # (float) intersection over union (IoU) threshold for NMS
max_det: 300  # (int) maximum number of detections per image
half: False  # (bool) use half precision (FP16)
dnn: False  # (bool) use OpenCV DNN for ONNX inference
plots: True  # (bool) save plots and images during train/val

# Predict settings -----------------------------------------------------------------------------------------------------
source:  # (str, optional) source directory for images or videos
vid_stride: 1  # (int) video frame-rate stride
stream_buffer: False  # (bool) buffer all streaming frames (True) or return the most recent frame (False)
visualize: False  # (bool) visualize model features
augment: False  # (bool) apply image augmentation to prediction sources
agnostic_nms: False  # (bool) class-agnostic NMS
classes:  # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
retina_masks: False  # (bool) use high-resolution segmentation masks
embed:  # (list[int], optional) return feature vectors/embeddings from given layers

# Visualize settings ---------------------------------------------------------------------------------------------------
show: False  # (bool) show predicted images and videos if environment allows
save_frames: False  # (bool) save predicted individual video frames
save_txt: False  # (bool) save results as .txt file
save_conf: False  # (bool) save results with confidence scores
save_crop: False  # (bool) save cropped images with results
show_labels: False  # (bool) show prediction labels, i.e. 'person'
show_conf: True  # (bool) show prediction confidence, i.e. '0.99'
show_boxes: True  # (bool) show prediction boxes
line_width: 1  # (int, optional) line width of the bounding boxes. Scaled to image size if None.

# Export settings ------------------------------------------------------------------------------------------------------
format: torchscript  # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
keras: False  # (bool) use Kera=s
optimize: False  # (bool) TorchScript: optimize for mobile
int8: False  # (bool) CoreML/TF INT8 quantization
dynamic: False  # (bool) ONNX/TF/TensorRT: dynamic axes
simplify: False  # (bool) ONNX: simplify model
opset:  # (int, optional) ONNX: opset version
workspace: 4  # (int) TensorRT: workspace size (GB)
nms: False  # (bool) CoreML: add NMS

# Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.01  # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
lrf: 0.01  # (float) final learning rate (lr0 * lrf)
momentum: 0.937  # (float) SGD momentum/Adam beta1
weight_decay: 0.0005  # (float) optimizer weight decay 5e-4
warmup_epochs: 3.0  # (float) warmup epochs (fractions ok)
warmup_momentum: 0.8  # (float) warmup initial momentum
warmup_bias_lr: 0.1  # (float) warmup initial bias lr
box: 7.5  # (float) box loss gain
cls: 0.5  # (float) cls loss gain (scale with pixels)
dfl: 1.5  # (float) dfl loss gain
pose: 12.0  # (float) pose loss gain
kobj: 1.0  # (float) keypoint obj loss gain
label_smoothing: 0.0  # (float) label smoothing (fraction)
nbs: 64  # (int) nominal batch size
hsv_h: 0 #.015  # (float) image HSV-Hue augmentation (fraction)
hsv_s: 0 #.7  # (float) image HSV-Saturation augmentation (fraction)
hsv_v: 0 #.4  # (float) image HSV-Value augmentation (fraction)
degrees: 0.0  # (float) image rotation (+/- deg)
translate: 0.1  # (float) image translation (+/- fraction)
scale: 0.5  # (float) image scale (+/- gain)
shear: 0.0  # (float) image shear (+/- deg)
perspective: 0.0  # (float) image perspective (+/- fraction), range 0-0.001
flipud: 0.0  # (float) image flip up-down (probability)
fliplr: 0.5  # (float) image flip left-right (probability)
mosaic: 1.0  # (float) image mosaic (probability)
mixup: 0.0  # (float) image mixup (probability)
copy_paste: 0.0  # (float) segment copy-paste (probability)
auto_augment: randaugment  # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix)
erasing: 0.4  # (float) probability of random erasing during classification training (0-1)
crop_fraction: 1.0  # (float) image crop fraction for classification evaluation/inference (0-1)

# Custom config.yaml ---------------------------------------------------------------------------------------------------
cfg:  # (str, optional) for overriding defaults.yaml

# Tracker settings ------------------------------------------------------------------------------------------------------
tracker: botsort.yaml  # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]

这里主要把hsv_h,hsv_s和hsv_v三个参数设置为0,不做HSV相关的数据增强,因为并不RGB图像。

二、数据处理代码

1.def verify_image_label(args):

yolov8在训练之前会先读取数据和label进行一个验证,将label信息存进.cache文件里面,方便训练的时候进行调用,但是对于很多通道的数据,在验证的过程会出现问题,而且验证的代码被写成了try+except 函数,导致即使格式不对也不会报错,但是肯定是训练不起来的,所有将miniconda3/envs/yolov8/lib/python3.9/site-packages/ultralytics/data/utils.py文件中的verify_image_labels函数修改如下:

def verify_image_label(args):
    """Verify one image-label pair."""
    im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
    # Number (missing, found, empty, corrupt), message, segments, keypoints
    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
    try:
        # Verify images
        if os.path.splitext(im_file)[1] == '.tiff':
            im = tifffile.imread(im_file)
            shape = im.shape[:2]
        else:
            im = Image.open(im_file)
            im.verify()  # PIL verify
            shape = exif_size(im)  # image size
        shape = (shape[1], shape[0])  # hw
        assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
        # assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
        # if im.format.lower() in ("jpg", "jpeg"):
        #     with open(im_file, "rb") as f:
        #         f.seek(-2, 2)
        #         if f.read() != b"\xff\xd9":  # corrupt JPEG
        #             ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
        #             msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"

        # Verify labels
        if os.path.isfile(lb_file):
            nf = 1  # label found
            with open(lb_file) as f:
                lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
                if any(len(x) > 6 for x in lb) and (not keypoint):  # is segment
                    classes = np.array([x[0] for x in lb], dtype=np.float32)
                    segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb]  # (cls, xy1...)
                    lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
                lb = np.array(lb, dtype=np.float32)
            nl = len(lb)
            if nl:
                if keypoint:
                    assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
                    points = lb[:, 5:].reshape(-1, ndim)[:, :2]
                else:
                    assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
                    points = lb[:, 1:]
                assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
                assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"

                # All labels
                max_cls = lb[:, 0].max()  # max label count
                assert max_cls <= num_cls, (
                    f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
                    f"Possible class labels are 0-{num_cls - 1}"
                )
                _, i = np.unique(lb, axis=0, return_index=True)
                if len(i) < nl:  # duplicate row check
                    lb = lb[i]  # remove duplicates
                    if segments:
                        segments = [segments[x] for x in i]
                    msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
            else:
                ne = 1  # label empty
                lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
        else:
            nm = 1  # label missing
            lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)
        if keypoint:
            keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
            if ndim == 2:
                kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
                keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1)  # (nl, nkpt, 3)
        lb = lb[:, :5]
        return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
    except Exception as e:
        nc = 1
        msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
        return [None, None, None, None, None, nm, nf, ne, nc, msg]

2.def load_image(self, i, rect_mode=True):

加载数据的代码需要修改,默认使用cv2.imread会自动读取为3通道图像,修改miniconda3/envs/yolov8/lib/python3.9/site-packages/ultralytics/data/base.py的class BaseDataset(Dataset):中的load_image函数如下:

def load_image(self, i, rect_mode=True):
        """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
        im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
        if im is None:  # not cached in RAM
            if fn.exists():  # load npy
                try:
                    im = np.load(fn)
                except Exception as e:
                    LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")
                    Path(fn).unlink(missing_ok=True)
                    im = cv2.imread(f)  # BGR
            else:  # read image
                if f.endswith('.tiff') or f.endswith('.TIFF'):
                    im = tifffile.imread(f)
                else:
                    im = cv2.imread(f)  # BGR
            if im is None:
                raise FileNotFoundError(f"Image Not Found {f}")

            h0, w0 = im.shape[:2]  # orig hw
            if rect_mode:  # resize long side to imgsz while maintaining aspect ratio
                r = self.imgsz / max(h0, w0)  # ratio
                if r != 1:  # if sizes are not equal
                    w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
                    #im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
                    im = resize(im, (w, h), mode='constant', anti_aliasing=True)
            elif not (h0 == w0 == self.imgsz):  # resize by stretching image to square imgsz
                im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)

            # Add to buffer if training with augmentations
            if self.augment:
                self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
                self.buffer.append(i)
                if len(self.buffer) >= self.max_buffer_length:
                    j = self.buffer.pop(0)
                    self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None

            return im, (h0, w0), im.shape[:2]

        return self.ims[i], self.im_hw0[i], self.im_hw[i]i], self.im_hw0[i], self.im_hw[i]

注意导入from skimage.transform import resize,用这个resize替换cv2.resize,这个方法可以对多通道的图像进行resize,如果你的图像通道数大于513,则需要修改,如果小于513则用cv2.resize就可以,我也不知道为什么。

3.def affine_transform(self, img, border):

这里是数据增强会有问题,cv2的仿射变换函数只能对3通道的,所以这里换一个库进行仿射变换,先导入:

from skimage.transform import warp, AffineTransform

修改miniconda3/envs/yolov8/lib/python3.9/site-packages/ultralytics/data/augment.py中的class RandomPerspective:的affine_transform函数:

def affine_transform(self, img, border):
        """
        Applies a sequence of affine transformations centered around the image center.

        Args:
            img (ndarray): Input image.
            border (tuple): Border dimensions.

        Returns:
            img (ndarray): Transformed image.
            M (ndarray): Transformation matrix.
            s (float): Scale factor.
        """

        # Center
        C = np.eye(3, dtype=np.float32)

        C[0, 2] = -img.shape[1] / 2  # x translation (pixels)
        C[1, 2] = -img.shape[0] / 2  # y translation (pixels)

        # Perspective
        P = np.eye(3, dtype=np.float32)
        P[2, 0] = random.uniform(-self.perspective, self.perspective)  # x perspective (about y)
        P[2, 1] = random.uniform(-self.perspective, self.perspective)  # y perspective (about x)

        # Rotation and Scale
        R = np.eye(3, dtype=np.float32)
        a = random.uniform(-self.degrees, self.degrees)
        # a += random.choice([-180, -90, 0, 90])  # add 90deg rotations to small rotations
        s = random.uniform(1 - self.scale, 1 + self.scale)
        # s = 2 ** random.uniform(-scale, scale)
        R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)

        # Shear
        S = np.eye(3, dtype=np.float32)
        S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180)  # x shear (deg)
        S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180)  # y shear (deg)

        # Translation
        T = np.eye(3, dtype=np.float32)
        T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0]  # x translation (pixels)
        T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1]  # y translation (pixels)

        # Combined rotation matrix
        M = T @ S @ R @ P @ C  # order of operations (right to left) is IMPORTANT
        # Affine image
        if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any():  # image changed
            if self.perspective:
                img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
            else:  # affine
                #img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
                channel_list = cv2.split(img)
                transformed_channels = [cv2.warpAffine(channel, M[:2], dsize=self.size, borderValue=(114, 114, 114)) for channel in channel_list]
                img = cv2.merge(transformed_channels)
        return img, M, s

4.class LetterBox:

这里也是数据增强部分的代码出问题,直接把miniconda3/envs/yolov8/lib/python3.9/site-packages/ultralytics/data/augment.py的类LetterBox中的函数def __call__(self, labels=None, image=None):改为如下所示:

def __call__(self, labels=None, image=None):
        """Return updated labels and image with added border."""
        if labels is None:
            labels = {}
        img = labels.get("img") if image is None else image
        shape = img.shape[:2]  # current shape [height, width]
        new_shape = labels.pop("rect_shape", self.new_shape)
        if isinstance(new_shape, int):
            new_shape = (new_shape, new_shape)

        # Scale ratio (new / old)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        if not self.scaleup:  # only scale down, do not scale up (for better val mAP)
            r = min(r, 1.0)

        # Compute padding
        ratio = r, r  # width, height ratios
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
        if self.auto:  # minimum rectangle
            dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride)  # wh padding
        elif self.scaleFill:  # stretch
            dw, dh = 0.0, 0.0
            new_unpad = (new_shape[1], new_shape[0])
            ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

        if self.center:
            dw /= 2  # divide padding into 2 sides
            dh /= 2

        if shape[::-1] != new_unpad:  # resize
            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
        top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))

        if img.shape[2] > 3:
            border_img = np.ones((img.shape[0]+top+bottom, img.shape[1]+left+right, img.shape[2]), dtype=img.dtype)*114
            border_img[top:img.shape[0]+top, left:img.shape[1]+left] = img
            img = border_img
        else:
            img = cv2.copyMakeBorder(
                img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
            )  # add border
        if labels.get("ratio_pad"):
            labels["ratio_pad"] = (labels["ratio_pad"], (left, top))  # for evaluation

        if len(labels):
            labels = self._update_labels(labels, ratio, dw, dh)
            labels["img"] = img
            labels["resized_shape"] = new_shape
            return labels
        else:
            return img

三、画图代码

1.def plot_images

画图代码里面也是默认使用通道数3,导致很多函数报错,在miniconda3/envs/yolov8/lib/python3.9/site-packages/ultralytics/utils/plotting.py的函数plot_images中将这一行:

    # Build Image
    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init

改为:

    # Build Image
    mosaic = np.full((int(ns * h), int(ns * w), images[0].shape[0]), 255, dtype=np.uint8)  # init

2.class Annotator:

在这个类的init里面会进行可视化操作,但是对于多通道图像无法保存为图片,所以就报错,这里可以直接用前3个通道来代替可视化图像,也可以用其他的通道,根据个人自由选择。在120行的

if self.pil:  # use PIL

后面加上两行:

            if im.shape[2] > 3:  # not RGB
                im = im[:, :, :3]

好了,到这里差不多就可以训练了

四、训练后验证

训练完成之后还会有一个验证的过程,不改的话也会报错,首先在miniconda3/envs/yolov8/lib/python3.9/site-packages/ultralytics/cfg/default.yaml这个文件里面增加一个超参数channel,如图所示:

然后在miniconda3/envs/yolov8/lib/python3.9/site-packages/ultralytics/engine/validator.py的类class BaseValidator:的def __call__(self, trainer=None, model=None):中修改:

model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup

为:

model.warmup(imgsz=(1 if pt else self.args.batch, self.args.channel, imgsz, imgsz))  # warmup

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐