5.2.标签分配和Loss计算

5.2.1. 计算Loss的模块和流程

loss的运算流程如下,当aux_headAGM启用的时候,aux_headfpnaux_fpn获取featmap随后输出预测,在detach_epoch(需要自己设置的参数,在训练了detach_epoch后标签分配将由检测头自己进行)内,使用AGM的输出来对head的预测值进行标签分配。

先根据输入的大小获取priors的网格,然后将AGM预测的分数和检测框根据prior进行标签分配,把分配结果提交给head,再使用head的输出和gt计算loss最后反向传播完成一步迭代。

5.2.2. 计算Loss

这里我们先“假装”已经知道封装好的函数在做什么,先关注整体的计算流程:

def loss(self, preds, gt_meta, aux_preds=None):
    """Compute losses.
    Args:
        preds (Tensor): Prediction output. head的输出
        gt_meta (dict): Ground truth information. 包含gt位置和标签的字典,还有原图像的数据
        aux_preds (tuple[Tensor], optional): Auxiliary head prediction output.
        如果AGM还没有detach,会用AGM的输出进行标签分配

    Returns:
        loss (Tensor): Loss tensor.
        loss_states (dict): State dict of each loss.
    """
    # 把gt相关的数据分离出来,这两个数据都是list,长度为batchsize的大小
    # 每个list中都包含了它们各自对应的图像上的gt和label
    gt_bboxes = gt_meta["gt_bboxes"]
    gt_labels = gt_meta["gt_labels"]
    # 一会要传送到GPU上
    device = preds.device
    # 得到本次loss计算的batch数,pred是3维tensor,可以参见第一部分关于推理的介绍
    batch_size = preds.shape[0]

    # 对"img"信息取shape就可以得到图像的长宽,这里请看DataSet和DataLoader了解训练数据的详细格式
    # 所有图片都会在前处理中被resize成网络的输入大小,不足的则直接加zero padding
    input_height, input_width = gt_meta["img"].shape[2:]

    # 因为稍后要布置priors,这里要计算出feature map的大小
    # 如果修改了输入或者采样率,输入无法被stride整除,要取整
    # 默认是对齐的,应该为[40,40],[20,20],[10,10],[5,5]
    featmap_sizes = [
        (math.ceil(input_height / stride), math.ceil(input_width) / stride)
        for stride in self.strides
    ]

    # get grid cells of one image
    # 在不同大小的stride上放置一组priors,默认四个检测头也就是四个不同尺寸的stride
    # 最后返回的是tensor维度是[batchsize,strideW*strideH,4]
    # 其中每一个都是[x,y,strideH,strideW]的结构,当featmap不是正方形的时候两个stride不相等
    mlvl_center_priors = [
        self.get_single_level_center_priors(
            batch_size,
            featmap_sizes[i],
            stride,
            dtype=torch.float32,
            device=device,
        )
        for i, stride in enumerate(self.strides)
    ]

    # 按照第二个维度拼接后的prior的维度是[batchsize,40x40+20x20+10x10+5x5=2125,4]
    # 其中四个值为[cx,cy,strideW,stridH]
    center_priors = torch.cat(mlvl_center_priors, dim=1)

    # 把预测值拆分成分类和框回归
    # cls_preds的维度是[batchsize,2125*class_num],reg_pred是[batchsize,2125*4*(reg_max+1)]
    cls_preds, reg_preds = preds.split(
        [self.num_classes, 4 * (self.reg_max + 1)], dim=-1
    )

    # 对reg_preds进行积分得到位置预测,reg_reds表示的是一条边的离散分布,
    # 积分就得到位置(对于离散来说就是加权求和),distribution_project()是Integral函数,稍后讲解
    # 乘以stride(在center_priors的最后两个位置)后,就得到[dl,dr,dt,db]在原图的长度了
    # dis_preds是[batchsize,2125,4],其中4为上述的中心到检测框四条边的距离
    dis_preds = self.distribution_project(reg_preds) * center_priors[..., 2, None]
    # 把[dl,dr,dt,db]根据prior的位置转化成框的左上角点和右下角点方便计算iou
    decoded_bboxes = distance2bbox(center_priors[..., :2], dis_preds)

    # 如果启用了辅助训练模块,将用其结果进行标签分配
    if aux_preds is not None:
        # use auxiliary head to assign
        aux_cls_preds, aux_reg_preds = aux_preds.split(
            [self.num_classes, 4 * (self.reg_max + 1)], dim=-1
        )
        aux_dis_preds = (
            self.distribution_project(aux_reg_preds) * center_priors[..., 2, None]
        )
        aux_decoded_bboxes = distance2bbox(center_priors[..., :2], aux_dis_preds)

        # 可以去看multi_apply的实现,是一个稍微有点复杂的map()方法
        # 每次给一张图片进行分配,应该是为了避免显存溢出,代码写起来可读性也更高
        # 应该有并行优化,因此不用太担心效率问题
        batch_assign_res = multi_apply(
            self.target_assign_single_img,
            aux_cls_preds.detach(),
            center_priors,
            aux_decoded_bboxes.detach(),
            gt_bboxes,
            gt_labels,
        )
    else:
        # use self prediction to assign
        # multi_apply将参数中的函数作用在后面的每一个可迭代对象上,一次处理批量数据
        # 因为target_assign_single_img一次只能分配一张图片
        # 并且由于显存限制,有时候无法一次处理整个batch
        batch_assign_res = multi_apply(
            self.target_assign_single_img,
            cls_preds.detach(),
            center_priors,
            decoded_bboxes.detach(),
            gt_bboxes,
            gt_labels,
        )

    # 根据分配结果计算loss,这个函数稍后会介绍
    loss, loss_states = self._get_loss_from_assign(
        cls_preds, reg_preds, decoded_bboxes, batch_assign_res
    )

    # 加入辅助训练模块的loss,这可以让网络在初期收敛的更快
    if aux_preds is not None:
        aux_loss, aux_loss_states = self._get_loss_from_assign(
            aux_cls_preds, aux_reg_preds, aux_decoded_bboxes, batch_assign_res
        )
        loss = loss + aux_loss
        for k, v in aux_loss_states.items():
            loss_states["aux_" + k] = v
    return loss, loss_states

