YOLOv8的predict过程

流程图示意

yolov8的流程图如下图所示,其中执行流程在于stream_inference,inference会分为单帧和视频流两种情况;这里使用单帧进行测试,测试的size为(1920,1200);在进行推理时,使用的是 yolov8n-seg.yaml 以及 yolov8-seg.pt,这样推理的 output channel = 32而不是文章中介绍的64,如果使用的时 yolov8l,则输出的channel就变成了64。

在这里插入图片描述

1.setup_model
先判断model有没有成功进行setup,如果没有则进行setup
2.setup_source
进行资源参数的配置
3.warmup
对模型进行预热,主要是为了优化学习率
4.preprocess
现在输入图片的数据为1920x1200x3,进行模型的前处理时主要包括的内容如下:
(1)pre_transform,将img进行等比例缩放,会进行padding等操作
(2)颜色格式和通道的变换,将BGR转换成RGB,BHWC转换成BCHW
(3)将数据格式变成内存容易读取的格式;归一化操作
5.inference
(1)visualize判断是否需要对模型的每一层输出的特征图进行可视化,能够进行保存
(2)forward前向推理,会调用基类的predict函数,进而调用_predict_once函数;在进行推理时,会调用各个模块的forward方法,初始化均在各个模块的 __init__函数中
6.postprocess
(1)图像的后处理,主要是进行多个预测框的nms算法以及mask的处理

    def postprocess(self, preds, img, orig_imgs):
        """Applies non-max suppression and processes detections for each image in an input batch."""
        # preds[0] => (1, 116, 5460),其中分为三个部分
        # (1, 4, 5460) => 目标坐标框,4表示了框的位置,以(xywh)进行存储,x和y表示的是框中心位置
        # (1, 80, 5460) => 80个类别,每个类别判定的概率
        # (1, 32, 5460) => 32个mask
        # p => (7, 38)
        p = ops.non_max_suppression(
            preds[0],
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            nc=len(self.model.names),
            classes=self.args.classes,
        )

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        # proto 是 mask proto,proto => (1, 32, 104, 160)
        proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]  # tuple if PyTorch model or array if exported
        for i, pred in enumerate(p):
            orig_img = orig_imgs[i]
            img_path = self.batch[0][i]
            if not len(pred):  # save empty boxes
                masks = None
            elif self.args.retina_masks:
                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
                masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
            else:
                # proto[i] 表示的是 mask proto,是最小尺寸的mask,即mask的基量
                # pred[:, 6:] 表示的是后面32个mask
                # pred[:, :4] 表示的是前面4个值,即box的位置信息
                masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
        return results

nms算法的原理

  1. 将集合H中的框进行排序,选出分数最高的框m,从集合H从移动到集合M
  2. 遍历H中的框,分别与框m计算交并比(IoU),如果高于某个阈值(一般为0~0.5),则认为此框与m重叠,将此框从集合H中去除
  3. 迭代第1布,直到H为空,集合M中的框为所需
def non_max_suppression(
    prediction,
    conf_thres=0.25,
    iou_thres=0.45,
    classes=None,
    agnostic=False,
    multi_label=False,
    labels=(),
    max_det=300,
    nc=0,  # number of classes (optional)
    max_time_img=0.05,
    max_nms=30000,
    max_wh=7680,
    in_place=True,
    rotated=False,
):
    """
    Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.

    Args:
        prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
            containing the predicted boxes, classes, and masks. The tensor should be in the format
            output by a model, such as YOLO.
        conf_thres (float): The confidence threshold below which boxes will be filtered out.
            Valid values are between 0.0 and 1.0.
        iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
            Valid values are between 0.0 and 1.0.
        classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
        agnostic (bool): If True, the model is agnostic to the number of classes, and all
            classes will be considered as one.
        multi_label (bool): If True, each box may have multiple labels.
        labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
            list contains the apriori labels for a given image. The list should be in the format
            output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
        max_det (int): The maximum number of boxes to keep after NMS.
        nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
        max_time_img (float): The maximum time (seconds) for processing one image.
        max_nms (int): The maximum number of boxes into torchvision.ops.nms().
        max_wh (int): The maximum box width and height in pixels.
        in_place (bool): If True, the input prediction tensor will be modified in place.

    Returns:
        (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
            shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
            (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
    """

    # Checks
    assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
    assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
    if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    # 输入值为 prediction = preds[0],即推理结果的 list{2} 中的第一个
    # prediction => (1, 116, 5460),其中分为三个部分
    # (1, 4, 5460) => 目标坐标框,4表示了框的位置,以(xywh)进行存储,x和y表示的是框中心位置
    # (1, 80, 5460) => 80个类别,每个类别判定的概率
    # (1, 32, 5460) => 32个mask

    bs = prediction.shape[0]  # batch size,输入如果是单张图片就是1
    nc = nc or (prediction.shape[1] - 4)  # number of classes,默认yolo使用的是80
    nm = prediction.shape[1] - nc - 4  # number of mask,mask的数量
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates,对小于0.25的类别概率进行过滤

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    time_limit = 2.0 + max_time_img * bs  # seconds to quit after
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)

    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84),这里测试时不是6300而是5460
    if not rotated:
        if in_place:
            prediction[..., :4] = xywh2xyxy(prediction[..., :4])  # xywh to xyxy
        else:
            prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1)  # xywh to xyxy,坐标转换

    t = time.time()
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs # 这里的 6 = 4(boxes) + 1(conf) + 1(j)
    for xi, x in enumerate(prediction):  # image index, image inference,从prediction的list当中逐帧读取
        # Apply constraints
        # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
        # xc[xi]的内容是list,list当中是True或者是False;对于x而言,如果是True,则该位置的元素保留,否则剔除;
        # 此时 x => (55, 116),即经过过滤之后还剩下55个框
        x = x[xc[xi]]  # confidence,

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]) and not rotated:
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
            v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
            v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Detections matrix nx6 (xyxy, conf, cls)
        # box => (55, 4)
        # cls => (55, 80)
        # mask => (55, 32)
        box, cls, mask = x.split((4, nc, nm), 1)

        if multi_label:
            i, j = torch.where(cls > conf_thres)
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True) # 取每个类的最大值,即每行的最大值
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] # 将阈值小于0.45的值过滤掉,一般情况下会选择置信度0.4~0.5

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        if n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes,将每个类别乘以一个很大的数字
        scores = x[:, 4]  # scores
        if rotated:
            boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1)  # xywhr
            i = nms_rotated(boxes, scores, iou_thres)
        else:
            boxes = x[:, :4] + c  # boxes (offset by class),增加一个offset
            # 执行内部的nms流程,其中boxes => (N,4), scores => (N),
            i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        i = i[:max_det]  # limit detections,过滤之后的bounding boxes索引,按照降序进行排序

        # # Experimental
        # merge = False  # use merge-NMS
        # if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
        #     # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
        #     from .metrics import box_iou
        #     iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
        #     weights = iou * scores[None]  # box weights
        #     x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
        #     redundant = True  # require redundant detections
        #     if redundant:
        #         i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i] # 将输出保存到output当中,output => (7, 38),表明所检测的图片当中还存在7个框,38则是记录了这7个框的情况
        if (time.time() - t) > time_limit:
            LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
            break  # time limit exceeded

    return output
Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