AutoAnchor bug fix

pull/1/head
Glenn Jocher 5 years ago
parent 57a0ae3350
commit bafbc65ee3

@ -719,7 +719,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
return x, x.max(1)[0] # x, best_x return x, x.max(1)[0] # x, best_x
def fitness(k): # mutation fitness def fitness(k): # mutation fitness
_, best = metric(k) _, best = metric(torch.tensor(k, dtype=torch.float32))
return (best * (best > thr).float()).mean() # fitness return (best * (best > thr).float()).mean() # fitness
def print_results(k): def print_results(k):
@ -743,8 +743,8 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
# Get label wh # Get label wh
shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True) shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh wh = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
wh = wh[(wh > 2.0).all(1)].numpy() # filter > 2 pixels wh = wh[(wh > 2.0).all(1)] # filter > 2 pixels
# Kmeans calculation # Kmeans calculation
from scipy.cluster.vq import kmeans from scipy.cluster.vq import kmeans
@ -752,7 +752,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
s = wh.std(0) # sigmas for whitening s = wh.std(0) # sigmas for whitening
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
k *= s k *= s
wh = torch.tensor(wh) wh = torch.tensor(wh, dtype=torch.float32)
k = print_results(k) k = print_results(k)
# Plot # Plot
@ -771,7 +771,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
# Evolve # Evolve
npr = np.random npr = np.random
f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
for _ in tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm:'): for _ in tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm'):
v = np.ones(sh) v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates) while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)

Loading…
Cancel
Save