解决很多路径问题

master
HuangJintao 5 months ago
parent 894b69e9a7
commit a317d886f8

@ -1,3 +1,6 @@
import sys
sys.path.append('/root/hjt/md_bayesian_er_ditto/')
import json import json
import time import time
from colorama import init, Fore from colorama import init, Fore
@ -8,8 +11,6 @@ from smac import Scenario, BlackBoxFacade
from ml_er.ditto_er import matching from ml_er.ditto_er import matching
from setting import hpo_output_dir from setting import hpo_output_dir
import sys
sys.path.append('/root/hjt/md_bayesian_er_ditto/')
class Optimization: class Optimization:
@ -38,8 +39,8 @@ class Optimization:
return cs return cs
# todo train函数 # todo train函数
def train(self, config: Configuration, seed: int = 0, ) -> float: def train(self, hpo_config: Configuration, seed: int = 0, ) -> float:
indicators = matching(config) indicators = matching(hpo_config)
return 1 - indicators['performance'] return 1 - indicators['performance']
@ -49,7 +50,7 @@ def ml_er_hpo():
str_configspace = csj.write(cs) str_configspace = csj.write(cs)
dict_configspace = json.loads(str_configspace) dict_configspace = json.loads(str_configspace)
# 将超参数空间保存本地 # 将超参数空间保存本地
with open(hpo_output_dir + r"\configspace.json", "w") as f: with open(hpo_output_dir + "/configspace.json", "w") as f:
json.dump(dict_configspace, f, indent=4) json.dump(dict_configspace, f, indent=4)
scenario = Scenario( scenario = Scenario(
@ -82,7 +83,7 @@ def ml_er_hpo():
print(Fore.BLUE + f"Optimized Configuration:{incumbent.values()}") print(Fore.BLUE + f"Optimized Configuration:{incumbent.values()}")
with open(hpo_output_dir + r"\incumbent.json", "w") as f: with open(hpo_output_dir + "/incumbent.json", "w") as f:
json.dump(dict(incumbent), f, indent=4) json.dump(dict(incumbent), f, indent=4)
return incumbent return incumbent

@ -1,52 +1,69 @@
import os
import sys
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
sys.path.append('/root/hjt/md_bayesian_er_ditto/')
import pdb
import pickle import pickle
import torch import torch
import json import json
import numpy as np import numpy as np
import random import random
# from ditto.matcher import *
from setting import * from setting import *
from colorama import Fore from colorama import Fore
from argparse import Namespace from argparse import Namespace
import ConfigSpace import ConfigSpace
from ConfigSpace import Configuration from ConfigSpace import Configuration
from ditto.matcher import set_seed, to_str, classify, predict, tune_threshold, load_model from ditto.matcher import set_seed, predict, tune_threshold, load_model
from ConfigSpace.read_and_write import json as csj from ConfigSpace.read_and_write import json as csj
from ditto.ditto_light.dataset import DittoDataset from ditto.ditto_light.dataset import DittoDataset
from ditto.ditto_light.summarize import Summarizer from ditto.ditto_light.summarize import Summarizer
from ditto.ditto_light.knowledge import * from ditto.ditto_light.knowledge import *
from ditto.ditto_light.ditto import train from ditto.ditto_light.ditto import train
import os
import sys
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
sys.path.append('/root/hjt/md_bayesian_er_ditto/')
def matching(hpo_config):
print(Fore.BLUE + f'Config: {hpo_config}')
def matching(config): with open(md_output_dir + "/mds.pickle", "rb") as file:
print(Fore.BLUE + f'Config: {config}') md_list = pickle.load(file)
# with open(md_output_dir + r"\mds.pickle", "rb") as file:
# md_list = pickle.load(file)
hp = Namespace() hp = Namespace()
hp.task = directory_path.replace('/root/hjt/DeepMatcher Dataset/', '') hp.task = directory_path.replace('/root/hjt/DeepMatcher Dataset/', '')
# only a single task for baseline
task = hp.task
# load task configuration
configs = json.load(open('../ditto/configs.json'))
configs = {conf['name']: conf for conf in configs}
config = configs[task]
config['trainset'] = '/root/hjt/md_bayesian_er_ditto/ditto/' + config['trainset']
config['validset'] = '/root/hjt/md_bayesian_er_ditto/ditto/' + config['validset']
config['testset'] = '/root/hjt/md_bayesian_er_ditto/ditto/' + config['testset']
trainset = config['trainset']
validset = config['validset']
testset = config['testset']
hp.run_id = 0 hp.run_id = 0
hp.batch_size = config['batch_size'] hp.batch_size = hpo_config['batch_size']
hp.max_len = config['max_len'] hp.max_len = hpo_config['max_len']
hp.lr = 3e-5 hp.lr = 3e-5
hp.n_epochs = 20 hp.n_epochs = 20
# hp.finetuning # hp.finetuning
hp.save_model = True hp.save_model = True
hp.input_path = '/root/autodl-tmp/input/candidates_small.jsonl' hp.input_path = config['testset']
hp.output_path = '/root/autodl-tmp/output/matched_small.jsonl' hp.output_path = '/root/autodl-tmp/output/matched_small.jsonl'
hp.logdir = '/root/autodl-tmp/checkpoints/' hp.logdir = '/root/autodl-tmp/checkpoints/'
hp.checkpoint_path = '/root/autodl-tmp/checkpoints/' hp.checkpoint_path = '/root/autodl-tmp/checkpoints/'
hp.lm = config['language_model'] hp.lm = hpo_config['language_model']
hp.fp16 = config['half_precision_float'] hp.fp16 = hpo_config['half_precision_float']
hp.da = config['data_augmentation'] hp.da = hpo_config['data_augmentation']
hp.alpha_aug = 0.8 hp.alpha_aug = 0.8
hp.dk = None hp.dk = None
hp.summarize = config['summarize'] hp.summarize = hpo_config['summarize']
hp.size = None hp.size = None
hp.use_gpu = True hp.use_gpu = True
@ -57,23 +74,11 @@ def matching(config):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
# only a single task for baseline
task = hp.task
# create the tag of the run # create the tag of the run
run_tag = '%s_lm=%s_da=%s_dk=%s_su=%s_size=%s_id=%d' % (task, hp.lm, hp.da, run_tag = '%s_lm=%s_da=%s_dk=%s_su=%s_size=%s_id=%d' % (task, hp.lm, hp.da,
hp.dk, hp.summarize, str(hp.size), hp.run_id) hp.dk, hp.summarize, str(hp.size), hp.run_id)
run_tag = run_tag.replace('/', '_') run_tag = run_tag.replace('/', '_')
# load task configuration
configs = json.load(open('configs.json'))
configs = {conf['name']: conf for conf in configs}
config = configs[task]
trainset = config['trainset']
validset = config['validset']
testset = config['testset']
# summarize the sequences up to the max sequence length # summarize the sequences up to the max sequence length
if hp.summarize: if hp.summarize:
summarizer = Summarizer(config, lm=hp.lm) summarizer = Summarizer(config, lm=hp.lm)
@ -101,6 +106,7 @@ def matching(config):
hp.lm, hp.use_gpu, hp.fp16) hp.lm, hp.use_gpu, hp.fp16)
summarizer = dk_injector = None summarizer = dk_injector = None
pdb.set_trace()
if hp.summarize: if hp.summarize:
summarizer = Summarizer(config, hp.lm) summarizer = Summarizer(config, hp.lm)
@ -117,12 +123,14 @@ def matching(config):
# todo indicators # todo indicators
# write results # write results
# interpretability # interpretability
indicators = {}
return indicators
# todo ml_er function # todo ml_er function
def ml_er(config: Configuration): def ml_er(config: Configuration):
indicators = matching(config) indicators = matching(config)
output_path = er_output_dir + r"\eval_result.txt" output_path = er_output_dir + "/eval_result.txt"
with open(output_path, 'w') as _f: with open(output_path, 'w') as _f:
_f.write('F1:' + str(indicators["F1"]) + '\n') _f.write('F1:' + str(indicators["F1"]) + '\n')
_f.write('interpretability:' + str(indicators['interpretability']) + '\n') _f.write('interpretability:' + str(indicators['interpretability']) + '\n')
@ -130,12 +138,12 @@ def ml_er(config: Configuration):
if __name__ == '__main__': if __name__ == '__main__':
if os.path.isfile(hpo_output_dir + r"\incumbent.json"): if os.path.isfile(hpo_output_dir + "/incumbent.json"):
with open(hpo_output_dir + r"\configspace.json", 'r') as f: with open(hpo_output_dir + "/configspace.json", 'r') as f:
dict_configspace = json.load(f) dict_configspace = json.load(f)
str_configspace = json.dumps(dict_configspace) str_configspace = json.dumps(dict_configspace)
configspace = csj.read(str_configspace) configspace = csj.read(str_configspace)
with open(hpo_output_dir + r"\incumbent.json", 'r') as f: with open(hpo_output_dir + "/incumbent.json", 'r') as f:
dic = json.load(f) dic = json.load(f)
configuration = ConfigSpace.Configuration(configspace, values=dic) configuration = ConfigSpace.Configuration(configspace, values=dic)
ml_er(configuration) ml_er(configuration)

Loading…
Cancel
Save