diff --git a/tryClsGarbage.py b/tryClsGarbage.py deleted file mode 100644 index f2fd4e9..0000000 --- a/tryClsGarbage.py +++ /dev/null @@ -1,138 +0,0 @@ -#coding:utf-8 -# 环境(均为CPU) -# paddlehub 1.8.0 -# paddlepaddle 2.3.0 - - -import argparse -import os -import ast -import paddle.fluid as fluid -import paddlehub as hub -import paddle -import numpy as np -import matplotlib -matplotlib.rcParams['font.family']='simHei' #显示汉字 -import matplotlib.pyplot as plt -import matplotlib.image as mpimg -from paddlehub.dataset.base_cv_dataset import BaseCVDataset - -# yapf: disable -parser = argparse.ArgumentParser(__doc__) -parser.add_argument("--num_epoch", type=int, default=1, help="Number of epoches for fine-tuning.") -parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning.") -parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.") -parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.") -parser.add_argument("--module", type=str, default="resnet50", help="Module used as feature extractor.") -parser.add_argument("--dataset", type=str, default="flowers", help="Dataset to fine-tune.") -parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True, help="Whether use data parallel.") -# yapf: enable. - -module_map = { - "resnet50": "resnet_v2_50_imagenet", - "resnet101": "resnet_v2_101_imagenet", - "resnet152": "resnet_v2_152_imagenet", - "mobilenet": "mobilenet_v2_imagenet", - "nasnet": "nasnet_imagenet", - "pnasnet": "pnasnet_imagenet" -} - -class DemoDataset(BaseCVDataset): - def __init__(self): - # 数据集存放位置(换成自己电脑数据集的位置) - self.dataset_dir = "D:\clsGarbageCode\garbage" - super(DemoDataset, self).__init__( - base_path=self.dataset_dir, - train_list_file="train_list.txt", - validate_list_file="val_list.txt", - test_list_file="test_list.txt", - # predict_file="predict_list.txt", - label_list_file="label_list.txt" - ) - -def finetune(args): - # Load Paddlehub pretrained model - paddle.enable_static() - module = hub.Module(name=args.module) - input_dict, output_dict, program = module.context(trainable=True) - - dataset = DemoDataset() - - # Use ImageClassificationReader to read dataset - data_reader = hub.reader.ImageClassificationReader( - image_width=module.get_expected_image_width(), - image_height=module.get_expected_image_height(), - images_mean=module.get_pretrained_images_mean(), - images_std=module.get_pretrained_images_std(), - dataset=dataset) - - feature_map = output_dict["feature_map"] - - # Setup feed list for data feeder - feed_list = [input_dict["image"].name] - - # Setup RunConfig for PaddleHub Fine-tune API - config = hub.RunConfig( - use_data_parallel=args.use_data_parallel, - use_cuda=False, - # use_cuda=args.use_gpu, - num_epoch=args.num_epoch, - batch_size=args.batch_size, - checkpoint_dir=args.checkpoint_dir, - strategy=hub.finetune.strategy.DefaultFinetuneStrategy()) - - # Define a image classification task by PaddleHub Fine-tune API - task = hub.ImageClassifierTask( - data_reader=data_reader, - feed_list=feed_list, - feature=feature_map, - num_classes=dataset.num_labels, - config=config) - - # Fine-tune by PaddleHub's API - # task.finetune_and_eval() - return task - -#暂时注释的代码 -'''def testShow(task): - # "任选一张图片路径/xx.jpg" - data = ["/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg"] - dataset = DemoDataset() - label_map = dataset.label_dict() - index = 0 - # get classification result - run_states = task.predict(data=data) - results = [run_state.run_results for run_state in run_states] - for batch_result in results: - # get predict index - batch_result = np.argmax(batch_result, axis=2)[0] - for result in batch_result: - index += 1 - result = label_map[result] - print("input %i is %s, and the predict result is ( %s )" % - (index, data[index - 1], result)) - d=plt.imread("/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg") - plt.imshow(d)''' - -def testShow(task): - # "任选一张图片路径/xx.jpg" - data = ["D:\clsGarbageCode\garbage\你是什么垃圾.(jpg)\下载.jpg"] - #print('图片类型是',end='') - #print(task.predict(data=data,return_result=True)) - matplotlib.use('TkAgg') - d=mpimg.imread(data[0]) - #plt.xlabel('垃圾类型为') - plt.title(task.predict(data=data,return_result=True)) - plt.imshow(d) - plt.show() - -if __name__ == "__main__": - args = parser.parse_args() - if not args.module in module_map: - hub.logger.error("module should in %s" % module_map.keys()) - exit(1) - args.module = module_map[args.module] - - task = finetune(args) # 微调模型 - testShow(task) # 传入微调模型,使用任意图片进行测试 -