From 6f616683b5a8d40e54cd9a4ccdc85aabb76e1094 Mon Sep 17 00:00:00 2001 From: hnu202401010612 <3117511861@qq.com> Date: Thu, 26 Dec 2024 00:05:34 +0800 Subject: [PATCH] ADD file via upload --- tryClsGarbage.py | 138 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 tryClsGarbage.py diff --git a/tryClsGarbage.py b/tryClsGarbage.py new file mode 100644 index 0000000..f2fd4e9 --- /dev/null +++ b/tryClsGarbage.py @@ -0,0 +1,138 @@ +#coding:utf-8 +# 环境(均为CPU) +# paddlehub 1.8.0 +# paddlepaddle 2.3.0 + + +import argparse +import os +import ast +import paddle.fluid as fluid +import paddlehub as hub +import paddle +import numpy as np +import matplotlib +matplotlib.rcParams['font.family']='simHei' #显示汉字 +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +from paddlehub.dataset.base_cv_dataset import BaseCVDataset + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--num_epoch", type=int, default=1, help="Number of epoches for fine-tuning.") +parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning.") +parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.") +parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.") +parser.add_argument("--module", type=str, default="resnet50", help="Module used as feature extractor.") +parser.add_argument("--dataset", type=str, default="flowers", help="Dataset to fine-tune.") +parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True, help="Whether use data parallel.") +# yapf: enable. + +module_map = { + "resnet50": "resnet_v2_50_imagenet", + "resnet101": "resnet_v2_101_imagenet", + "resnet152": "resnet_v2_152_imagenet", + "mobilenet": "mobilenet_v2_imagenet", + "nasnet": "nasnet_imagenet", + "pnasnet": "pnasnet_imagenet" +} + +class DemoDataset(BaseCVDataset): + def __init__(self): + # 数据集存放位置(换成自己电脑数据集的位置) + self.dataset_dir = "D:\clsGarbageCode\garbage" + super(DemoDataset, self).__init__( + base_path=self.dataset_dir, + train_list_file="train_list.txt", + validate_list_file="val_list.txt", + test_list_file="test_list.txt", + # predict_file="predict_list.txt", + label_list_file="label_list.txt" + ) + +def finetune(args): + # Load Paddlehub pretrained model + paddle.enable_static() + module = hub.Module(name=args.module) + input_dict, output_dict, program = module.context(trainable=True) + + dataset = DemoDataset() + + # Use ImageClassificationReader to read dataset + data_reader = hub.reader.ImageClassificationReader( + image_width=module.get_expected_image_width(), + image_height=module.get_expected_image_height(), + images_mean=module.get_pretrained_images_mean(), + images_std=module.get_pretrained_images_std(), + dataset=dataset) + + feature_map = output_dict["feature_map"] + + # Setup feed list for data feeder + feed_list = [input_dict["image"].name] + + # Setup RunConfig for PaddleHub Fine-tune API + config = hub.RunConfig( + use_data_parallel=args.use_data_parallel, + use_cuda=False, + # use_cuda=args.use_gpu, + num_epoch=args.num_epoch, + batch_size=args.batch_size, + checkpoint_dir=args.checkpoint_dir, + strategy=hub.finetune.strategy.DefaultFinetuneStrategy()) + + # Define a image classification task by PaddleHub Fine-tune API + task = hub.ImageClassifierTask( + data_reader=data_reader, + feed_list=feed_list, + feature=feature_map, + num_classes=dataset.num_labels, + config=config) + + # Fine-tune by PaddleHub's API + # task.finetune_and_eval() + return task + +#暂时注释的代码 +'''def testShow(task): + # "任选一张图片路径/xx.jpg" + data = ["/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg"] + dataset = DemoDataset() + label_map = dataset.label_dict() + index = 0 + # get classification result + run_states = task.predict(data=data) + results = [run_state.run_results for run_state in run_states] + for batch_result in results: + # get predict index + batch_result = np.argmax(batch_result, axis=2)[0] + for result in batch_result: + index += 1 + result = label_map[result] + print("input %i is %s, and the predict result is ( %s )" % + (index, data[index - 1], result)) + d=plt.imread("/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg") + plt.imshow(d)''' + +def testShow(task): + # "任选一张图片路径/xx.jpg" + data = ["D:\clsGarbageCode\garbage\你是什么垃圾.(jpg)\下载.jpg"] + #print('图片类型是',end='') + #print(task.predict(data=data,return_result=True)) + matplotlib.use('TkAgg') + d=mpimg.imread(data[0]) + #plt.xlabel('垃圾类型为') + plt.title(task.predict(data=data,return_result=True)) + plt.imshow(d) + plt.show() + +if __name__ == "__main__": + args = parser.parse_args() + if not args.module in module_map: + hub.logger.error("module should in %s" % module_map.keys()) + exit(1) + args.module = module_map[args.module] + + task = finetune(args) # 微调模型 + testShow(task) # 传入微调模型,使用任意图片进行测试 +