master
XLC 6 years ago
parent 238d93a23a
commit 2da5ca2500

@ -148,11 +148,48 @@ print(type(img))
#### 5.2.3 Spark 加载数据集 #### 5.2.3 Spark 加载数据集
上面我们已经把数据集处理完毕之后,然后我们可以通过`Spark`来加载数据集。 上面我们已经把数据集处理完毕之后,然后我们可以通过`Spark`来加载数据集。
``` ```
import cv2
import os
import pandas as pd
from pyspark.mllib.evaluation import BinaryClassificationMetrics
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression
def get_file_path(root_path):
file_list = []
dir_or_files = os.listdir(root_path)
for dir_file in dir_or_files:
dir_file_path = os.path.join(root_path, dir_file)
file_list.append(dir_file_path)
return file_list
def img2vector(imgfilename):
img = cv2.imread(imgfilename, cv2.IMREAD_GRAYSCALE)
rows, columns = img.shape
img = img.reshape(rows * columns)
# print(type(img))
return Vectors.dense(img)
root_path = r"testing2"
fileList = get_file_path(root_path)
vectos = []
for x in fileList:
vector = img2vector(x)
if "btr-70" in x:
label = 0
elif "t-72" in x:
label = 1
else:
label = 2
vectos.append((vector, label))
df = pd.DataFrame(vectos, columns=['features', 'label'])
spark = SparkSession.builder.master("local[*]").appName("demo").getOrCreate() spark = SparkSession.builder.master("local[*]").appName("demo").getOrCreate()
spark.sparkContext.setLogLevel("error")
sparkDF = spark.createDataFrame(df) # df是我们通过pandas构建的Dataframe sparkDF = spark.createDataFrame(df) # df是我们通过pandas构建的Dataframe
``` ```

Loading…
Cancel
Save