获取每个priors坐标的函数get_single_level_center_priors()如下,注意一定要先学习FCOS式检测器的原理,否则理解loss计算和推理将会有很大的阻碍。

# 在feature map上布置一组prior
# prior就是框分布的回归起点,将以prior的位置作为目标中心,预测四个值形成检测框
def get_single_level_center_priors(
    self, batch_size, featmap_size, stride, dtype, device
):
    """Generate centers of a single stage feature map.
    Args:
        batch_size (int): Number of images in one batch.
        featmap_size (tuple[int]): height and width of the feature map
        stride (int): down sample stride of the feature map
        dtype (obj:`torch.dtype`): data type of the tensors
        device (obj:`torch.device`): device of the tensors
    Return:
        priors (Tensor): center priors of a single level feature map.
    """
    h, w = featmap_size
    # arange()会生成一个一维tensor,和range()差不多,步长默认为1
    # 乘上stride后,就得到了划分的网格宽度和长度
    x_range = (torch.arange(w, dtype=dtype, device=device)) * stride
    y_range = (torch.arange(h, dtype=dtype, device=device)) * stride
    # 根据网格长宽生成一组二维坐标
    y, x = torch.meshgrid(y_range, x_range)
    # 展平成一维
    y = y.flatten()
    x = x.flatten()
    # 扩充出一个strides的tensor,稍后给每一个prior都加上其对应的下采样倍数
    strides = x.new_full((x.shape[0],), stride)
    # 把得到的prior按照以下顺序叠成二维tensor,即原图上的坐标,采样倍数
    proiors = torch.stack([x, y, strides, strides], dim=-1)
    # 一次处理一个batch,所以unsqueeze增加一个batch的维度
    # 然后把得到的prior复制到同个batch的其他位置上
    return proiors.unsqueeze(0).repeat(batch_size, 1, 1)

5.2.3. 为一张图片进行标签分配并获得正负样本

这里有两部分,分别是target_assign_single_img() sample()(被前者调用,生成正负样本)。刚刚已经看到,target_assign_single_img()_get_loss_from_assign()调用。

