|
|
|
@ -0,0 +1,138 @@
|
|
|
|
|
#coding:utf-8
|
|
|
|
|
# 环境(均为CPU)
|
|
|
|
|
# paddlehub 1.8.0
|
|
|
|
|
# paddlepaddle 2.3.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import os
|
|
|
|
|
import ast
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddlehub as hub
|
|
|
|
|
import paddle
|
|
|
|
|
import numpy as np
|
|
|
|
|
import matplotlib
|
|
|
|
|
matplotlib.rcParams['font.family']='simHei' #显示汉字
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import matplotlib.image as mpimg
|
|
|
|
|
from paddlehub.dataset.base_cv_dataset import BaseCVDataset
|
|
|
|
|
|
|
|
|
|
# yapf: disable
|
|
|
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
|
|
|
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epoches for fine-tuning.")
|
|
|
|
|
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning.")
|
|
|
|
|
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
|
|
|
|
|
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
|
|
|
|
|
parser.add_argument("--module", type=str, default="resnet50", help="Module used as feature extractor.")
|
|
|
|
|
parser.add_argument("--dataset", type=str, default="flowers", help="Dataset to fine-tune.")
|
|
|
|
|
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True, help="Whether use data parallel.")
|
|
|
|
|
# yapf: enable.
|
|
|
|
|
|
|
|
|
|
module_map = {
|
|
|
|
|
"resnet50": "resnet_v2_50_imagenet",
|
|
|
|
|
"resnet101": "resnet_v2_101_imagenet",
|
|
|
|
|
"resnet152": "resnet_v2_152_imagenet",
|
|
|
|
|
"mobilenet": "mobilenet_v2_imagenet",
|
|
|
|
|
"nasnet": "nasnet_imagenet",
|
|
|
|
|
"pnasnet": "pnasnet_imagenet"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class DemoDataset(BaseCVDataset):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
# 数据集存放位置(换成自己电脑数据集的位置)
|
|
|
|
|
self.dataset_dir = "D:\clsGarbageCode\garbage"
|
|
|
|
|
super(DemoDataset, self).__init__(
|
|
|
|
|
base_path=self.dataset_dir,
|
|
|
|
|
train_list_file="train_list.txt",
|
|
|
|
|
validate_list_file="val_list.txt",
|
|
|
|
|
test_list_file="test_list.txt",
|
|
|
|
|
# predict_file="predict_list.txt",
|
|
|
|
|
label_list_file="label_list.txt"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def finetune(args):
|
|
|
|
|
# Load Paddlehub pretrained model
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
module = hub.Module(name=args.module)
|
|
|
|
|
input_dict, output_dict, program = module.context(trainable=True)
|
|
|
|
|
|
|
|
|
|
dataset = DemoDataset()
|
|
|
|
|
|
|
|
|
|
# Use ImageClassificationReader to read dataset
|
|
|
|
|
data_reader = hub.reader.ImageClassificationReader(
|
|
|
|
|
image_width=module.get_expected_image_width(),
|
|
|
|
|
image_height=module.get_expected_image_height(),
|
|
|
|
|
images_mean=module.get_pretrained_images_mean(),
|
|
|
|
|
images_std=module.get_pretrained_images_std(),
|
|
|
|
|
dataset=dataset)
|
|
|
|
|
|
|
|
|
|
feature_map = output_dict["feature_map"]
|
|
|
|
|
|
|
|
|
|
# Setup feed list for data feeder
|
|
|
|
|
feed_list = [input_dict["image"].name]
|
|
|
|
|
|
|
|
|
|
# Setup RunConfig for PaddleHub Fine-tune API
|
|
|
|
|
config = hub.RunConfig(
|
|
|
|
|
use_data_parallel=args.use_data_parallel,
|
|
|
|
|
use_cuda=False,
|
|
|
|
|
# use_cuda=args.use_gpu,
|
|
|
|
|
num_epoch=args.num_epoch,
|
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
checkpoint_dir=args.checkpoint_dir,
|
|
|
|
|
strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
|
|
|
|
|
|
|
|
|
|
# Define a image classification task by PaddleHub Fine-tune API
|
|
|
|
|
task = hub.ImageClassifierTask(
|
|
|
|
|
data_reader=data_reader,
|
|
|
|
|
feed_list=feed_list,
|
|
|
|
|
feature=feature_map,
|
|
|
|
|
num_classes=dataset.num_labels,
|
|
|
|
|
config=config)
|
|
|
|
|
|
|
|
|
|
# Fine-tune by PaddleHub's API
|
|
|
|
|
# task.finetune_and_eval()
|
|
|
|
|
return task
|
|
|
|
|
|
|
|
|
|
#暂时注释的代码
|
|
|
|
|
'''def testShow(task):
|
|
|
|
|
# "任选一张图片路径/xx.jpg"
|
|
|
|
|
data = ["/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg"]
|
|
|
|
|
dataset = DemoDataset()
|
|
|
|
|
label_map = dataset.label_dict()
|
|
|
|
|
index = 0
|
|
|
|
|
# get classification result
|
|
|
|
|
run_states = task.predict(data=data)
|
|
|
|
|
results = [run_state.run_results for run_state in run_states]
|
|
|
|
|
for batch_result in results:
|
|
|
|
|
# get predict index
|
|
|
|
|
batch_result = np.argmax(batch_result, axis=2)[0]
|
|
|
|
|
for result in batch_result:
|
|
|
|
|
index += 1
|
|
|
|
|
result = label_map[result]
|
|
|
|
|
print("input %i is %s, and the predict result is ( %s )" %
|
|
|
|
|
(index, data[index - 1], result))
|
|
|
|
|
d=plt.imread("/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg")
|
|
|
|
|
plt.imshow(d)'''
|
|
|
|
|
|
|
|
|
|
def testShow(task):
|
|
|
|
|
# "任选一张图片路径/xx.jpg"
|
|
|
|
|
data = ["D:\clsGarbageCode\garbage\你是什么垃圾.(jpg)\下载.jpg"]
|
|
|
|
|
#print('图片类型是',end='')
|
|
|
|
|
#print(task.predict(data=data,return_result=True))
|
|
|
|
|
matplotlib.use('TkAgg')
|
|
|
|
|
d=mpimg.imread(data[0])
|
|
|
|
|
#plt.xlabel('垃圾类型为')
|
|
|
|
|
plt.title(task.predict(data=data,return_result=True))
|
|
|
|
|
plt.imshow(d)
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
if not args.module in module_map:
|
|
|
|
|
hub.logger.error("module should in %s" % module_map.keys())
|
|
|
|
|
exit(1)
|
|
|
|
|
args.module = module_map[args.module]
|
|
|
|
|
|
|
|
|
|
task = finetune(args) # 微调模型
|
|
|
|
|
testShow(task) # 传入微调模型,使用任意图片进行测试
|
|
|
|
|
|