7.4:动手实现k-均值
import numpy as np
def euclidean_distance(one_sample, X):
'''
input:
one_sample(ndarray):单个样本
X(ndarray):所有样本
output:
distances(ndarray):单个样本到所有样本的欧氏距离平方
'''
one_sample = one_sample.reshape(1, -1)
distances = np.power(np.tile(one_sample, (X.shape[0], 1)) - X, 2).sum(axis=1)
return distances
def init_random_centroids(k,X):
'''
input:
k(int):聚类簇的个数
X(ndarray):所有样本
output:
centroids(ndarray):k个簇的聚类中心
'''
n_samples, n_features = np.shape(X)
centroids = np.zeros((k, n_features))
for i in range(k):
centroid = X[np.random.choice(range(n_samples))]
centroids[i] = centroid
return centroids
def _closest_centroid(sample, centroids):
'''
input:
sample(ndarray):单个样本
centroids(ndarray):k个簇的聚类中心
output:
closest_i(int):最近中心的索引
'''
distances = euclidean_distance(sample, centroids)
closest_i = np.argmin(distances)
return closest_i
def create_clusters(k,centroids, X):
'''
input:
k(int):聚类簇的个数
centroids(ndarray):k个簇的聚类中心
X(ndarray):所有样本
output:
clusters(list):列表中有k个元素,每个元素保存相同簇的样本的索引
'''
clusters = [[] for _ in range(k)]
for sample_i, sample in enumerate(X):
centroid_i = _closest_centroid(sample, centroids)
clusters[centroid_i].append(sample_i)
return clusters
def update_centroids(k,clusters, X):
'''
input:
k(int):聚类簇的个数
X(ndarray):所有样本
output:
centroids(ndarray):k个簇的聚类中心
'''
n_features = np.shape(X)[1]
centroids = np.zeros((k, n_features))
for i, cluster in enumerate(clusters):
centroid = np.mean(X[cluster], axis=0)
centroids[i] = centroid
return centroids
def get_cluster_labels(clusters, X):
'''
input:
clusters(list):列表中有k个元素,每个元素保存相同簇的样本的索引
X(ndarray):所有样本
output:
y_pred(ndarray):所有样本的类别标签
'''
y_pred = np.zeros(np.shape(X)[0])
for cluster_i, cluster in enumerate(clusters):
for sample_i in cluster:
y_pred[sample_i] = cluster_i
return y_pred
def predict(k,X,max_iterations,varepsilon):
'''
input:
k(int):聚类簇的个数
X(ndarray):所有样本
max_iterations(int):最大训练轮数
varepsilon(float):最小误差阈值
output:
y_pred(ndarray):所有样本的类别标签
'''
centroids = init_random_centroids(k,X)
for _ in range(max_iterations):
clusters = create_clusters(k,centroids, X)
former_centroids = centroids
centroids = update_centroids(k,clusters, X)
diff = centroids - former_centroids
if diff.any() < varepsilon:
break
y_pred = get_cluster_labels(clusters, X)
return y_pred