YOLOV5源代码学习之check_anchors()函数

该函数主要在train.py中调用

 为了方便直观的阅读代码,对代码中的变量值进行了输出

def check_anchors(dataset, model, thr=4.0, imgsz=640):# Check anchor fit to data, recompute if necessaryprefix = colorstr('autoanchor: ')print(f'\n{prefix}Analyzing anchors... ', end='')m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1]  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1))  # augment scalewh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float()  # wh

dataset中有697张图片,imgsz=640

prefix='[34m[1mautoanchor: [0m'

m=Detect(
  (m): ModuleList(
    (0): Conv2d(128, 21, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(256, 21, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(512, 21, kernel_size=(1, 1), stride=(1, 1))
  )
)

shapes中的数据如下(共计697张图片,图片从0开始计数到696)

scale中的数据如下

 for循环中的第一个s=[     646.25      363.51],这两个数是scale和shapes的乘积

第一个l=[[          0     0.50703     0.45451     0.51719     0.24792]],这是数据集的标签

wh用来存储训练数据中所有gt框的宽高,其shape为(N,2),N为gt框的总个数,此处N=697

wh=tensor([ [334.23145,  90.12132],
                    [350.72150, 153.06581],
                    [349.53928, 101.28985],
                    ...,
                    [349.43201, 202.97702],
                    [358.07135, 103.62663],
                    [356.24854, 120.78210]])

    def metric(k):  # compute metricr = wh[:, None] / k[None]x = torch.min(r, 1. / r).min(2)[0]  # ratio metricbest = x.max(1)[0]  # best_xaat = (x > 1. / thr).float().sum(1).mean()  # anchors above thresholdbpr = (best > 1. / thr).float().mean()  # best possible recallreturn bpr, aatanchors = m.anchor_grid.clone().cpu().view(-1, 2)  # current anchorsbpr, aat = metric(anchors)

此处由后一行代码调用的metric()函数,因此按代码执行顺序来看

anchors=tensor([ [ 10.,  13.],
                            [ 16.,  30.],
                            [ 33.,  23.],
                            [ 30.,  61.],
                            [ 62.,  45.],
                            [ 59., 119.],
                            [116.,  90.],
                            [156., 198.],
                            [373., 326.]])

进入metric()函数,通过anchors和wh来计算bpr,aat(anchors above threshold) 两个指标

r=tensor([ [ [33.42315,  6.93241],
                   ...,
                   [ 0.89606,  0.27645]],
                   ...,

                 [ [35.62486,  9.29093],
                   ...,
                   [ 0.95509,  0.37050]]])

x的shape[N,9],即[697,9]

x=tensor([ [0.02992, 0.04787, 0.09873,  ..., 0.34706, 0.45516, 0.27645],
                 ...,
                 [0.02807, 0.04491, 0.09263,  ..., 0.32562, 0.43790, 0.37050]])

best=tensor([0.45516, 0.46953, 0.44630, 0.49288, 0.41652, 0.43961, 0.42956, ..., 0.43790])

best的shape为N,即697

thr=4.0

aat=tensor(2.94835)

bpr=tensor(1.)

  • bpr(best possible recall): 最多能被召回的gt框数量 / 所有gt框数量 最大值为1 越大越好 小于0.98就需要使用k-means + 遗传进化算法选择出与数据集更匹配的anchors框。
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98:  # threshold to recomputeprint('. Attempting to improve anchors, please wait...')na = m.anchor_grid.numel() // 2  # number of anchorstry:anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)except Exception as e:print(f'{prefix}ERROR: {e}')new_bpr = metric(anchors)[0]if new_bpr > bpr:  # replace anchorsanchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid)  # for inferencem.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1)  # losscheck_anchor_order(m)print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')else:print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
print('')  # newline

此处bpr=1.0>0.98,因此不进入if判断语句

如果进入if判断语句,就会调用kmean_anchors()函数


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部