diff --git a/chapter5/5.2Spark机器学习-坦克卫星图片识别分类.md b/chapter5/5.2Spark机器学习-坦克卫星图片识别分类.md index bdb9d92..daa4da3 100644 --- a/chapter5/5.2Spark机器学习-坦克卫星图片识别分类.md +++ b/chapter5/5.2Spark机器学习-坦克卫星图片识别分类.md @@ -148,11 +148,48 @@ print(type(img)) #### 5.2.3 Spark 加载数据集 - 上面我们已经把数据集处理完毕之后,然后我们可以通过`Spark`来加载数据集。 ``` +import cv2 +import os +import pandas as pd +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.sql import SparkSession +from pyspark.ml.linalg import Vectors +from pyspark.ml.classification import LogisticRegression + +def get_file_path(root_path): + file_list = [] + dir_or_files = os.listdir(root_path) + for dir_file in dir_or_files: + dir_file_path = os.path.join(root_path, dir_file) + file_list.append(dir_file_path) + return file_list + + +def img2vector(imgfilename): + img = cv2.imread(imgfilename, cv2.IMREAD_GRAYSCALE) + rows, columns = img.shape + img = img.reshape(rows * columns) + # print(type(img)) + return Vectors.dense(img) + +root_path = r"testing2" +fileList = get_file_path(root_path) +vectos = [] +for x in fileList: + vector = img2vector(x) + if "btr-70" in x: + label = 0 + elif "t-72" in x: + label = 1 + else: + label = 2 + vectos.append((vector, label)) +df = pd.DataFrame(vectos, columns=['features', 'label']) spark = SparkSession.builder.master("local[*]").appName("demo").getOrCreate() +spark.sparkContext.setLogLevel("error") sparkDF = spark.createDataFrame(df) # df是我们通过pandas构建的Dataframe ```