You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
1.5 KiB
38 lines
1.5 KiB
from sklearn.neighbors import KNeighborsClassifier
|
|
from sklearn.preprocessing import LabelEncoder
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.metrics import classification_report
|
|
from simplepreprocessor import SimplePreprocessor
|
|
from simpledatasetloader import SimpleDatasetLoader
|
|
from imutils import paths
|
|
import argparse
|
|
|
|
if __name__ == '__main__':
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("-d", "--dataset", required=True, help="path to input dataset")
|
|
ap.add_argument("-k", "--neighbors", type=int, default=1, help="of nearest neighbors for classification")
|
|
ap.add_argument("-j", "--jobs", type=int, help="of jobs for K-NN distance (-1 uses all variables cores)")
|
|
args = vars(ap.parse_args())
|
|
|
|
print("[INFO] loading images...")
|
|
imagePaths = list(paths.list_images(args["dataset"]))
|
|
|
|
sp = SimplePreprocessor(32, 32)
|
|
sdl = SimpleDatasetLoader(preprocessors=[sp])
|
|
(data, labels) = sdl.load(imagePaths, verbose=100)
|
|
data = data.reshape((data.shape[0], 3072))
|
|
|
|
print("[INFO] features matrix:{:.1f}MB".format(data.nbytes / (1024 *1000.0)))
|
|
|
|
le = LabelEncoder()
|
|
labels = le.fit_transform(labels)
|
|
|
|
(trainX, testX, trainY, testY) = train_test_split(data, labels, test_size=0.25, random_state=42)
|
|
|
|
print("[INFO] evaluating K-NN classifier...")
|
|
|
|
model =KNeighborsClassifier(n_neighbors=3)
|
|
model.fit(trainX, trainY)
|
|
print(classification_report(testY, model.predict(testX), target_names=le.classes_))
|
|
|