You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
4.1 KiB
4.1 KiB
7.3 图像分割
了解数据
既然要对图像进行图像分割,那肯定需要有图像。在这里,为你准备了一张图像,如下图所示:
如果我们想将整张图分割成双黄线,马路,路边,天空这 4 个部分,那么很明显我们可以使用 k均值算法来进行分割,而且此时的 k 为 4。因为我们样将所有的像素聚类成 4 个簇。
代码实现
读取图像
首先我们需要将图像读入内存,python
有很多库都实现了读取图像的功能。在这里,我将使用opencv
这个库来读取图像。opencv
在计算机视觉领域的使用是非常广泛的,如果你对计算机视觉感兴趣,可以深入了解一下opencv
以及一些图像处理的知识,在这里我就不多做介绍了。
opencv
读图像很简单,只需要使用如下代码即可:
# 导入opencv库
import cv2
# 读取图像,图像名字为test.jpg,并将图像保存到img变量中
img = cv2.imread('test.jpg')
读取到图像后,就可以着手实现 k 均值算法了。
k均值算法
在实现 k 均值算法之前,可以思考一下 k 均值算法所需要的参数,很明显需要三个参数,一个是数据,在这里也就是读取到的图像;另一个就是 k,刚刚已经提到过,在这里,k 为 4。还有一个就是 k 均值算法的最大迭代次数。
因此可以写出如下函数声明:
def kmeans(n, k, image):
接下来我们来实现函数体。对于图像来说,它可以看成是一个三维数组,但是 k 均值算法需要的是一个二维数组,所以我们需要对数据进行变形。而且还需要将图像进行升维,因为在算法结束时,需要知道每个像素所对应的簇标签值。所以会有如下代码:
# 图像的高
height = image.shape[0]
# 图像的高
width = image.shape[1]
tmp = image.reshape(-1, 3)
result = tmp.copy()
#扩展一个维度用来存放标签
result = np.column_stack((result, np.ones(height*width)))
做好数据处理后,可以开始根据上一节所提到的 k 均值算法的原理来实现该算法了。首先需要初始化质心。
# 初始化质心
center_point = np.random.choice(height*width, k, replace=False)
center = result[center_point, :]
# 初始化距离矩阵
distance = [[] for i in range(k)]
然后需要不断地迭代更新我们的质心。
# 迭代n次
for i in range(n):
# 计算每个像素到各个质心的距离
for j in range(k):
distance[j] = np.sqrt(np.sum(np.square(result - np.array(center[j])), axis=1))
# 为每个像素打上簇标签
result[:, 3] = np.argmin(np.array(distance), axis=0)
# 更新质心
for j in range(k):
center[j] = np.mean(result[result[:, 3] == j], axis=0)
return result
因此,完整的代码如下:
def kmeans(n, k, image):
height = image.shape[0]
width = image.shape[1]
tmp = image.reshape(-1, 3)
result = tmp.copy()
#扩展一个维度用来存放标签
result = np.column_stack((result, np.ones(height*width)))
center_point = np.random.choice(height*width, k, replace=False)
center = result[center_point, :]
distance = [[] for i in range(k)]
#迭代
for i in range(n):
for j in range(k):
distance[j] = np.sqrt(np.sum(np.square(result - np.array(center[j])), axis=1))
result[:, 3] = np.argmin(np.array(distance), axis=0)
for j in range(k):
center[j] = np.mean(result[result[:, 3] == j], axis=0)
return result
分割图像
有了 k 均值算法后,就可以分割图像了。
plt.subplot('121')
plt.imshow(img)
height = img.shape[0]
width = img.shape[1]
# 使用刚刚实现的k均值算法
result_img = kmeans(150, 5, img)
# 将最终结果变形为二维数组
result_img = result_img[:, 3].reshape(height, width)
plt.subplot('122')
plt.imshow(result_img)
plt.show()
最终可以看到 k 均值算法能够较好的对图像进行分割。