将不匹配的数据点也加上聚类

MD-metrics-HPO
HuangJintao 11 months ago
parent 106f3eabf3
commit d03d300f8c

@ -8,16 +8,18 @@ from draw_md_cluster import DBSCAN
from ml_er.ml_entity_resolver import build_col_pairs_sim_tensor_dict from ml_er.ml_entity_resolver import build_col_pairs_sim_tensor_dict
def plot(md_keys_, data_, data_points_, labels_, output_path_): def plot(md_keys_, md_data_, pre_match_points_, pre_mismatch_points_, labels_, output_path_):
clusterNum = len(set(labels_)) clusterNum = len(set(labels_))
fig = plt.figure() fig = plt.figure()
scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown'] scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown']
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
for i in range(-1, clusterNum): for i in range(-1, clusterNum):
colorStyle = scatterColors[i % len(scatterColors)] colorStyle = scatterColors[i % len(scatterColors)]
subCluster = data_[np.where(labels_ == i)] subCluster = md_data_[np.where(labels_ == i)]
ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12) ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12)
ax.scatter(data_points_[:, 0], data_points_[:, 1], data_points_[:, 2], c='#66CCFF', s=12, marker='x') ax.scatter(pre_match_points_[:, 0], pre_match_points_[:, 1], pre_match_points_[:, 2], c='#66CCFF', s=12, marker='x')
if pre_mismatch_points_.shape[0] > 0:
ax.scatter(pre_mismatch_points_[:, 0], pre_mismatch_points_[:, 1], pre_mismatch_points_[:, 2], c='#006666', s=12, marker='x')
ax.set_xlabel(md_keys_[0], rotation=0) # 设置标签角度 ax.set_xlabel(md_keys_[0], rotation=0) # 设置标签角度
ax.set_ylabel(md_keys_[1], rotation=-45) ax.set_ylabel(md_keys_[1], rotation=-45)
ax.set_zlabel(md_keys_[2], rotation=0) ax.set_zlabel(md_keys_[2], rotation=0)
@ -35,32 +37,37 @@ if __name__ == '__main__':
predictions = outcome_path + rf'\{dataset_name}' + config_dir + r'\predictions.csv' # prediction路径 predictions = outcome_path + rf'\{dataset_name}' + config_dir + r'\predictions.csv' # prediction路径
pred = pd.read_csv(predictions) pred = pd.read_csv(predictions)
pred = pred.astype(str) pred = pred.astype(str)
pred = pred[pred['predicted'] == str(1)] # pred = pred[pred['predicted'] == str(1)]
sim_tensor_dict = build_col_pairs_sim_tensor_dict(pred) sim_tensor_dict = build_col_pairs_sim_tensor_dict(pred)
# 选取的三个字段 # 选取的三个字段
md_keys = [] md_keys = []
with open(absolute_path, 'r') as f: with open(absolute_path, 'r') as f:
# 读取每一行的md加入该文件的md列表 # 读取每一行的md加入该文件的md列表
data = [] md_data = []
for line in f.readlines(): for line in f.readlines():
md_metadata = line.strip().split('\t') md_metadata = line.strip().split('\t')
md_tuple = eval(md_metadata[1]) md_tuple = eval(md_metadata[1])
md_keys = list(md_tuple[0].keys())[1:4] md_keys = list(md_tuple[0].keys())[1:4]
md_values = list(md_tuple[0].values()) md_values = list(md_tuple[0].values())
data.append(md_values[1:4]) md_data.append(md_values[1:4])
if len(data) == 10000: if len(md_data) == 10000:
break break
data_points = [] pre_match_points = []
for _ in range(len(pred)): pre_mismatch_points = []
for _ in pred.itertuples():
data_point_value = [] data_point_value = []
for key in md_keys: for key in md_keys:
sim_tensor = sim_tensor_dict[key] sim_tensor = sim_tensor_dict[key]
data_point_value.append(round(float(sim_tensor[_]), 4)) data_point_value.append(round(float(sim_tensor[_[0]]), 4))
data_points.append(data_point_value) if getattr(_, 'predicted') == str(1):
pre_match_points.append(data_point_value)
elif getattr(_, 'predicted') == str(0):
pre_mismatch_points.append(data_point_value)
data = np.array(data, dtype=np.float32) md_data = np.array(md_data, dtype=np.float32)
data_points = np.array(data_points, dtype=np.float32) pre_match_points = np.array(pre_match_points, dtype=np.float32)
labels = DBSCAN(data, 0.5, 30) pre_mismatch_points = np.array(pre_mismatch_points, dtype=np.float32)
labels = DBSCAN(md_data, 0.5, 30)
output_path = outcome_path + rf'\{dataset_name}_MD&data.png' output_path = outcome_path + rf'\{dataset_name}_MD&data.png'
plot(md_keys, data, data_points, labels, output_path) plot(md_keys, md_data, pre_match_points, pre_mismatch_points, labels, output_path)

Loading…
Cancel
Save