You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
matching_dependency/draw/draw_md_cluster.py

87 lines
4.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# https://github.com/GISerWang/Spatio-temporal-Clustering.git
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
def DBSCAN(data_, eps, minPts):
# 获得距离矩阵
disMat = squareform(pdist(data_, metric='euclidean'))
# 获得数据的行和列(一共有n条数据)
n, m = data_.shape
# 将矩阵的中小于eps的数赋予1, 大于eps的数置0, 按行求和, 求核心点坐标的索引
core_points_index = np.where(np.sum(np.where(disMat <= eps, 1, 0), axis=1) >= minPts)[0]
# 初始化类别,-1代表未分类。
labels = np.full((n,), -1)
clusterId = 0
# 遍历所有的核心点
for pointId in core_points_index:
# 如果核心点未被分类,将其作为的种子点,开始寻找相应簇集
if labels[pointId] == -1:
# 首先将点pointId标记为当前类别(即标识为已操作)
labels[pointId] = clusterId
# 然后寻找种子点的eps邻域且没有被分类的点将其放入种子集合
neighbour = np.where((disMat[:, pointId] <= eps) & (labels == -1))[0]
seeds = set(neighbour)
# 通过种子点,开始生长,寻找密度可达的数据点,一直到种子集合为空,一个簇集寻找完毕
while len(seeds) > 0:
# 弹出一个新种子点
newPoint = seeds.pop()
# 将newPoint标记为当前类
labels[newPoint] = clusterId
# 寻找newPoint种子点eps邻域包含自己
queryResults = np.where(disMat[:, newPoint] <= eps)[0]
# 如果newPoint属于核心点那么newPoint是可以扩展的即密度是可以通过newPoint继续密度可达的
if len(queryResults) >= minPts:
# 将邻域内且没有被分类的点压入种子集合
for resultPoint in queryResults:
if labels[resultPoint] == -1:
seeds.add(resultPoint)
# 簇集生长完毕,寻找到一个类别
clusterId = clusterId + 1
return labels
def plotFeature(md_keys_, data_, labels_, output_path_):
clusterNum = len(set(labels_))
fig = plt.figure()
scatterColors = ['black', 'blue', 'green', 'yellow', 'red', 'purple', 'orange', 'brown']
ax = fig.add_subplot(111, projection='3d')
for i in range(-1, clusterNum):
colorStyle = scatterColors[i % len(scatterColors)]
subCluster = data_[np.where(labels_ == i)]
ax.scatter(subCluster[:, 0], subCluster[:, 1], subCluster[:, 2], c=colorStyle, s=12)
ax.set_xlabel(md_keys_[0], rotation=0) # 设置标签角度
ax.set_ylabel(md_keys_[1], rotation=-45)
ax.set_zlabel(md_keys_[2], rotation=0)
plt.title(output_path_.split('\\')[-1].split('.')[0])
plt.savefig(output_path_, dpi=500)
plt.show()
if __name__ == '__main__':
outcome_path = r'E:\Data\Research\Outcome'
config_dir = r'\Magellan+Smac+roberta-large-nli-stsb-mean-tokens+inter-0.5'
dataset_name_list = [f.name for f in os.scandir(outcome_path) if f.is_dir()]
for dataset_name in dataset_name_list:
absolute_path = outcome_path + rf'\{dataset_name}' + config_dir + r'\mds.txt'
md_keys = []
with open(absolute_path, 'r') as f:
# 读取每一行的md加入该文件的md列表
data = []
for line in f.readlines():
md_metadata = line.strip().split('\t')
md_tuple = eval(md_metadata[1])
md_keys = list(md_tuple[0].keys())[1:4]
md_values = list(md_tuple[0].values())
data.append(md_values[1:4])
if len(data) == 10000:
break
data = np.array(data, dtype=np.float32)
labels = DBSCAN(data, 0.5, 30)
output_path = outcome_path + rf'\{dataset_name}.png'
plotFeature(md_keys, data, labels, output_path)