You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
5.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#coding:utf-8
# 环境均为CPU
# paddlehub 1.8.0
# paddlepaddle 2.3.0
import argparse
import os
import ast
import paddle.fluid as fluid
import paddlehub as hub
import paddle
import numpy as np
import matplotlib
matplotlib.rcParams['font.family']='simHei' #显示汉字
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from paddlehub.dataset.base_cv_dataset import BaseCVDataset
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="resnet50", help="Module used as feature extractor.")
parser.add_argument("--dataset", type=str, default="flowers", help="Dataset to fine-tune.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True, help="Whether use data parallel.")
# yapf: enable.
module_map = {
"resnet50": "resnet_v2_50_imagenet",
"resnet101": "resnet_v2_101_imagenet",
"resnet152": "resnet_v2_152_imagenet",
"mobilenet": "mobilenet_v2_imagenet",
"nasnet": "nasnet_imagenet",
"pnasnet": "pnasnet_imagenet"
}
class DemoDataset(BaseCVDataset):
def __init__(self):
# 数据集存放位置(换成自己电脑数据集的位置)
self.dataset_dir = "D:\clsGarbageCode\garbage"
super(DemoDataset, self).__init__(
base_path=self.dataset_dir,
train_list_file="train_list.txt",
validate_list_file="val_list.txt",
test_list_file="test_list.txt",
# predict_file="predict_list.txt",
label_list_file="label_list.txt"
)
def finetune(args):
# Load Paddlehub pretrained model
paddle.enable_static()
module = hub.Module(name=args.module)
input_dict, output_dict, program = module.context(trainable=True)
dataset = DemoDataset()
# Use ImageClassificationReader to read dataset
data_reader = hub.reader.ImageClassificationReader(
image_width=module.get_expected_image_width(),
image_height=module.get_expected_image_height(),
images_mean=module.get_pretrained_images_mean(),
images_std=module.get_pretrained_images_std(),
dataset=dataset)
feature_map = output_dict["feature_map"]
# Setup feed list for data feeder
feed_list = [input_dict["image"].name]
# Setup RunConfig for PaddleHub Fine-tune API
config = hub.RunConfig(
use_data_parallel=args.use_data_parallel,
use_cuda=False,
# use_cuda=args.use_gpu,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
# Define a image classification task by PaddleHub Fine-tune API
task = hub.ImageClassifierTask(
data_reader=data_reader,
feed_list=feed_list,
feature=feature_map,
num_classes=dataset.num_labels,
config=config)
#Fine-tune by PaddleHub's API
#task.finetune_and_eval()
return task
#暂时注释的代码
'''def testShow(task):
# "任选一张图片路径/xx.jpg"
data = ["/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg"]
dataset = DemoDataset()
label_map = dataset.label_dict()
index = 0
# get classification result
run_states = task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
index += 1
result = label_map[result]
print("input %i is %s, and the predict result is ( %s )" %
(index, data[index - 1], result))
d=plt.imread("/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg")
plt.imshow(d)'''
def testShow(task):
# "任选一张图片路径/xx.jpg"
data = ["D:\clsGarbageCode\garbage\show\images.jpg"]#你要判断的垃圾类型
#print('图片类型是',end='')
#print(task.predict(data=data,return_result=True))
matplotlib.use('TkAgg')
d=mpimg.imread(data[0])
#plt.ylabel('垃圾类型为')
plt.title(task.predict(data=data,return_result=True))
plt.imshow(d)
plt.show()
if __name__ == "__main__":
args = parser.parse_args()
if not args.module in module_map:
hub.logger.error("module should in %s" % module_map.keys())
exit(1)
args.module = module_map[args.module]
task = finetune(args) # 微调模型
testShow(task) # 传入微调模型,使用任意图片进行测试