# 标签分配时的运算不会被记录,我们只是在计算cost并进行匹配
# 需要特别注意,这个函数只为一张图片,即一个样本进行标签分配!
@torch.no_grad()
def target_assign_single_img(
    self, cls_preds, center_priors, decoded_bboxes, gt_bboxes, gt_labels
):
    """Compute classification, regression, and objectness targets for
    priors in a single image.
    这些参数和第四部分介绍的label assigner差不多,稍后也会调用assign
    主要是为能够进行批处理而增加了一些代码
    Args:
        cls_preds (Tensor): Classification predictions of one image,
            a 2D-Tensor with shape [num_priors, num_classes]
        center_priors (Tensor): All priors of one image, a 2D-Tensor with
            shape [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
        decoded_bboxes (Tensor): Decoded bboxes predictions of one image,
            a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y,
            br_x, br_y] format.
        gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
            with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
        gt_labels (Tensor): Ground truth labels of one image, a Tensor
            with shape [num_gts].
    """

    num_priors = center_priors.size(0)
    device = center_priors.device
    # 一些前处理,把数据都移到gpu里面
    gt_bboxes = torch.from_numpy(gt_bboxes).to(device)
    gt_labels = torch.from_numpy(gt_labels).to(device)
    num_gts = gt_labels.size(0)
    gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype)

    # dist_targets是最终用来计算回归损失的tensor,具体过程看后面
    bbox_targets = torch.zeros_like(center_priors)
    dist_targets = torch.zeros_like(center_priors)

    # 把label扩充成one-hot向量
    labels = center_priors.new_full(
        (num_priors,), self.num_classes, dtype=torch.long
    )
    label_scores = center_priors.new_zeros(labels.shape, dtype=torch.float)

    # No target,啥也不用管,全都是负样本,返回回去就行
    if num_gts == 0:
        return labels, label_scores, bbox_targets, dist_targets, 0

    # class的输出要映射到0-1之间,看之前对head构建conv layer就可以发现最后的分类输出后面没带激活函数
    # assign参见第四部分关于dsl_assigner的介绍
    assign_result = self.assigner.assign(
        cls_preds.sigmoid(), center_priors, decoded_bboxes, gt_bboxes, gt_labels
    )
    # 调用采样函数,获得正负样本,这个函数在下面马上介绍
    pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds = self.sample(
        assign_result, gt_bboxes
    )

    # 当前进行分配的这个图片上正样本的数目
    num_pos_per_img = pos_inds.size(0)
    # 把分配到了gt的那些prior预测的检测框和gt的iou算出来,稍后用于QFL的计算
    pos_ious = assign_result.max_overlaps[pos_inds]

    if len(pos_inds) > 0:
        # bbox_targets就是最终用来和gt计算回归损失的东西了(reg分支),维度为[2125,4]
        # 不过这里还需要一步转化,因为bbox_target是检测框四条边和他对应的prior的偏移量
        # 因此要转换成原图上的框(绝对位置),和gt进行回归损失计算
        bbox_targets[pos_inds, :] = pos_gt_bboxes
        # 
        dist_targets[pos_inds, :] = (
            bbox2distance(center_priors[pos_inds, :2], pos_gt_bboxes)
            / center_priors[pos_inds, None, 2]
        )
        dist_targets = dist_targets.clamp(min=0, max=self.reg_max - 0.1)
        # 上面计算回归,这里就是得到用于计算类别损失的,把那些匹配到的prior利用pos_inds索引筛出来
        labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
        label_scores[pos_inds] = pos_ious
    return (
        labels,
        label_scores,
        bbox_targets,
        dist_targets,
        num_pos_per_img,
    )

sample()的实现很简单,没分配到标签的priors都是负样本,这也是动态标签分配的优点所在:

def sample(self, assign_result, gt_bboxes):
    """Sample positive and negative bboxes."""
    # 分配到标签的priors索引,注意正样本和负样本的大小总和应该为prior的数目
    # 对于320x320,是2125
    pos_inds = (
        torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
        .squeeze(-1)
        .unique()
    )
    # 没分配到标签的priors索引
    neg_inds = (
        torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
        .squeeze(-1)
        .unique()
    )

    # -----------------------------------------------------------------------
    # @TODO:
    # 这里有疑问,不知道为什么之前在dsl_assigner里面要将分配到标签的prior索引变成2
    # 其实不是直接设置成1就可以了吗?然后这里又重新把index-1,有点不太明白
    # 如果有看懂了的或者知道为什么,请联系我!
    pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
    #------------------------------------------------------------------------

    if gt_bboxes.numel() == 0:
        # hack for index error case
        assert pos_assigned_gt_inds.numel() == 0
        pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
    else:
        if len(gt_bboxes.shape) < 2:
            gt_bboxes = gt_bboxes.view(-1, 4)
        # pos_gt_bboxes大小为[正样本数,4],4就是框的位置了
        pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
    return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds