You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/samples_generator.py

115 lines
3.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import random
import pandas as pd
import Levenshtein
import ml_er.ml_entity_resolver
def my_Levenshtein_ratio(str1, str2):
if max(len(str1), len(str2)) == 0:
return 1
return 1 - Levenshtein.distance(str1, str2) / max(len(str1), len(str2))
def load_mds(paths: list) -> list:
if len(paths) == 0:
return []
all_mds = []
# 传入md路径列表
for md_path in paths:
if not os.path.exists(md_path):
continue
mds = []
# 打开每一个md文件
with open(md_path, 'r') as f:
# 读取每一行的md加入该文件的md列表
for line in f.readlines():
md_metadata = line.strip().split('\t')
md = eval(md_metadata[0].replace('md:', ''))
confidence = eval(md_metadata[2].replace('confidence:', ''))
if confidence > 0:
mds.append(md)
all_mds.extend(mds)
return all_mds
# 输入: md地址列表/预测表地址/随机生成次数
# 输出: 一些正样本(带gold列不带prediction列)
def generate_samples(md_path_list, pred_path, count: int):
all_mds = load_mds(md_path_list)
predictions = pd.read_csv(pred_path, low_memory=False, encoding='ISO-8859-1')
predictions.fillna("", inplace=True)
predictions = predictions.astype(str)
pred_attrs = predictions.columns.values.tolist() # 预测表中的字段,带前缀,包括gold和predict
attrs = [] # 不带前缀的字段,不包括gold和predict
l_attrs = []
r_attrs = []
for _ in pred_attrs:
if _.startswith('ltable_'):
attrs.append(_.replace('ltable_', ''))
l_attrs.append(_)
elif _.startswith('rtable'):
r_attrs.append(_)
fp = predictions[(predictions['gold'] == '0') & (predictions['predicted'] == '1')]
fn = predictions[(predictions['gold'] == '1') & (predictions['predicted'] == '0')]
fpl = fp[l_attrs]
fpr = fp[r_attrs]
# 将左右两部分字段名统一
fpl.columns = attrs
fpr.columns = attrs
fnl = fn[l_attrs]
fnr = fn[r_attrs]
fnl.columns = attrs
fnr.columns = attrs
fp = pd.concat([fpl, fpr])
fn = pd.concat([fnl, fnr])
df = pd.concat([fp, fn])
length = len(df)
result = pd.DataFrame()
for i in range(0, count):
dic = {}
for _ in attrs:
if _ == 'id':
index = random.randint(0, length-1)
value = df.iloc[index]['id']
dic['ltable_'+_] = value
dic['rtable_'+_] = value
else:
index1 = random.randint(0, length-1)
index2 = random.randint(0, length-1)
value1 = df.iloc[index1][_]
value2 = df.iloc[index2][_]
dic['ltable_'+_] = value1
dic['rtable_'+_] = value2
for md in all_mds:
satis = True
for _ in attrs:
if my_Levenshtein_ratio(str(dic['ltable_'+_]), str(dic['rtable_'+_])) < md[_]:
satis = False
break
if satis:
series = pd.Series(dic)
result = result._append(series, ignore_index=True)
result['gold'] = 1
return result
# 判断字典是否满足某条md,满足则转为series插入dataframe(初始为空)
if __name__ == '__main__':
md_paths = ['/home/w/PycharmProjects/matching_dependency/md_discovery/output/tp_mds.txt',
'/home/w/PycharmProjects/matching_dependency/md_discovery/output/fn_mds.txt',
'/home/w/PycharmProjects/matching_dependency/md_discovery/output/tp_vio.txt',
'/home/w/PycharmProjects/matching_dependency/md_discovery/output/fn_vio.txt']
pre_p = '/home/w/pred.csv'
generate_samples(md_paths, pre_p, 10000)
# 随机生成次数写个一千一万都没问题