YOLOV5源代码学习之check_anchors()函数
该函数主要在train.py中调用
为了方便直观的阅读代码,对代码中的变量值进行了输出
def check_anchors(dataset, model, thr=4.0, imgsz=640):# Check anchor fit to data, recompute if necessaryprefix = colorstr('autoanchor: ')print(f'\n{prefix}Analyzing anchors... ', end='')m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scalewh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
dataset中有697张图片,imgsz=640
prefix='[34m[1mautoanchor: [0m'
m=Detect(
(m): ModuleList(
(0): Conv2d(128, 21, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 21, kernel_size=(1, 1), stride=(1, 1))
(2): Conv2d(512, 21, kernel_size=(1, 1), stride=(1, 1))
)
)
shapes中的数据如下(共计697张图片,图片从0开始计数到696)
scale中的数据如下
for循环中的第一个s=[ 646.25 363.51],这两个数是scale和shapes的乘积
第一个l=[[ 0 0.50703 0.45451 0.51719 0.24792]],这是数据集的标签
wh用来存储训练数据中所有gt框的宽高,其shape为(N,2),N为gt框的总个数,此处N=697
wh=tensor([ [334.23145, 90.12132],
[350.72150, 153.06581],
[349.53928, 101.28985],
...,
[349.43201, 202.97702],
[358.07135, 103.62663],
[356.24854, 120.78210]])
def metric(k): # compute metricr = wh[:, None] / k[None]x = torch.min(r, 1. / r).min(2)[0] # ratio metricbest = x.max(1)[0] # best_xaat = (x > 1. / thr).float().sum(1).mean() # anchors above thresholdbpr = (best > 1. / thr).float().mean() # best possible recallreturn bpr, aatanchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchorsbpr, aat = metric(anchors)
此处由后一行代码调用的metric()函数,因此按代码执行顺序来看
anchors=tensor([ [ 10., 13.],
[ 16., 30.],
[ 33., 23.],
[ 30., 61.],
[ 62., 45.],
[ 59., 119.],
[116., 90.],
[156., 198.],
[373., 326.]])
进入metric()函数,通过anchors和wh来计算bpr,aat(anchors above threshold) 两个指标
r=tensor([ [ [33.42315, 6.93241],
...,
[ 0.89606, 0.27645]],
...,
[ [35.62486, 9.29093],
...,
[ 0.95509, 0.37050]]])
x的shape[N,9],即[697,9]
x=tensor([ [0.02992, 0.04787, 0.09873, ..., 0.34706, 0.45516, 0.27645],
...,
[0.02807, 0.04491, 0.09263, ..., 0.32562, 0.43790, 0.37050]])
best=tensor([0.45516, 0.46953, 0.44630, 0.49288, 0.41652, 0.43961, 0.42956, ..., 0.43790])
best的shape为N,即697
thr=4.0
aat=tensor(2.94835)
bpr=tensor(1.)
- bpr(best possible recall): 最多能被召回的gt框数量 / 所有gt框数量 最大值为1 越大越好 小于0.98就需要使用k-means + 遗传进化算法选择出与数据集更匹配的anchors框。
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98: # threshold to recomputeprint('. Attempting to improve anchors, please wait...')na = m.anchor_grid.numel() // 2 # number of anchorstry:anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)except Exception as e:print(f'{prefix}ERROR: {e}')new_bpr = metric(anchors)[0]if new_bpr > bpr: # replace anchorsanchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inferencem.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # losscheck_anchor_order(m)print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')else:print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
print('') # newline
此处bpr=1.0>0.98,因此不进入if判断语句
如果进入if判断语句,就会调用kmean_anchors()函数
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!