【计算机视觉】YOLOv8的predict过程
YOLOv8的predict过程
YOLOv8的predict过程
流程图示意
yolov8的流程图如下图所示,其中执行流程在于stream_inference,inference会分为单帧和视频流两种情况;这里使用单帧进行测试,测试的size为(1920,1200);在进行推理时,使用的是 yolov8n-seg.yaml 以及 yolov8-seg.pt,这样推理的 output channel = 32而不是文章中介绍的64,如果使用的时 yolov8l,则输出的channel就变成了64。
1.setup_model
先判断model有没有成功进行setup,如果没有则进行setup
2.setup_source
进行资源参数的配置
3.warmup
对模型进行预热,主要是为了优化学习率
4.preprocess
现在输入图片的数据为1920x1200x3,进行模型的前处理时主要包括的内容如下:
(1)pre_transform,将img进行等比例缩放,会进行padding等操作
(2)颜色格式和通道的变换,将BGR转换成RGB,BHWC转换成BCHW
(3)将数据格式变成内存容易读取的格式;归一化操作
5.inference
(1)visualize判断是否需要对模型的每一层输出的特征图进行可视化,能够进行保存
(2)forward前向推理,会调用基类的predict函数,进而调用_predict_once函数;在进行推理时,会调用各个模块的forward方法,初始化均在各个模块的 __init__函数中
6.postprocess
(1)图像的后处理,主要是进行多个预测框的nms算法以及mask的处理
def postprocess(self, preds, img, orig_imgs):
"""Applies non-max suppression and processes detections for each image in an input batch."""
# preds[0] => (1, 116, 5460),其中分为三个部分
# (1, 4, 5460) => 目标坐标框,4表示了框的位置,以(xywh)进行存储,x和y表示的是框中心位置
# (1, 80, 5460) => 80个类别,每个类别判定的概率
# (1, 32, 5460) => 32个mask
# p => (7, 38)
p = ops.non_max_suppression(
preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
# proto 是 mask proto,proto => (1, 32, 104, 160)
proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported
for i, pred in enumerate(p):
orig_img = orig_imgs[i]
img_path = self.batch[0][i]
if not len(pred): # save empty boxes
masks = None
elif self.args.retina_masks:
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
# proto[i] 表示的是 mask proto,是最小尺寸的mask,即mask的基量
# pred[:, 6:] 表示的是后面32个mask
# pred[:, :4] 表示的是前面4个值,即box的位置信息
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results
nms算法的原理
- 将集合H中的框进行排序,选出分数最高的框m,从集合H从移动到集合M
- 遍历H中的框,分别与框m计算交并比(IoU),如果高于某个阈值(一般为0~0.5),则认为此框与m重叠,将此框从集合H中去除
- 迭代第1布,直到H为空,集合M中的框为所需
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nc=0, # number of classes (optional)
max_time_img=0.05,
max_nms=30000,
max_wh=7680,
in_place=True,
rotated=False,
):
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
Args:
prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
containing the predicted boxes, classes, and masks. The tensor should be in the format
output by a model, such as YOLO.
conf_thres (float): The confidence threshold below which boxes will be filtered out.
Valid values are between 0.0 and 1.0.
iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
Valid values are between 0.0 and 1.0.
classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
agnostic (bool): If True, the model is agnostic to the number of classes, and all
classes will be considered as one.
multi_label (bool): If True, each box may have multiple labels.
labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
list contains the apriori labels for a given image. The list should be in the format
output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
max_det (int): The maximum number of boxes to keep after NMS.
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
max_time_img (float): The maximum time (seconds) for processing one image.
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
max_wh (int): The maximum box width and height in pixels.
in_place (bool): If True, the input prediction tensor will be modified in place.
Returns:
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
"""
# Checks
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
# 输入值为 prediction = preds[0],即推理结果的 list{2} 中的第一个
# prediction => (1, 116, 5460),其中分为三个部分
# (1, 4, 5460) => 目标坐标框,4表示了框的位置,以(xywh)进行存储,x和y表示的是框中心位置
# (1, 80, 5460) => 80个类别,每个类别判定的概率
# (1, 32, 5460) => 32个mask
bs = prediction.shape[0] # batch size,输入如果是单张图片就是1
nc = nc or (prediction.shape[1] - 4) # number of classes,默认yolo使用的是80
nm = prediction.shape[1] - nc - 4 # number of mask,mask的数量
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates,对小于0.25的类别概率进行过滤
# Settings
# min_wh = 2 # (pixels) minimum box width and height
time_limit = 2.0 + max_time_img * bs # seconds to quit after
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84),这里测试时不是6300而是5460
if not rotated:
if in_place:
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
else:
prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy,坐标转换
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs # 这里的 6 = 4(boxes) + 1(conf) + 1(j)
for xi, x in enumerate(prediction): # image index, image inference,从prediction的list当中逐帧读取
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
# xc[xi]的内容是list,list当中是True或者是False;对于x而言,如果是True,则该位置的元素保留,否则剔除;
# 此时 x => (55, 116),即经过过滤之后还剩下55个框
x = x[xc[xi]] # confidence,
# Cat apriori labels if autolabelling
if labels and len(labels[xi]) and not rotated:
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Detections matrix nx6 (xyxy, conf, cls)
# box => (55, 4)
# cls => (55, 80)
# mask => (55, 32)
box, cls, mask = x.split((4, nc, nm), 1)
if multi_label:
i, j = torch.where(cls > conf_thres)
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True) # 取每个类的最大值,即每行的最大值
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] # 将阈值小于0.45的值过滤掉,一般情况下会选择置信度0.4~0.5
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
if n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes,将每个类别乘以一个很大的数字
scores = x[:, 4] # scores
if rotated:
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
i = nms_rotated(boxes, scores, iou_thres)
else:
boxes = x[:, :4] + c # boxes (offset by class),增加一个offset
# 执行内部的nms流程,其中boxes => (N,4), scores => (N),
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections,过滤之后的bounding boxes索引,按照降序进行排序
# # Experimental
# merge = False # use merge-NMS
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# from .metrics import box_iou
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
# weights = iou * scores[None] # box weights
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
# redundant = True # require redundant detections
# if redundant:
# i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i] # 将输出保存到output当中,output => (7, 38),表明所检测的图片当中还存在7个框,38则是记录了这7个框的情况
if (time.time() - t) > time_limit:
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
break # time limit exceeded
return output
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)