From 56e71b2fb0df078cafa05632630cbd326d05a255 Mon Sep 17 00:00:00 2001 From: qinxiaonan_branch <860289024@qq.com> Date: Thu, 4 Jul 2024 11:01:53 +0800 Subject: [PATCH] =?UTF-8?q?train.py=E7=9A=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/image_recognition/train.py | 49 ++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/src/image_recognition/train.py b/src/image_recognition/train.py index c652328..766337f 100644 --- a/src/image_recognition/train.py +++ b/src/image_recognition/train.py @@ -557,46 +557,49 @@ def main(opt, callbacks=Callbacks()): evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv' if opt.bucket: os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # 下载evolve.csv + - for _ in range(opt.evolve): # generations to evolve - if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate - # Select parent(s) - parent = 'single' # parent selection method: 'single' or 'weighted' + #进行opt.evolve指定次数的超参数演化 + for _ in range(opt.evolve): + if evolve_csv.exists(): # 如果存在演化过程中生成的csv文件,选择最佳超参数并进行变异 + # 选择父方法 + parent = 'single' # 'single' 表示随机选择一个最佳超参数集作为父代;'weighted' 表示加权选择 x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1) - n = min(5, len(x)) # number of previous results to consider - x = x[np.argsort(-fitness(x))][:n] # top n mutations - w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0) + n = min(5, len(x)) # 加载超参数数据 + x = x[np.argsort(-fitness(x))][:n] # 考虑前n个历史结果 + w = fitness(x) - fitness(x).min() + 1E-6 if parent == 'single' or len(x) == 1: - # x = x[random.randint(0, n - 1)] # random selection - x = x[random.choices(range(n), weights=w)[0]] # weighted selection + # 随机选择一个超参数集 + x = x[random.choices(range(n), weights=w)[0]] # 权重选择 elif parent == 'weighted': - x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination + x = (x * w.reshape(n, 1)).sum(0) / w.sum() # 权重组合 + + # 变异操作 + mp, s = 0.8, 0.2 # 变异概率,标准差 - # Mutate - mp, s = 0.8, 0.2 # mutation probability, sigma npr = np.random npr.seed(int(time.time())) - g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1 + g = np.array([meta[k][0] for k in hyp.keys()]) # ng = len(meta) v = np.ones(ng) - while all(v == 1): # mutate until a change occurs (prevent duplicates) + while all(v == 1): # 进行变异,直到发生改变,防止生成重复的超参数集 v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0) - for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300) - hyp[k] = float(x[i + 7] * v[i]) # mutate + for i, k in enumerate(hyp.keys()): + hyp[k] = float(x[i + 7] * v[i]) - # Constrain to limits + # 约束超参数在设定的范围内 for k, v in meta.items(): - hyp[k] = max(hyp[k], v[1]) # lower limit - hyp[k] = min(hyp[k], v[2]) # upper limit - hyp[k] = round(hyp[k], 5) # significant digits + hyp[k] = max(hyp[k], v[1]) # 下限 + hyp[k] = min(hyp[k], v[2]) # 上限 + hyp[k] = round(hyp[k], 5) # 5位有效数字 - # Train mutation + # 变异训练 results = train(hyp.copy(), opt, device, callbacks) - # Write mutation results + #记录结果 print_mutation(results, hyp.copy(), save_dir, opt.bucket) - # Plot results + # 绘制结果 plot_evolve(evolve_csv) LOGGER.info(f'Hyperparameter evolution finished\n' f"Results saved to {colorstr('bold', save_dir)}\n"