You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/draw/draw_md_cluster.py

80 lines
3.6 KiB

11 months ago
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
def DBSCAN(data_, eps, minPts):
# 获得距离矩阵
disMat = squareform(pdist(data_, metric='euclidean'))
# 获得数据的行和列(一共有n条数据)
n, m = data_.shape
# 将矩阵的中小于eps的数赋予1, 大于eps的数置0, 按行求和, 求核心点坐标的索引
core_points_index = np.where(np.sum(np.where(disMat <= eps, 1, 0), axis=1) >= minPts)[0]
# 初始化类别,-1代表未分类。
labels = np.full((n,), -1)
clusterId = 0
# 遍历所有的核心点
for pointId in core_points_index:
# 如果核心点未被分类,将其作为的种子点,开始寻找相应簇集
if labels[pointId] == -1:
# 首先将点pointId标记为当前类别(即标识为已操作)
labels[pointId] = clusterId
# 然后寻找种子点的eps邻域且没有被分类的点将其放入种子集合
neighbour = np.where((disMat[:, pointId] <= eps) & (labels == -1))[0]
seeds = set(neighbour)
# 通过种子点,开始生长,寻找密度可达的数据点,一直到种子集合为空,一个簇集寻找完毕
while len(seeds) > 0:
# 弹出一个新种子点
newPoint = seeds.pop()
# 将newPoint标记为当前类
labels[newPoint] = clusterId
# 寻找newPoint种子点eps邻域包含自己
queryResults = np.where(disMat[:, newPoint] <= eps)[0]
# 如果newPoint属于核心点那么newPoint是可以扩展的即密度是可以通过newPoint继续密度可达的
if len(queryResults) >= minPts:
# 将邻域内且没有被分类的点压入种子集合
for resultPoint in queryResults:
if labels[resultPoint] == -1:
seeds.add(resultPoint)
# 簇集生长完毕,寻找到一个类别
clusterId = clusterId + 1
return labels
def plotFeature(data_, labels_, output_path_):
clusterNum = len(set(labels_))
fig = plt.figure()
scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown']
ax = fig.add_subplot(111, projection='3d')
for i in range(-1, clusterNum):
colorStyle = scatterColors[i % len(scatterColors)]
subCluster = data_[np.where(labels_ == i)]
ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12)
plt.savefig(output_path_, dpi=500)
plt.show()
if __name__ == '__main__':
outcome_path = r'E:\Data\Research\Outcome'
config_dir = r'\Magellan+Smac+roberta-large-nli-stsb-mean-tokens+inter-0.5'
dataset_name_list = [f.name for f in os.scandir(outcome_path) if f.is_dir()]
for dataset_name in dataset_name_list:
absolute_path = outcome_path + rf'\{dataset_name}' + config_dir + r'\mds.txt'
with open(absolute_path, 'r') as f:
# 读取每一行的md加入该文件的md列表
data = []
for line in f.readlines():
md_metadata = line.strip().split('\t')
md_tuple = eval(md_metadata[1])
md_values = list(md_tuple[0].values())
data.append(md_values[1:4])
if len(data) == 10000:
break
data = np.array(data, dtype=np.float32)
labels = DBSCAN(data, 0.5, 30)
output_path = outcome_path + rf'\{dataset_name}.png'
plotFeature(data, labels, output_path)