根据JedAI样例加入数值类超参

main
HuangJintao 9 months ago
parent 253eb0835f
commit d5f60a4c99

@ -1,7 +1,7 @@
import pandas as pd
import json
from time import *
from ConfigSpace import Categorical, Configuration, ConfigurationSpace, Integer, Float
from ConfigSpace import EqualsCondition
from ConfigSpace.read_and_write import json as csj
from smac import HyperparameterOptimizationFacade, Scenario
from settings import *
@ -20,8 +20,12 @@ class Classifier:
jed_blocker = Categorical("jed_blocker",
["Standard", "QGrams", "ExtendedQG", "SuffixArrays", "ExtendedSA"])
qgrams = Integer('qgrams', (3, 10), default=6)
use_qgrams = EqualsCondition(child=qgrams, parent=jed_blocker, value="QGrams")
block_attr = Categorical("block_attr", block_attr_items)
# filter_ratio = Float("filter_ratio", (0.0, 1.0), default=0.8)
block_filtering_ratio = Float("block_filtering_ratio", (0.7, 0.95), default=0.8)
meta_blocker = Categorical("meta_blocker",
["WEP", "WNP", "CEP", "CNP", "BLAST", "RCNP", "RWNP", "CP"])
weighting_scheme = Categorical("weighting_scheme",
@ -35,11 +39,14 @@ class Classifier:
['char_tokenizer', 'word_tokenizer', 'white_space_tokenizer'])
matching_vectorizer = Categorical("matching_vectorizer",
['tfidf', 'tf', 'boolean'])
matching_sim_thresh = Float("similarity_threshold", (0.05, 0.9))
clusteror = Categorical("clusteror_name",
["CCC", "UMC", "EC", "CenterC", "BMC", "MCC", "CC", "CTC", "MCL", "KMAC", "RSRC"])
["CCC", "UMC", "CenterC", "BMC", "MCC", "CC", "CTC", "MCL", "KMAC", "RSRC"])
cs.add_hyperparameters([jed_blocker, block_attr, meta_blocker, weighting_scheme, matching_metric,
matching_tokenizer, matching_vectorizer, clusteror])
matching_tokenizer, matching_vectorizer, clusteror, qgrams, block_filtering_ratio,
matching_sim_thresh])
cs.add_conditions([use_qgrams])
return cs
def train(self, config: Configuration, seed: int = 0) -> float:
@ -57,6 +64,7 @@ def ml_er_hpo():
scenario = Scenario(
cs,
crash_cost=1,
deterministic=True,
n_trials=50, # We want to run max 50 trials (combination of config and seed)
n_workers=1

@ -163,7 +163,7 @@ def er_process(config: Configuration):
case "Standard":
blocker = StandardBlocking()
case "QGrams":
blocker = QGramsBlocking()
blocker = QGramsBlocking(config["qgrams"])
case "ExtendedQG":
blocker = ExtendedQGramsBlocking()
case "SuffixArrays":
@ -177,7 +177,7 @@ def er_process(config: Configuration):
cleaned_blocks = bp.process(blocks, data, tqdm_disable=False)
# block cleaning(optional)
bf = BlockFiltering(ratio=0.8) # todo what is ratio for?
bf = BlockFiltering(ratio=config["block_filtering_ratio"]) # todo what is ratio for?
filtered_blocks = bf.process(cleaned_blocks, data, tqdm_disable=False)
# Comparison Cleaning - Meta Blocking(optional)
@ -208,7 +208,7 @@ def er_process(config: Configuration):
tokenizer=config["matching_tokenizer"],
vectorizer=config["matching_vectorizer"],
qgram=3,
similarity_threshold=0.0
similarity_threshold=config["similarity_threshold"]
)
pairs_graph = em.predict(candidate_pairs_blocks, data, tqdm_disable=True)
@ -221,8 +221,6 @@ def er_process(config: Configuration):
clusteror = ConnectedComponentsClustering()
case "UMC":
clusteror = UniqueMappingClustering()
case "EC":
clusteror = ExactClustering()
case "CenterC":
clusteror = CenterClustering()
case "BMC":
@ -240,7 +238,7 @@ def er_process(config: Configuration):
case "RSRC":
clusteror = RicochetSRClustering()
# 得到预测结果与评估指标
clusters = clusteror.process(pairs_graph, data, similarity_threshold=0.17)
clusters = clusteror.process(pairs_graph, data, similarity_threshold=0.17) # todo cluster sim thresh
matches_dataframe = clusteror.export_to_df(clusters)
matches_dataframe_path = er_output_dir + r'\matches_dataframe.csv'
matches_dataframe.to_csv(matches_dataframe_path, sep=',', index=False, header=True, quoting=1)

@ -11,7 +11,7 @@ target_attr = 'id' # 进行md挖掘时的目标字段
# lr_attrs_map = {} # 如果两个表中存在对应字段名称不一样的情况,将名称加入列表便于调整一致
model = SentenceTransformer('E:\\Data\\Research\\Models\\roberta-large-nli-stsb-mean-tokens')
interpre_weight = 0.5 # 可解释性权重
interpre_weight = 0 # 可解释性权重
similarity_threshold = 0.1
support_threshold = 1
confidence_threshold = 0.25

Loading…
Cancel
Save