YOLOv8
损失包括两大部分:类别损失 和位置损失。
对于类别损失, YOLOv8采用了和RetinaNet、FCOS等相同的策略,使用sigmoid函数来计算每个类别的概率,并计算全局的类别损失,其学习标签是由TOOD给出的target_scores
,其中,正样本的类别标签就是IoU值,而负样本处全是0。对于这种情况,一个常用的策略是使用Variable Focal loss(VFL), 比如YOLOv6和PP-YOLOE都是这么做的,但YOLOv8则采用简单的BCE。
代码分析
TaskAlignedAssigner
select_topk_candidates
这个函数select_topk_candidates
的目的是基于给定的指标选择前k个候选锚点。
下面是每一行代码的解释:
-
函数签名部分定义了该函数的输入和输出。
- 输入
metrics
是一个Tensor,表示锚点与真实对象之间的某些度量,例如重叠或得分。 largest
是一个布尔标志,决定是选择度量的最大值还是最小值。topk_mask
是一个可选的布尔tensor,指示应考虑哪些顶部k值。
- 输入
-
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
: 使用torch.topk
从最后一个维度选择前self.topk
个度量,并返回这些值及其索引。 -
if topk_mask is None: ...
: 如果没有提供topk_mask
,则计算一个新的topk_mask
,该mask对应于度量值大于self.eps
的顶部k值。 -
topk_idxs.masked_fill_(~topk_mask, 0)
: 使用逻辑非~
反转topk_mask
,然后使用masked_fill_
将topk_idxs
中对应于false值的索引位置设置为0。 -
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
: 创建一个形状与metrics
相同的零tensor,用于计数选择的锚点。 -
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
: 创建一个与topk_idxs
中的一部分具有相同形状的全1张量。 -
for k in range(self.topk): ...
: 通过循环将topk_idxs
中的索引添加到count_tensor
中。对于每个k值,都会增加相应位置的计数。 -
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
: 使用scatter_add_
将1加到count_tensor
中,这些1的位置由topk_idxs
中的索引决定。 -
count_tensor.masked_fill_(count_tensor > 1, 0)
: 如果count_tensor
中的任何值大于1(意味着某些锚点被选择了多次),则将这些值设置为0。 -
return count_tensor.to(metrics.dtype)
: 将count_tensor
转换为与metrics
相同的数据类型并返回。
总体来说,该函数的目标是基于度量选择前k个候选锚点,然后返回一个指示哪些锚点被选择的张量。
select_candidates_in_gts
该函数select_candidates_in_gts
的目的是选择落在给定的真实边界框(gt_bboxes)内部的锚框中心点。
下面是每一行代码的解释:
-
n_anchors = xy_centers.shape[0]
: 获取锚框的数量。 -
bs, n_boxes, _ = gt_bboxes.shape
: 获取输入gt_bboxes的形状,其中bs
表示批量大小,n_boxes
表示每张图片上真实边界框的数量。 -
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
: 这行代码首先通过view
方法改变gt_bboxes的形状以便操作,然后使用chunk
方法将其分割为两部分,分别为左上角(lt
)和右下角(rb
)坐标。 -
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2)
: 这里计算了两个差值:xy_centers[None] - lt
:锚框中心与真实边界框左上角的差。rb - xy_centers[None]
:真实边界框右下角与锚框中心的差。
然后使用torch.cat
沿着dim=2连接这两个差值。
-
bbox_deltas = ... .view(bs, n_boxes, n_anchors, -1)
: 使用view
改变bbox_deltas
的形状,使其符合(batch size, 真实边界框的数量, 锚框的数量, 4)。 -
return bbox_deltas.amin(3).gt_(eps)
:bbox_deltas.amin(3)
:在最后一个维度上寻找每个bbox_deltas
的最小值。gt_(eps)
: 检查这些最小值是否大于给定的阈值eps
。如果是,则该锚框中心被认为落在了真实边界框内部。
所以,该函数的输出是一个布尔张量,表示哪些锚框中心落在了真实边界框的内部。
get_box_metrics
这个函数get_box_metrics
的主要目的是计算预测的边界框与真实边界框之间的对齐度度量。度量的计算基于两个关键因素:预测框的分类分数和预测框与真实边界框之间的IoU(交并比)。
现在,我将对这个函数的每一行代码进行详细解释:
-
na = pd_bboxes.shape[-2]
: 获取预测边界框的数量,也就是每张图片上的锚框数量。 -
mask_gt = mask_gt.bool()
: 将mask_gt
转化为布尔类型。mask_gt
是一个掩码,标明哪些锚框与真实的目标物体有重叠。 -
定义
overlaps
和bbox_scores
为全零的张量。这些张量的目的是存储每个真实目标物体与所有锚框之间的IoU和分类得分。 -
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)
: 初始化一个指示张量,用于在后面获取每个锚框的正确分类分数。 -
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)
: 为每个样本设置相应的批次索引。 -
ind[1] = gt_labels.squeeze(-1)
: 获取每个真实物体的标签作为索引。 -
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]
: 使用ind
索引从pd_scores
中提取与真实目标物体重叠的锚框的分类分数。 -
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
: 对预测边界框进行变形,使其与gt_bboxes
具有相同的形状,以便计算IoU。 -
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
: 对真实边界框进行变形,使其与pd_bboxes
具有相同的形状。 -
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
: 计算每个真实边界框与所有锚框之间的IoU。这里使用了CIoU(Complete Intersection over Union),它是IoU的一种变体,考虑了中心点、宽高等因素,提供了更全面的匹配度量。 -
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
: 计算对齐度量,它是基于分类分数的权值和IoU的权值的乘积。self.alpha
和self.beta
是超参数,用于平衡这两个因素的权重。 -
返回
align_metric
和overlaps
。
总结,这个函数的目的是为每个真实目标物体与所有锚框之间计算一个对齐度量,这个度量是基于预测框的分类得分和预测框与真实边界框之间的IoU来得出的。