You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/md_discovery/discovery_executor.py

201 lines
7.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import operator
import numpy as np
import pandas as pd
import copy
import torch
from ConfigSpace import Configuration
from tqdm import tqdm
from settings import model, similarity_threshold, support_threshold, confidence_threshold
def is_minimal(md, md_list, target_col):
# 假设这个md是minimal
if len(md_list) == 0:
return True
minimal = True
for _ in md_list:
if isinstance(_, tuple):
_ = _[0]
if _ != md:
other_cols = list(set(_.keys()) - {target_col})
# 假设列表中每一个md都使当前md不minimal
exist = True
# 如果左边任何一个大于,则假设不成立
for col in other_cols:
if _[col] > md[col]:
exist = False
break
# 如果右边小于,假设也不成立
if _[target_col] < md[target_col]:
exist = False
# 任何一次假设成立当前md不minimal
if exist:
minimal = False
break
return minimal
def pairs_inference(path, target_col):
data = pd.read_csv(path, low_memory=False, encoding='ISO-8859-1')
data.fillna("", inplace=True)
data = data.astype(str)
columns = data.columns.values.tolist()
target_index = columns.index(target_col)
cols_but_target = list(set(columns) - {target_col})
length = data.shape[0]
width = data.shape[1]
sentences = []
for col in range(0, width):
for row in range(0, length):
cell_value = data.values[row, col]
sentences.append(cell_value)
embedding = model.encode(sentences, convert_to_tensor=True, device="cuda")
split_embedding = torch.split(embedding, length, dim=0)
table_tensor = torch.stack(split_embedding, dim=0, out=None)
norm_table_tensor = torch.nn.functional.normalize(table_tensor, dim=2)
sim_tensor = torch.matmul(norm_table_tensor, norm_table_tensor.transpose(1, 2))
# sim_tensor = torch.round(sim_tensor, decimals=3)
# torch.save(sim_tensor, md_output_dir + "tensor.pt")
md_list = []
minimal_vio = []
init_md = {}
for col in columns:
init_md[col] = 1 if col == target_col else -1
md_list.append(init_md)
for row1 in tqdm(range(0, length - 1)):
terminate = False
for row2 in range(row1 + 1, length):
violated_mds = []
# sims是两行的相似度
sims = {}
for col_index in range(0, width):
col = columns[col_index]
similarity = sim_tensor[col_index, row1, row2].item()
sims[col] = similarity
# 寻找violated md,从md列表中删除并加入vio列表
for md in md_list[:]:
lhs_satis = True
rhs_satis = True
for col in cols_but_target:
if sims[col] < md[col]:
lhs_satis = False
break
if sims[target_col] < md[target_col]:
rhs_satis = False
if lhs_satis == True and rhs_satis == False:
md_list.remove(md)
violated_mds.append(md)
# for vio_md in violated_mds:
# # 特殊化左侧
# for col in cols_but_target:
# if sims[col] + 0.01 <= 1:
# spec_l_md = copy.deepcopy(vio_md)
# spec_l_md[col] = simt if sims[col] < simt else sims[col] + 0.01
# if is_minimal(spec_l_md, md_list, target_col):
# md_list.append(spec_l_md)
# if vio_md not in minimal_vio:
# minimal_vio.append(vio_md)
for vio_md in violated_mds:
vio_md_support, vio_md_confidence = get_metrics(vio_md, data, sim_tensor, target_col, target_index)
if vio_md_support >= support_threshold:
for col in cols_but_target:
if sims[col] < 1.0:
spec_l_md = copy.deepcopy(vio_md)
if sims[col] < similarity_threshold:
spec_l_md[col] = similarity_threshold
else:
if sims[col] + 0.01 <= 1.0:
spec_l_md[col] = sims[col] + 0.01
else:
spec_l_md[col] = 1.0
if is_minimal(spec_l_md, md_list, target_col):
md_list.append(spec_l_md)
if vio_md not in minimal_vio:
minimal_vio.append(vio_md)
if len(md_list) == 0:
terminate = True
break
if terminate:
break
if len(minimal_vio) > 0:
for md in minimal_vio[:]:
support, confidence = get_metrics(md, data, sim_tensor, target_col, target_index)
if support >= support_threshold and confidence >= confidence_threshold:
minimal_vio.append((md, support, confidence))
minimal_vio.remove(md)
if len(md_list) > 0:
# 去除重复MD
tmp = []
for _ in md_list:
if _ not in tmp:
tmp.append(_)
md_list = tmp
# 去除support小于阈值MD
for _ in md_list[:]:
support, confidence = get_metrics(_, data, sim_tensor, target_col, target_index)
if support >= support_threshold and confidence >= confidence_threshold:
md_list.append((_, support, confidence))
md_list.remove(_)
# 去除不minimal的MD
for md_tuple in md_list[:]:
if not is_minimal(md_tuple[0], md_list, target_col) and md_tuple[2] < 0.5:
md_list.remove(md_tuple)
if len(minimal_vio) > 0:
for vio_tuple in minimal_vio[:]:
if not is_minimal(vio_tuple[0], md_list, target_col) and vio_tuple[2] < 0.5:
minimal_vio.remove(vio_tuple)
if len(minimal_vio) > 0:
for vio_tuple in minimal_vio[:]:
if not is_minimal(vio_tuple[0], minimal_vio, target_col) and vio_tuple[2] < 0.5:
minimal_vio.remove(vio_tuple)
result = []
result.extend(md_list)
result.extend(minimal_vio)
result.sort(key=operator.itemgetter(2), reverse=True)
print(f'\033[33mList Length: {len(result)}\033[0m')
return result
def get_metrics(current_md, data, sim_tensor, target_col, target_index):
columns = data.columns.values.tolist()
length = data.shape[0]
width = data.shape[1]
md_tensor = list(current_md.values())
md_tensor = torch.tensor(md_tensor, device='cuda')
md_tensor_2d = md_tensor.unsqueeze(1)
md_tensor_3d = md_tensor_2d.unsqueeze(2)
md_tensor_3d = md_tensor_3d.repeat(1, length, length)
sim_tensor = torch.round(sim_tensor, decimals=4)
sup_tensor = torch.ge(sim_tensor, md_tensor_3d)
ini_slice = torch.ones((length, length), dtype=torch.bool, device='cuda')
for i in range(0, width):
if i != target_index:
sup_tensor_slice = sup_tensor[i]
ini_slice = torch.logical_and(ini_slice, sup_tensor_slice)
sup_tensor_int = ini_slice.int()
support_Naumann = torch.count_nonzero(sup_tensor_int).item()
support_Naumann = (support_Naumann - length) / 2
ini_slice = torch.logical_and(ini_slice, sup_tensor[target_index])
conf_tensor_int = ini_slice.int()
support_Fan = torch.count_nonzero(conf_tensor_int).item()
support_Fan = (support_Fan - length) / 2
confidence = support_Fan / support_Naumann if support_Naumann > 0 else 0
return support_Fan, confidence