You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.9 KiB
76 lines
2.9 KiB
11 months ago
|
import time
|
||
|
import pandas as pd
|
||
|
import py_entitymatching as em
|
||
|
import py_entitymatching.catalog.catalog_manager as cm
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from md_discovery.md_mining import mining
|
||
|
from settings import *
|
||
|
|
||
|
|
||
|
def blocking_mining():
|
||
|
start = time.time()
|
||
|
ltable = pd.read_csv(ltable_path, encoding='ISO-8859-1')
|
||
|
cm.set_key(ltable, ltable_id)
|
||
|
rtable = pd.read_csv(rtable_path, encoding='ISO-8859-1')
|
||
|
cm.set_key(rtable, rtable_id)
|
||
|
mappings = pd.read_csv(mapping_path, encoding='ISO-8859-1')
|
||
|
matching_number = len(mappings)
|
||
|
if ltable_id == rtable_id:
|
||
|
tables_id = rtable_id
|
||
|
attributes = ltable.columns.values.tolist()
|
||
|
lattributes = ['ltable_' + i for i in attributes]
|
||
|
rattributes = ['rtable_' + i for i in attributes]
|
||
|
cm.set_key(ltable, ltable_id)
|
||
|
cm.set_key(rtable, rtable_id)
|
||
|
|
||
|
blocker = em.OverlapBlocker()
|
||
|
candidate = blocker.block_tables(ltable, rtable, ltable_block_attr, rtable_block_attr, allow_missing=True,
|
||
|
l_output_attrs=attributes, r_output_attrs=attributes, n_jobs=1,
|
||
|
overlap_size=1, show_progress=False)
|
||
|
candidate['gold'] = 0
|
||
|
candidate = candidate.reset_index(drop=True)
|
||
|
|
||
|
# 根据mapping表标注数据
|
||
|
candidate_match_rows = []
|
||
|
for t in tqdm(mappings.itertuples()):
|
||
|
mask = ((candidate['ltable_' + ltable_id].isin([getattr(t, 'ltable_id')])) &
|
||
|
(candidate['rtable_' + rtable_id].isin([getattr(t, 'rtable_id')])))
|
||
|
matching_indices = candidate[mask].index
|
||
|
candidate_match_rows.extend(matching_indices.tolist())
|
||
|
match_rows_mask = candidate.index.isin(candidate_match_rows)
|
||
|
candidate.loc[match_rows_mask, 'gold'] = 1
|
||
|
candidate.fillna(value="", inplace=True)
|
||
|
|
||
|
candidate_mismatch = candidate[candidate['gold'] == 0]
|
||
|
candidate_match = candidate[candidate['gold'] == 1]
|
||
|
candidate_mismatch = candidate_mismatch.sample(n=3*len(candidate_match))
|
||
|
candidate_for_train_test = pd.concat([candidate_mismatch, candidate_match])
|
||
|
# 如果拼接后不重设索引可能导致索引重复
|
||
|
candidate_for_train_test = candidate_for_train_test.reset_index(drop=True)
|
||
|
cm.set_key(candidate_for_train_test, '_id')
|
||
|
cm.set_fk_ltable(candidate_for_train_test, 'ltable_' + ltable_id)
|
||
|
cm.set_fk_rtable(candidate_for_train_test, 'rtable_' + rtable_id)
|
||
|
cm.set_ltable(candidate_for_train_test, ltable)
|
||
|
cm.set_rtable(candidate_for_train_test, rtable)
|
||
|
block_recall = len(candidate_match) / matching_number
|
||
|
|
||
|
# 分为训练测试集
|
||
|
train_proportion = 0.5
|
||
|
sets = em.split_train_test(candidate_for_train_test, train_proportion=train_proportion, random_state=0)
|
||
|
train_set = sets['train']
|
||
|
test_set = sets['test']
|
||
|
end_blocking = time.time()
|
||
|
print(end_blocking - start)
|
||
|
|
||
|
mining(train_set)
|
||
|
return 1
|
||
|
|
||
|
|
||
|
def matching():
|
||
|
return 1
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
blocking_mining()
|