You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
3.0 KiB
94 lines
3.0 KiB
5 months ago
|
import json
|
||
|
import time
|
||
|
from colorama import init, Fore
|
||
|
from ConfigSpace import Categorical, Configuration, ConfigurationSpace, Integer, Float
|
||
|
from ConfigSpace.conditions import InCondition, EqualsCondition, AndConjunction
|
||
|
from ConfigSpace.read_and_write import json as csj
|
||
|
from smac import Scenario, BlackBoxFacade
|
||
|
|
||
|
from ml_er.ditto_er import matching
|
||
|
from setting import hpo_output_dir
|
||
|
import sys
|
||
|
sys.path.append('/root/hjt/md_bayesian_er_ditto/')
|
||
|
|
||
|
|
||
|
class Optimization:
|
||
|
@property
|
||
|
def configspace(self) -> ConfigurationSpace:
|
||
|
cs = ConfigurationSpace(seed=0)
|
||
|
|
||
|
# task
|
||
|
# run_id
|
||
|
batch_size = Categorical('batch_size', [32, 64], default=64)
|
||
|
max_len = Categorical('max_len', [64, 128, 256], default=256)
|
||
|
# lr 3e-5
|
||
|
# n_epochs 20
|
||
|
# fine_tuning
|
||
|
# save_model
|
||
|
# logdir
|
||
|
lm = Categorical('language_model', ['distilbert', 'roberta', 'bert-base-uncased', 'xlnet-base-cased'], default='distilbert')
|
||
|
fp16 = Categorical('half_precision_float', [True, False])
|
||
|
da = Categorical('data_augmentation', ['del', 'swap', 'drop_col', 'append_col', 'all'])
|
||
|
# alpha_aug
|
||
|
# dk
|
||
|
summarize = Categorical('summarize', [True, False])
|
||
|
# size
|
||
|
|
||
|
cs.add_hyperparameters([batch_size, max_len, lm, fp16, da, summarize])
|
||
|
return cs
|
||
|
|
||
|
# todo train函数
|
||
|
def train(self, config: Configuration, seed: int = 0, ) -> float:
|
||
|
indicators = matching(config)
|
||
|
return 1 - indicators['performance']
|
||
|
|
||
|
|
||
|
def ml_er_hpo():
|
||
|
optimization = Optimization()
|
||
|
cs = optimization.configspace
|
||
|
str_configspace = csj.write(cs)
|
||
|
dict_configspace = json.loads(str_configspace)
|
||
|
# 将超参数空间保存本地
|
||
|
with open(hpo_output_dir + r"\configspace.json", "w") as f:
|
||
|
json.dump(dict_configspace, f, indent=4)
|
||
|
|
||
|
scenario = Scenario(
|
||
|
cs,
|
||
|
crash_cost=1.0,
|
||
|
deterministic=True,
|
||
|
n_trials=16,
|
||
|
n_workers=1
|
||
|
)
|
||
|
|
||
|
initial_design = BlackBoxFacade.get_initial_design(scenario, n_configs=5)
|
||
|
|
||
|
smac = BlackBoxFacade(
|
||
|
scenario,
|
||
|
optimization.train,
|
||
|
initial_design=initial_design,
|
||
|
overwrite=True, # If the run exists, we overwrite it; alternatively, we can continue from last state
|
||
|
)
|
||
|
|
||
|
incumbent = smac.optimize()
|
||
|
incumbent_cost = smac.validate(incumbent)
|
||
|
default = cs.get_default_configuration()
|
||
|
default_cost = smac.validate(default)
|
||
|
print(Fore.BLUE + f"Default Cost: {default_cost}")
|
||
|
print(Fore.BLUE + f"Incumbent Cost: {incumbent_cost}")
|
||
|
|
||
|
if incumbent_cost > default_cost:
|
||
|
incumbent = default
|
||
|
print(Fore.RED + f'Updated Incumbent Cost: {default_cost}')
|
||
|
|
||
|
print(Fore.BLUE + f"Optimized Configuration:{incumbent.values()}")
|
||
|
|
||
|
with open(hpo_output_dir + r"\incumbent.json", "w") as f:
|
||
|
json.dump(dict(incumbent), f, indent=4)
|
||
|
return incumbent
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
init(autoreset=True)
|
||
|
print(Fore.CYAN + f'Start Time: {time.time()}')
|
||
|
ml_er_hpo()
|