You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

4.1 KiB

7.3 图像分割

了解数据

既然要对图像进行图像分割,那肯定需要有图像。在这里,为你准备了一张图像,如下图所示:

如果我们想将整张图分割成双黄线,马路,路边,天空这 4 个部分,那么很明显我们可以使用 k均值算法来进行分割而且此时的 k 为 4。因为我们样将所有的像素聚类成 4 个簇。

代码实现

读取图像

首先我们需要将图像读入内存,python有很多库都实现了读取图像的功能。在这里,我将使用opencv这个库来读取图像。opencv在计算机视觉领域的使用是非常广泛的,如果你对计算机视觉感兴趣,可以深入了解一下opencv以及一些图像处理的知识,在这里我就不多做介绍了。

opencv读图像很简单,只需要使用如下代码即可:

# 导入opencv库
import cv2

# 读取图像图像名字为test.jpg并将图像保存到img变量中
img = cv2.imread('test.jpg')

读取到图像后,就可以着手实现 k 均值算法了。

k均值算法

在实现 k 均值算法之前,可以思考一下 k 均值算法所需要的参数,很明显需要三个参数,一个是数据,在这里也就是读取到的图像;另一个就是 k刚刚已经提到过在这里k 为 4。还有一个就是 k 均值算法的最大迭代次数。

因此可以写出如下函数声明:

def kmeans(n, k, image):

接下来我们来实现函数体。对于图像来说,它可以看成是一个三维数组,但是 k 均值算法需要的是一个二维数组,所以我们需要对数据进行变形。而且还需要将图像进行升维,因为在算法结束时,需要知道每个像素所对应的簇标签值。所以会有如下代码:

    # 图像的高
    height = image.shape[0]
    # 图像的高
    width = image.shape[1]
    tmp = image.reshape(-1, 3)
    result = tmp.copy()

    #扩展一个维度用来存放标签
    result = np.column_stack((result, np.ones(height*width)))

做好数据处理后,可以开始根据上一节所提到的 k 均值算法的原理来实现该算法了。首先需要初始化质心。

    # 初始化质心
    center_point = np.random.choice(height*width, k, replace=False)
    center = result[center_point, :]
    # 初始化距离矩阵
    distance = [[] for i in range(k)]

然后需要不断地迭代更新我们的质心。

    # 迭代n次
    for i in range(n):
        # 计算每个像素到各个质心的距离
        for j in range(k):
            distance[j] = np.sqrt(np.sum(np.square(result - np.array(center[j])), axis=1))
        # 为每个像素打上簇标签
        result[:, 3] = np.argmin(np.array(distance), axis=0)

        # 更新质心
        for j in range(k):
            center[j] = np.mean(result[result[:, 3] == j], axis=0)
    return result

因此,完整的代码如下:

def kmeans(n, k, image):
    height = image.shape[0]
    width = image.shape[1]
    tmp = image.reshape(-1, 3)
    result = tmp.copy()

    #扩展一个维度用来存放标签
    result = np.column_stack((result, np.ones(height*width)))

    center_point = np.random.choice(height*width, k, replace=False)
    center = result[center_point, :]
    distance = [[] for i in range(k)]

    #迭代
    for i in range(n):
        for j in range(k):
            distance[j] = np.sqrt(np.sum(np.square(result - np.array(center[j])), axis=1))
        result[:, 3] = np.argmin(np.array(distance), axis=0)
        for j in range(k):
            center[j] = np.mean(result[result[:, 3] == j], axis=0)
    return result

分割图像

有了 k 均值算法后,就可以分割图像了。

plt.subplot('121')
plt.imshow(img)

height = img.shape[0]
width = img.shape[1]

# 使用刚刚实现的k均值算法
result_img = kmeans(150, 5, img)
# 将最终结果变形为二维数组
result_img = result_img[:, 3].reshape(height, width)

plt.subplot('122')
plt.imshow(result_img)
plt.show()

最终可以看到 k 均值算法能够较好的对图像进行分割。