You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.0 KiB

import json
import time
from colorama import init, Fore
from ConfigSpace import Categorical, Configuration, ConfigurationSpace, Integer, Float
from ConfigSpace.conditions import InCondition, EqualsCondition, AndConjunction
from ConfigSpace.read_and_write import json as csj
from smac import Scenario, BlackBoxFacade
from ml_er.ditto_er import matching
from setting import hpo_output_dir
import sys
sys.path.append('/root/hjt/md_bayesian_er_ditto/')
class Optimization:
@property
def configspace(self) -> ConfigurationSpace:
cs = ConfigurationSpace(seed=0)
# task
# run_id
batch_size = Categorical('batch_size', [32, 64], default=64)
max_len = Categorical('max_len', [64, 128, 256], default=256)
# lr 3e-5
# n_epochs 20
# fine_tuning
# save_model
# logdir
lm = Categorical('language_model', ['distilbert', 'roberta', 'bert-base-uncased', 'xlnet-base-cased'], default='distilbert')
fp16 = Categorical('half_precision_float', [True, False])
da = Categorical('data_augmentation', ['del', 'swap', 'drop_col', 'append_col', 'all'])
# alpha_aug
# dk
summarize = Categorical('summarize', [True, False])
# size
cs.add_hyperparameters([batch_size, max_len, lm, fp16, da, summarize])
return cs
# todo train函数
def train(self, config: Configuration, seed: int = 0, ) -> float:
indicators = matching(config)
return 1 - indicators['performance']
def ml_er_hpo():
optimization = Optimization()
cs = optimization.configspace
str_configspace = csj.write(cs)
dict_configspace = json.loads(str_configspace)
# 将超参数空间保存本地
with open(hpo_output_dir + r"\configspace.json", "w") as f:
json.dump(dict_configspace, f, indent=4)
scenario = Scenario(
cs,
crash_cost=1.0,
deterministic=True,
n_trials=16,
n_workers=1
)
initial_design = BlackBoxFacade.get_initial_design(scenario, n_configs=5)
smac = BlackBoxFacade(
scenario,
optimization.train,
initial_design=initial_design,
overwrite=True, # If the run exists, we overwrite it; alternatively, we can continue from last state
)
incumbent = smac.optimize()
incumbent_cost = smac.validate(incumbent)
default = cs.get_default_configuration()
default_cost = smac.validate(default)
print(Fore.BLUE + f"Default Cost: {default_cost}")
print(Fore.BLUE + f"Incumbent Cost: {incumbent_cost}")
if incumbent_cost > default_cost:
incumbent = default
print(Fore.RED + f'Updated Incumbent Cost: {default_cost}')
print(Fore.BLUE + f"Optimized Configuration:{incumbent.values()}")
with open(hpo_output_dir + r"\incumbent.json", "w") as f:
json.dump(dict(incumbent), f, indent=4)
return incumbent
if __name__ == '__main__':
init(autoreset=True)
print(Fore.CYAN + f'Start Time: {time.time()}')
ml_er_hpo()