From 1e27452da0bf5515b70fcff93c1a63c04e5ae61f Mon Sep 17 00:00:00 2001 From: px38ly72e <494532044@qq.com> Date: Sun, 17 Dec 2023 20:44:09 +0800 Subject: [PATCH] Add main.py --- main.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..af9cf8d --- /dev/null +++ b/main.py @@ -0,0 +1,81 @@ +!pip install paddlehub==1.8.1 -i https://pypi.tuna.tsinghua.edu.cn/simple +import paddlehub as hub +module = hub.Module(name="mobilenet_v2_imagenet") +!unzip -o /data/shixunfiles/26a2e3c3b2c50fe54e2fcab6e031a141_1607408726958.zip +from paddlehub.dataset.base_cv_dataset import BaseCVDataset +class DemoDataset(BaseCVDataset): + def __init__(self): + self.dataset_dir = "car_datasets" + super(DemoDataset, self).__init__( + base_path=self.dataset_dir, + train_list_file="train_list.txt", + validate_list_file="validate_list.txt", + test_list_file="test_list.txt", + label_list_file="label_list.txt", + ) +dataset = DemoDataset() +data_reader = hub.reader.ImageClassificationReader( + image_width=module.get_expected_image_width(), + image_height=module.get_expected_image_height(), + images_mean=module.get_pretrained_images_mean(), + images_std=module.get_pretrained_images_std(), + dataset=dataset) +config = hub.RunConfig( + use_cuda=False, + num_epoch=10, + batch_size=32, + eval_interval=50, + strategy=hub.finetune.strategy.DefaultFinetuneStrategy()) +input_dict, output_dict, program = module.context(trainable=True) +img = input_dict["image"] +feature_map = output_dict["feature_map"] +feed_list = [img.name] +task = hub.ImageClassifierTask( + data_reader=data_reader, + feed_list=feed_list, + feature=feature_map, + num_classes=dataset.num_labels, + config=config) +run_states = task.finetune_and_eval() +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +from pandas import Series,DataFrame +%matplotlib inline +import os +dirs=os.listdir('car_datasets/test') +num=0 +for i in os.listdir('car_datasets/test'): + m='car_datasets/test/'+i + dirs[num]=m + num+=1 +s=0 +b=0 +a=os.listdir('car_datasets/test') +for i in a: + b+=len(os.listdir('car_datasets/test/'+i)) +data=[] +for i in range(b): + data.append('o') +for i in dirs: + for j in os.listdir(i): + n=i+'/'+j + data[s]=n + s+=1 +label_map = dataset.label_dict() +index = 0 +true=0 +run_states = task.predict(data=data) +results = [run_state.run_results for run_state in run_states] +for batch_result in results: + batch_result = np.argmax(batch_result, axis=2)[0] + for result in batch_result: + index += 1 + result = label_map[result] + actual=os.path.dirname(data[index - 1]) + actual=actual.split('/') + if actual[-1]==result: + true+=1 + print("input %i is %s, and the predict result is ( %s )" % + (index, data[index - 1], result)) +print( '预测正确率为{:.2%}'.format(true/index)) \ No newline at end of file