HuangJintao
HuangJintao 1 year ago
parent 24985da169
commit ea3b247da0

@ -0,0 +1,27 @@
import csv
import pandas as pd
import json
import sentence_transformers.util
import torch
from sentence_transformers import SentenceTransformer
from torch import nn
if __name__ == '__main__':
# model = SentenceTransformer('E:\\Data\\Research\\Models\\roberta-large-nli-stsb-mean-tokens')
# sentences = ['公积金转入深圳', '公积金转出深圳', None, None, 114514, 114514, 1919810]
# embedding = model.encode(sentences, device='cuda')
# outcome1 = sentence_transformers.util.cos_sim(embedding[4], embedding[5])
# outcome2 = sentence_transformers.util.cos_sim(embedding[4], embedding[6])
# print(outcome1.item())
# print(outcome2.item())
train = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\train.csv', encoding='ISO-8859-1')
valid = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\valid.csv', encoding='ISO-8859-1')
test = pd.read_csv(r'E:\Data\Research\Datasets\DeepMatcher dataset\Dirty\Walmart-Amazon\test.csv', encoding='ISO-8859-1')
train = train[train['label'] == 1]
valid = valid[valid['label'] == 1]
test = test[test['label'] == 1]
matches = pd.concat([train, valid, test])
matches.drop(columns=['label'], inplace=True)
matches = matches.sort_values(by='ltable_id')
matches.to_csv(r'E:\Data\Research\Projects\matching_dependency\datasets\Walmart-Amazon_dirty\matches.csv', sep=',', index=False, header=True)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,55 @@
import os
import pyecharts
from pyecharts.charts import Line
from pyecharts import options as opts
from pyecharts.globals import ThemeType
if __name__ == '__main__':
dir_path = r'E:\Data\Research\Outcome\Abt-Buy'
filename_list = os.listdir(dir_path)
iter_list = []
precision = []
recall = []
f1 = []
interpretability = []
performance = []
for _ in filename_list:
if _.startswith('eval_result'):
it = int(_[12:13])
iter_list.append(str(it))
with open(dir_path + '\\' + _, 'r') as f:
# 读取每一行的md加入该文件的md列表
for line in f.readlines():
if line.startswith('Precision'):
lt = line.split(' ')
value = float(lt[2].replace('%', ''))/100
precision.append(value)
elif line.startswith('Recall'):
lt = line.split(' ')
value = float(lt[2].replace('%', ''))/100
recall.append(value)
elif line.startswith('F1'):
lt = line.split(' ')
value = float(lt[2].replace('%', ''))/100
f1.append(value)
elif line.startswith('interpretability'):
lt = line.split(':')
value = float(lt[1])
interpretability.append(value)
elif line.startswith('performance'):
lt = line.split(':')
value = float(lt[1])
performance.append(value)
line = (
Line(init_opts=opts.InitOpts(theme=ThemeType.LIGHT))
.add_xaxis(iter_list)
.add_yaxis('Precision', precision)
.add_yaxis('Recall', recall)
.add_yaxis('F1', f1)
.add_yaxis('Interpretability', interpretability)
.add_yaxis('Performance', performance)
.set_global_opts(title_opts=opts.TitleOpts(title=dir_path.split('\\')[-1]))
)
line.render(dir_path + '\\' + "line.html")

@ -17,17 +17,17 @@ from ml_er.ml_entity_resolver import evaluate_prediction, load_mds, is_explicabl
# 数据在外部加载
########################################################################################################################
ltable = pd.read_csv(ltable_path, encoding='ISO-8859-1')
# ltable.fillna("", inplace=True)
ltable.fillna("", inplace=True)
rtable = pd.read_csv(rtable_path, encoding='ISO-8859-1')
# rtable.fillna("", inplace=True)
rtable.fillna("", inplace=True)
mappings = pd.read_csv(mapping_path)
lid_mapping_list = []
rid_mapping_list = []
# 全部转为字符串
# ltable = ltable.astype(str)
# rtable = rtable.astype(str)
# mappings = mappings.astype(str)
ltable = ltable.astype(str)
rtable = rtable.astype(str)
mappings = mappings.astype(str)
matching_number = len(mappings) # 所有阳性样本数商品数据集应为1300
for index, row in mappings.iterrows():
@ -162,8 +162,6 @@ class Classifier:
attrs_after=test_feature_after, show_progress=False)
fit_exclude = ['_id', 'ltable_' + tables_id, 'rtable_' + tables_id, 'gold']
train_feature_vecs.fillna(0, inplace=True)
test_feature_vecs.fillna(0, inplace=True)
matcher.fit(table=train_feature_vecs, exclude_attrs=fit_exclude, target_attr='gold')
test_feature_after.extend(['_id', 'ltable_' + tables_id, 'rtable_' + tables_id])
@ -181,7 +179,7 @@ class Classifier:
predictions_attrs.extend(['gold', 'predicted'])
predictions = predictions[predictions_attrs]
predictions = predictions.reset_index(drop=True)
# predictions = predictions.astype(str)
predictions = predictions.astype(str)
sim_tensor_dict = build_col_pairs_sim_tensor_dict(predictions)
# 默认路径为 "../md_discovery/output/xxx.txt"
@ -197,10 +195,11 @@ class Classifier:
ppre = predictions[predictions['predicted'] == str(1)]
interpretability = epl_match / len(ppre) # 可解释性
if indicators["block_recall"] >= 0.8:
f1 = indicators["F1"]
if (indicators["block_recall"] < 0.8) and (indicators["block_recall"] < indicators["recall"]):
f1 = (2.0 * indicators["precision"] * indicators["block_recall"]) / (
indicators["precision"] + indicators["block_recall"])
else:
f1 = (2.0 * indicators["precision"] * indicators["block_recall"]) / (indicators["precision"] + indicators["block_recall"])
f1 = indicators["F1"]
# if indicators["block_recall"] < 0.8:
# return 1
# f1 = indicators["F1"]

