You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/hpo/er_model_hpo.py

74 lines
2.4 KiB

import json
from ConfigSpace import Categorical, Configuration, ConfigurationSpace, Integer, Float
1 year ago
from ConfigSpace.conditions import InCondition
from ConfigSpace.read_and_write import json as csj
1 year ago
import py_entitymatching.catalog.catalog_manager as cm
import pandas as pd
from smac import HyperparameterOptimizationFacade, Scenario
from settings import *
1 year ago
from ml_er.ml_entity_resolver import er_process
1 year ago
class Classifier:
1 year ago
@property
def configspace(self) -> ConfigurationSpace:
cs = ConfigurationSpace(seed=0)
1 year ago
ml_matcher = Categorical("ml_matcher", ["dt", "svm", "rf", "lg", "ln", "nb"], default="rf")
# todo 每个分类器的超参数
tree_criterion = Categorical("dt_criterion", ["gini", "entropy", "log_loss"], default="gini")
1 year ago
cs.add_hyperparameters([ml_matcher])
1 year ago
return cs
1 year ago
def train(self, config: Configuration, seed: int = 0) -> float:
1 year ago
cm.del_catalog()
indicators = er_process(config)
return 1-indicators['performance']
1 year ago
1 year ago
def ml_er_hpo():
1 year ago
classifier = Classifier()
cs = classifier.configspace
str_configspace = csj.write(cs)
dict_configspace = json.loads(str_configspace)
with open(hpo_output_dir + "configspace.json", "w") as f:
json.dump(dict_configspace, f, indent=4)
1 year ago
scenario = Scenario(
cs,
1 year ago
deterministic=True,
n_trials=12, # We want to run max 50 trials (combination of config and seed)
n_workers=1
1 year ago
)
1 year ago
initial_design = HyperparameterOptimizationFacade.get_initial_design(scenario, n_configs=5)
1 year ago
smac = HyperparameterOptimizationFacade(
scenario,
classifier.train,
initial_design=initial_design,
overwrite=True, # If the run exists, we overwrite it; alternatively, we can continue from last state
)
incumbent = smac.optimize()
1 year ago
incumbent_cost = smac.validate(incumbent)
default = cs.get_default_configuration()
default_cost = smac.validate(default)
print(f"Default Cost: {default_cost}")
print(f"Incumbent Cost: {incumbent_cost}")
1 year ago
1 year ago
if incumbent_cost > default_cost:
incumbent = default
print(f"Updated Incumbent Cost: {default_cost}")
1 year ago
1 year ago
print(f"Optimized Configuration:{incumbent.values()}")
with open(hpo_output_dir + "incumbent.json", "w") as f:
json.dump(dict(incumbent), f, indent=4)
1 year ago
return incumbent
1 year ago
if __name__ == '__main__':
ml_er_hpo()