#coding:utf-8 # 环境(均为CPU) # paddlehub 1.8.0 # paddlepaddle 2.3.0 import argparse import os import ast import paddle.fluid as fluid import paddlehub as hub import paddle import numpy as np import matplotlib matplotlib.rcParams['font.family']='simHei' #显示汉字 import matplotlib.pyplot as plt import matplotlib.image as mpimg from paddlehub.dataset.base_cv_dataset import BaseCVDataset # yapf: disable parser = argparse.ArgumentParser(__doc__) parser.add_argument("--num_epoch", type=int, default=1, help="Number of epoches for fine-tuning.") parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning.") parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.") parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.") parser.add_argument("--module", type=str, default="resnet50", help="Module used as feature extractor.") parser.add_argument("--dataset", type=str, default="flowers", help="Dataset to fine-tune.") parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True, help="Whether use data parallel.") # yapf: enable. module_map = { "resnet50": "resnet_v2_50_imagenet", "resnet101": "resnet_v2_101_imagenet", "resnet152": "resnet_v2_152_imagenet", "mobilenet": "mobilenet_v2_imagenet", "nasnet": "nasnet_imagenet", "pnasnet": "pnasnet_imagenet" } class DemoDataset(BaseCVDataset): def __init__(self): # 数据集存放位置(换成自己电脑数据集的位置) self.dataset_dir = "D:\clsGarbageCode\garbage" super(DemoDataset, self).__init__( base_path=self.dataset_dir, train_list_file="train_list.txt", validate_list_file="val_list.txt", test_list_file="test_list.txt", # predict_file="predict_list.txt", label_list_file="label_list.txt" ) def finetune(args): # Load Paddlehub pretrained model paddle.enable_static() module = hub.Module(name=args.module) input_dict, output_dict, program = module.context(trainable=True) dataset = DemoDataset() # Use ImageClassificationReader to read dataset data_reader = hub.reader.ImageClassificationReader( image_width=module.get_expected_image_width(), image_height=module.get_expected_image_height(), images_mean=module.get_pretrained_images_mean(), images_std=module.get_pretrained_images_std(), dataset=dataset) feature_map = output_dict["feature_map"] # Setup feed list for data feeder feed_list = [input_dict["image"].name] # Setup RunConfig for PaddleHub Fine-tune API config = hub.RunConfig( use_data_parallel=args.use_data_parallel, use_cuda=False, # use_cuda=args.use_gpu, num_epoch=args.num_epoch, batch_size=args.batch_size, checkpoint_dir=args.checkpoint_dir, strategy=hub.finetune.strategy.DefaultFinetuneStrategy()) # Define a image classification task by PaddleHub Fine-tune API task = hub.ImageClassifierTask( data_reader=data_reader, feed_list=feed_list, feature=feature_map, num_classes=dataset.num_labels, config=config) #Fine-tune by PaddleHub's API #task.finetune_and_eval() return task #暂时注释的代码 '''def testShow(task): # "任选一张图片路径/xx.jpg" data = ["/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg"] dataset = DemoDataset() label_map = dataset.label_dict() index = 0 # get classification result run_states = task.predict(data=data) results = [run_state.run_results for run_state in run_states] for batch_result in results: # get predict index batch_result = np.argmax(batch_result, axis=2)[0] for result in batch_result: index += 1 result = label_map[result] print("input %i is %s, and the predict result is ( %s )" % (index, data[index - 1], result)) d=plt.imread("/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg") plt.imshow(d)''' def testShow(task): # "任选一张图片路径/xx.jpg" data = ["D:\clsGarbageCode\garbage\show\images.jpg"]#你要判断的垃圾类型 #print('图片类型是',end='') #print(task.predict(data=data,return_result=True)) matplotlib.use('TkAgg') d=mpimg.imread(data[0]) #plt.ylabel('垃圾类型为') plt.title(task.predict(data=data,return_result=True)) plt.imshow(d) plt.show() if __name__ == "__main__": args = parser.parse_args() if not args.module in module_map: hub.logger.error("module should in %s" % module_map.keys()) exit(1) args.module = module_map[args.module] task = finetune(args) # 微调模型 testShow(task) # 传入微调模型,使用任意图片进行测试