@ -157,9 +157,9 @@ def ml_er(iter_round: int, config: Configuration = None, ):
lid_mapping_list = []
rid_mapping_list = []
# 全部转为字符串
# ltable = ltable.astype(str)
# rtable = rtable.astype(str)
# mappings = mappings.astype(str)
ltable = ltable.astype(str)
rtable = rtable.astype(str)
mappings = mappings.astype(str)
matching_number = len(mappings) # 所有阳性样本数
for index, row in mappings.iterrows():
@ -206,9 +206,9 @@ def ml_er(iter_round: int, config: Configuration = None, ):
config["block_attr"], allow_missing=True,
l_output_attrs=selected_attrs, r_output_attrs=selected_attrs)
else:
matcher = em.RFMatcher(name='RF', random_state=0)
matcher = em.SVMMatcher(name='SVM', random_state=0)
blocker = em.OverlapBlocker()
candidate = blocker.block_tables(selected_ltable, selected_rtable, items_but_id[0], items_but_id[0],
candidate = blocker.block_tables(selected_ltable, selected_rtable, selected_attrs[-1], selected_attrs[-1],
l_output_attrs=selected_attrs, r_output_attrs=selected_attrs,
overlap_size=1, show_progress=False, allow_missing=True)
@ -229,6 +229,8 @@ def ml_er(iter_round: int, config: Configuration = None, ):
for row in candidate_match_rows:
candidate.loc[row, 'gold'] = 1
candidate.fillna("", inplace=True)
# 裁剪负样本,保持正负样本数量一致
candidate_mismatch = candidate[candidate['gold'] == 0]
candidate_match = candidate[candidate['gold'] == 1]
@ -266,8 +268,6 @@ def ml_er(iter_round: int, config: Configuration = None, ):
attrs_after=test_feature_after, show_progress=False)
fit_exclude = ['_id', 'ltable_' + tables_id, 'rtable_' + tables_id, 'gold']
train_feature_vecs.fillna(0, inplace=True)
test_feature_vecs.fillna(0, inplace=True)
matcher.fit(table=train_feature_vecs, exclude_attrs=fit_exclude, target_attr='gold')
test_feature_after.extend(['_id', 'ltable_' + tables_id, 'rtable_' + tables_id])
predictions = matcher.predict(table=test_feature_vecs, exclude_attrs=test_feature_after,
@ -299,11 +299,11 @@ def ml_er(iter_round: int, config: Configuration = None, ):
df = predictions[predictions['predicted'] == str(1)]
interpretability = epl_match / len(df) # 可解释性
if indicators["block_recall"] >= 0.8:
f1 = indicators["F1"]
else:
if (indicators["block_recall"] < 0.8) and (indicators["block_recall"] < indicators["recall"]):
f1 = (2.0 * indicators["precision"] * indicators["block_recall"]) / (
indicators["precision"] + indicators["block_recall"])
else:
f1 = indicators["F1"]
performance = interpre_weight * interpretability + (1 - interpre_weight) * f1
################################################################################################################

@ -4,16 +4,16 @@ import numpy as np
ltable_path = r'E:\Data\Research\Projects\matching_dependency\datasets\Walmart-Amazon_dirty\tableA.csv'
rtable_path = r'E:\Data\Research\Projects\matching_dependency\datasets\Walmart-Amazon_dirty\tableB.csv'
mapping_path = r'E:\Data\Research\Projects\matching_dependency\datasets\Walmart-Amazon_dirty\matches.csv'
mapping_lid = 'id1' # mapping表中左表id名
mapping_rid = 'id2' # mapping表中右表id名
mapping_lid = 'ltable_id' # mapping表中左表id名
mapping_rid = 'rtable_id' # mapping表中右表id名
ltable_id = 'id' # 左表id字段名称
rtable_id = 'id' # 右表id字段名称
target_attr = 'id' # 进行md挖掘时的目标字段
lr_attrs_map = {} # 如果两个表中存在对应字段名称不一样的情况,将名称加入列表便于调整一致
similarity_threshold = 0.2
support_threshold = 1
confidence_threshold = 0.5
interpre_weight = 0.3 # 可解释性权重
confidence_threshold = 0.4
interpre_weight = 0.4 # 可解释性权重
er_output_dir = 'E:\\Data\\Research\\Projects\\matching_dependency\\ml_er\\output\\'
md_output_dir = 'E:\\Data\\Research\\Projects\\matching_dependency\\md_discovery\\output\\'
hpo_output_dir = 'E:\\Data\\Research\\Projects\\matching_dependency\\hpo\\output\\'

Loading…
Cancel
Save