Delete 'tryClsGarbage.py'

main
hnu202401010612 8 months ago
parent 6f616683b5
commit de7a7057c2

@ -1,138 +0,0 @@
#coding:utf-8
# 环境均为CPU
# paddlehub 1.8.0
# paddlepaddle 2.3.0
import argparse
import os
import ast
import paddle.fluid as fluid
import paddlehub as hub
import paddle
import numpy as np
import matplotlib
matplotlib.rcParams['font.family']='simHei' #显示汉字
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from paddlehub.dataset.base_cv_dataset import BaseCVDataset
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for fine-tuning.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="resnet50", help="Module used as feature extractor.")
parser.add_argument("--dataset", type=str, default="flowers", help="Dataset to fine-tune.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True, help="Whether use data parallel.")
# yapf: enable.
module_map = {
"resnet50": "resnet_v2_50_imagenet",
"resnet101": "resnet_v2_101_imagenet",
"resnet152": "resnet_v2_152_imagenet",
"mobilenet": "mobilenet_v2_imagenet",
"nasnet": "nasnet_imagenet",
"pnasnet": "pnasnet_imagenet"
}
class DemoDataset(BaseCVDataset):
def __init__(self):
# 数据集存放位置(换成自己电脑数据集的位置)
self.dataset_dir = "D:\clsGarbageCode\garbage"
super(DemoDataset, self).__init__(
base_path=self.dataset_dir,
train_list_file="train_list.txt",
validate_list_file="val_list.txt",
test_list_file="test_list.txt",
# predict_file="predict_list.txt",
label_list_file="label_list.txt"
)
def finetune(args):
# Load Paddlehub pretrained model
paddle.enable_static()
module = hub.Module(name=args.module)
input_dict, output_dict, program = module.context(trainable=True)
dataset = DemoDataset()
# Use ImageClassificationReader to read dataset
data_reader = hub.reader.ImageClassificationReader(
image_width=module.get_expected_image_width(),
image_height=module.get_expected_image_height(),
images_mean=module.get_pretrained_images_mean(),
images_std=module.get_pretrained_images_std(),
dataset=dataset)
feature_map = output_dict["feature_map"]
# Setup feed list for data feeder
feed_list = [input_dict["image"].name]
# Setup RunConfig for PaddleHub Fine-tune API
config = hub.RunConfig(
use_data_parallel=args.use_data_parallel,
use_cuda=False,
# use_cuda=args.use_gpu,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
# Define a image classification task by PaddleHub Fine-tune API
task = hub.ImageClassifierTask(
data_reader=data_reader,
feed_list=feed_list,
feature=feature_map,
num_classes=dataset.num_labels,
config=config)
# Fine-tune by PaddleHub's API
# task.finetune_and_eval()
return task
#暂时注释的代码
'''def testShow(task):
# "任选一张图片路径/xx.jpg"
data = ["/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg"]
dataset = DemoDataset()
label_map = dataset.label_dict()
index = 0
# get classification result
run_states = task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
for batch_result in results:
# get predict index
batch_result = np.argmax(batch_result, axis=2)[0]
for result in batch_result:
index += 1
result = label_map[result]
print("input %i is %s, and the predict result is ( %s )" %
(index, data[index - 1], result))
d=plt.imread("/media/lsz/lll/DLcode/paddle/flower_photos/dandelion/7355522_b66e5d3078_m.jpg")
plt.imshow(d)'''
def testShow(task):
# "任选一张图片路径/xx.jpg"
data = ["D:\clsGarbageCode\garbage\你是什么垃圾.(jpg)\下载.jpg"]
#print('图片类型是',end='')
#print(task.predict(data=data,return_result=True))
matplotlib.use('TkAgg')
d=mpimg.imread(data[0])
#plt.xlabel('垃圾类型为')
plt.title(task.predict(data=data,return_result=True))
plt.imshow(d)
plt.show()
if __name__ == "__main__":
args = parser.parse_args()
if not args.module in module_map:
hub.logger.error("module should in %s" % module_map.keys())
exit(1)
args.module = module_map[args.module]
task = finetune(args) # 微调模型
testShow(task) # 传入微调模型,使用任意图片进行测试
Loading…
Cancel
Save