forked from pl8qemw3k/garbage
Compare commits
4 Commits
Author | SHA1 | Date |
---|---|---|
|
2a7b22261d | 2 years ago |
|
0020179263 | 2 years ago |
|
a24a14ec18 | 2 years ago |
|
646cb3884c | 2 years ago |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,81 @@
|
||||
'''
|
||||
数据样本分析
|
||||
画出数据量条形图
|
||||
画出图像分辨率散点图
|
||||
'''
|
||||
import os
|
||||
import PIL.Image as Image
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def plot_resolution(dataset_root_path):
|
||||
img_size_list = [] # 存储图片长宽数据
|
||||
for root, dirs, files in os.walk(dataset_root_path):
|
||||
for file_i in files:
|
||||
file_i_full_path = os.path.join(root, file_i)
|
||||
img_i = Image.open(file_i_full_path)
|
||||
img_i_size = img_i.size # 获取单张图像的长宽
|
||||
img_size_list.append(img_i_size)
|
||||
|
||||
print(img_size_list) #
|
||||
|
||||
width_list = [img_size_list[i][0] for i in range(len(img_size_list))]#提取所有图片的宽度信息,构建一个新的列表
|
||||
height_list = [img_size_list[i][1] for i in range(len(img_size_list))]#提取所有图片的高度信息,构建一个新的列表
|
||||
|
||||
# print(width_list) # 640
|
||||
# print(height_list) # 346
|
||||
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 设置中文字体
|
||||
plt.rcParams["font.size"] = 8
|
||||
plt.rcParams["axes.unicode_minus"] = False # 该语句解决图像中的“-”负号的乱码问题
|
||||
|
||||
plt.scatter(width_list, height_list, s=1)
|
||||
plt.xlabel("宽")
|
||||
plt.ylabel("高")
|
||||
plt.title("图像宽高分布")
|
||||
plt.show()
|
||||
|
||||
|
||||
# 画出条形图
|
||||
def plot_bar(dataset_root_path):
|
||||
|
||||
file_name_list = []
|
||||
file_num_list = []
|
||||
for root, dirs, files in os.walk(dataset_root_path):
|
||||
if len(dirs) != 0:
|
||||
for dir_i in dirs:
|
||||
file_name_list.append(dir_i)
|
||||
file_num_list.append(len(files))
|
||||
|
||||
file_num_list = file_num_list[1:]
|
||||
# 求均值,并把均值以横线形式显示出来
|
||||
mean = np.mean(file_num_list)
|
||||
print("mean = ", mean)
|
||||
|
||||
bar_positions = np.arange(len(file_name_list))
|
||||
|
||||
fig, ax = plt.subplots() # 定义画的区间和子画
|
||||
ax.bar(bar_positions, file_num_list, 0.5) # 画柱图,参数:柱间的距离,柱的值,柱的宽度
|
||||
|
||||
ax.plot(bar_positions, [mean for i in bar_positions], color="red") # 显示平均值
|
||||
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 设置中文字体
|
||||
plt.rcParams["font.size"] = 8
|
||||
plt.rcParams["axes.unicode_minus"] = False # 该语句解决图像中的“-”负号的乱码问题
|
||||
|
||||
ax.set_xticks(bar_positions) # 设置x轴的刻度
|
||||
ax.set_xticklabels(file_name_list, rotation=90) # 设置x轴的标签
|
||||
ax.set_ylabel("类别数量")
|
||||
ax.set_title("数据分布图")
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
dataset_root_path = "dataset"
|
||||
|
||||
plot_resolution(dataset_root_path)
|
||||
|
||||
# plot_bar(dataset_root_path)
|
||||
|
||||
|
||||
|
@ -0,0 +1,42 @@
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
dataset_root_path = "dataset"
|
||||
|
||||
min = 200 # 短边
|
||||
max = 2000 # 长边
|
||||
ratio = 0.5 # 短边 / 长边
|
||||
|
||||
delete_list = [] # 所有图片的长宽数据
|
||||
for root,dirs,files in os.walk(dataset_root_path):
|
||||
for file_i in files:
|
||||
file_i_full_path = os.path.join(root, file_i)#构建当前文件的完整路径
|
||||
img_i = Image.open(file_i_full_path)#将其加载为一个图像对象
|
||||
img_i_size = img_i.size # 获取单张图像的长宽
|
||||
|
||||
# 删除单边过短的图片
|
||||
if img_i_size[0]<min or img_i_size[1]<min:
|
||||
print(file_i_full_path, " 不满足要求")
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
# 删除单边过长的图片
|
||||
if img_i_size[0] > max or img_i_size[1] > max:
|
||||
print(file_i_full_path, " 不满足要求")
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
# 删除宽高比例不当的图片
|
||||
long = img_i_size[0] if img_i_size[0] > img_i_size[1] else img_i_size[1]
|
||||
short = img_i_size[0] if img_i_size[0] < img_i_size[1] else img_i_size[1]
|
||||
|
||||
if short / long < ratio:
|
||||
print(file_i_full_path, " 不满足要求",img_i_size[0],img_i_size[1])
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
|
||||
# print(delete_list)
|
||||
for file_i in delete_list:
|
||||
try:
|
||||
print("正在删除",file_i)
|
||||
os.remove(file_i)
|
||||
except:
|
||||
pass
|
@ -0,0 +1,49 @@
|
||||
'''
|
||||
使用翻转进行数据增强
|
||||
'''
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# 水平翻转
|
||||
def Horizontal(image):
|
||||
return cv2.flip(image,1,dst=None) #水平镜像
|
||||
|
||||
# 垂直翻转
|
||||
def Vertical(image):
|
||||
return cv2.flip(image,0,dst=None) #垂直镜像
|
||||
|
||||
if __name__ == '__main__':
|
||||
from_root = r"dataset"
|
||||
save_root = r"enhance_dataset"
|
||||
|
||||
threshold = 200
|
||||
|
||||
for a,b,c in os.walk(from_root):
|
||||
for file_i in c:
|
||||
file_i_path = os.path.join(a,file_i)
|
||||
|
||||
split = os.path.split(file_i_path)
|
||||
dir_loc = os.path.split(split[0])[1]
|
||||
save_path = os.path.join(save_root,dir_loc)
|
||||
|
||||
print(file_i_path)
|
||||
print(save_path)
|
||||
|
||||
if os.path.isdir(save_path) == False:
|
||||
os.makedirs(save_path)
|
||||
|
||||
img_i = cv2.imdecode(np.fromfile(file_i_path, dtype=np.uint8),-1) # 读取图片
|
||||
|
||||
cv2.imencode('.jpg', img_i)[1].tofile(os.path.join(save_path, file_i[:-5] + "_original.jpg")) # 保存图片
|
||||
|
||||
if len(c) < threshold:
|
||||
|
||||
img_horizontal = Horizontal(img_i)
|
||||
cv2.imencode('.jpg', img_horizontal)[1].tofile(os.path.join(save_path, file_i[:-5] + "_horizontal.jpg")) # 保存图片
|
||||
|
||||
img_vertical = Vertical(img_i)
|
||||
cv2.imencode('.jpg', img_vertical)[1].tofile(os.path.join(save_path, file_i[:-5] + "_vertical.jpg")) # 保存图片
|
||||
|
||||
else:
|
||||
pass
|
@ -0,0 +1,23 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
img_root = r"enhance_dataset"
|
||||
threshold = 300
|
||||
|
||||
for a,b,c in os.walk(img_root):
|
||||
if len(c) > threshold:
|
||||
delete_list = []
|
||||
for file_i in c:
|
||||
file_i_full_path = os.path.join(a,file_i)
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
random.shuffle(delete_list)
|
||||
|
||||
print(delete_list)
|
||||
delete_list = delete_list[threshold:]
|
||||
for file_delete_i in delete_list:
|
||||
os.remove(file_delete_i)
|
||||
print("将会删除",file_delete_i)
|
||||
|
||||
|
||||
|
@ -0,0 +1,28 @@
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torch
|
||||
from torchvision import transforms as T
|
||||
from tqdm import tqdm
|
||||
|
||||
transform = T.Compose([
|
||||
T.RandomResizedCrop(224),
|
||||
T.ToTensor(),
|
||||
])
|
||||
|
||||
def getStat(train_data):
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
mean = torch.zeros(3)
|
||||
std = torch.zeros(3)
|
||||
for X, _ in tqdm(train_loader):
|
||||
for d in range(3):
|
||||
mean[d] += X[:, d, :, :].mean() # N, C, H ,W
|
||||
std[d] += X[:, d, :, :].std()
|
||||
mean.div_(len(train_data))
|
||||
std.div_(len(train_data))
|
||||
return list(mean.numpy()), list(std.numpy())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_dataset = ImageFolder(root=r'enhance_dataset', transform=transform)
|
||||
print(getStat(train_dataset))
|
@ -0,0 +1,40 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
train_ratio = 0.9 #训练集比例
|
||||
test_ratio = 1-train_ratio #测试集比例
|
||||
|
||||
rootdata = r"dataset"
|
||||
|
||||
train_list, test_list = [],[]
|
||||
|
||||
class_flag = -1#初始类别标签
|
||||
t=1
|
||||
for a,b,c in os.walk(rootdata):#目录路径,子目录名,文件名
|
||||
# if t==1:
|
||||
# # print(a)
|
||||
# print(b)
|
||||
# print(len(c))
|
||||
# t=-1
|
||||
|
||||
for i in range(0, int(len(c)*train_ratio)):#根据训练集比例确定训练集的文件数量
|
||||
train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n'
|
||||
# print('666'+train_data)
|
||||
train_list.append(train_data)
|
||||
|
||||
for i in range(int(len(c) * train_ratio), len(c)):
|
||||
test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'
|
||||
test_list.append(test_data)
|
||||
|
||||
class_flag += 1
|
||||
|
||||
random.shuffle(train_list)#随机打乱训练集列表 train_list 中的元素顺序
|
||||
random.shuffle(test_list)
|
||||
|
||||
with open('train.txt','w',encoding='UTF-8') as f:
|
||||
for train_img in train_list:
|
||||
f.write(str(train_img))#将训练集中的文件路径和类别标签写入文件 'train.txt',将其转换为字符串形式并写入文件。
|
||||
|
||||
with open('test.txt','w',encoding='UTF-8') as f:
|
||||
for test_img in test_list:
|
||||
f.write(test_img)
|
@ -0,0 +1,74 @@
|
||||
# -*-coding:utf-8-*-
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ReadData(data_loc):
|
||||
epoch_list = []
|
||||
train_loss_list = []
|
||||
test_loss_list = []
|
||||
test_accuracy_list = []
|
||||
|
||||
# open(data_loc,"r").readlines()
|
||||
with open(data_loc, "r") as f:
|
||||
linedata = f.readlines()
|
||||
|
||||
for line_i in linedata:
|
||||
data = line_i.split('\t')
|
||||
print("data = ", data)
|
||||
epoch_i , train_loss_i,test_loss_i,test_accuracy_i =data[1], data[3],data[5],data[7]
|
||||
epoch_list.append(int(epoch_i))
|
||||
train_loss_list.append(float(train_loss_i))
|
||||
test_loss_list.append(float(test_loss_i))
|
||||
test_accuracy_list.append(float(test_accuracy_i))
|
||||
|
||||
# print(epoch_list)
|
||||
# print(train_loss_list)
|
||||
# print(test_loss_list)
|
||||
# print(test_accuracy_list)
|
||||
return epoch_list, train_loss_list ,test_loss_list,test_accuracy_list
|
||||
|
||||
|
||||
|
||||
def DrawLoss(train_loss_list,train_loss_list_2):
|
||||
plt.style.use('dark_background')
|
||||
plt.title("Loss")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("loss")
|
||||
|
||||
train_loss_list = train_loss_list[:10]
|
||||
|
||||
epoch_list = [i for i in range(len(train_loss_list))]
|
||||
|
||||
p1, = plt.plot(epoch_list, train_loss_list, linewidth=3)
|
||||
p2, = plt.plot(epoch_list, train_loss_list_2, linewidth=3)
|
||||
|
||||
plt.legend([p1, p2], ["with pretrain", "no pretrain"])
|
||||
plt.show()
|
||||
|
||||
def DrawAcc(train_loss_list,train_loss_list_2):
|
||||
plt.style.use('dark_background')
|
||||
plt.title("Accuracy")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("accuracy")
|
||||
|
||||
train_loss_list = train_loss_list[:10]
|
||||
|
||||
epoch_list = [i for i in range(len(train_loss_list))]
|
||||
|
||||
p1, = plt.plot(epoch_list, train_loss_list, linewidth=3)
|
||||
p2, = plt.plot(epoch_list, train_loss_list_2, linewidth=3)
|
||||
|
||||
plt.legend([p1, p2], ["with pretrain", "no pretrain"])
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_1_loc = "output/resnet18.txt"
|
||||
data_2_loc = "output/resnet18_no_pretrain.txt"
|
||||
|
||||
_, train_loss_list ,test_loss_list,test_accuracy_list = ReadData(data_1_loc)
|
||||
_, train_loss_list_2 ,test_loss_list_2,test_accuracy_list_2 = ReadData(data_2_loc)
|
||||
|
||||
DrawLoss(train_loss_list,train_loss_list_2)
|
||||
|
||||
DrawAcc(test_accuracy_list,test_accuracy_list_2)
|
@ -0,0 +1,139 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import io
|
||||
from flask import Flask, request
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlopen
|
||||
from alibabacloud_imagerecog20190930.client import Client
|
||||
from alibabacloud_imagerecog20190930.models import ClassifyingRubbishAdvanceRequest
|
||||
from alibabacloud_tea_openapi.models import Config
|
||||
from alibabacloud_tea_util.models import RuntimeOptions
|
||||
import mysql.connector
|
||||
from mysql.connector import Error
|
||||
|
||||
def insert_feedback(feedbackInfo,feedbackType):
|
||||
try:
|
||||
# 弹出文件选择对话框,选择要插入的图片文件
|
||||
file_path = r'C:\Users\admin\Desktop\garbage.jpg'
|
||||
|
||||
# 读取图片文件并转换为字节数组
|
||||
with open(file_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
# files = {'file': open(file_path, 'rb')}
|
||||
|
||||
# 建立数据库连接
|
||||
database_name = "garbage"
|
||||
username = "root"
|
||||
password = "root"
|
||||
host = '127.0.0.1'
|
||||
port = 3306
|
||||
connection = mysql.connector.connect(
|
||||
host=host,
|
||||
port=port,
|
||||
user=username,
|
||||
password=password,
|
||||
database=database_name
|
||||
)
|
||||
cursor = connection.cursor(prepared=True)
|
||||
# 插入用户反馈信息数据到数据库
|
||||
sql = "INSERT INTO feedback (feedbackInfo, feedbackType, feedbackPhoto) VALUES (%s, %s, %s)"
|
||||
cursor.execute(sql, (feedbackInfo, feedbackType, image_data))
|
||||
connection.commit()
|
||||
print("信息上传成功!")
|
||||
if cursor is not None:
|
||||
cursor.close()
|
||||
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print("未选择文件。")
|
||||
except (Error, IOError) as e:
|
||||
print(e)
|
||||
finally:
|
||||
if 'connection' in locals():
|
||||
if connection.is_connected():
|
||||
connection.close()
|
||||
|
||||
|
||||
def recognition(url):
|
||||
config = Config(
|
||||
access_key_id='LTAI5tHR7LEUHziRSK4TLSnM',
|
||||
access_key_secret='uq7sZKEixU5osl54GVzcj4Yb1Yb7XA',
|
||||
endpoint='imagerecog.cn-shanghai.aliyuncs.com',
|
||||
region_id='cn-shanghai'
|
||||
)
|
||||
# img = open(r'tmp.jpg', 'rb')
|
||||
img = io.BytesIO(urlopen(url).read())
|
||||
classifying_rubbish_request = ClassifyingRubbishAdvanceRequest()
|
||||
classifying_rubbish_request.image_urlobject = img
|
||||
|
||||
runtime = RuntimeOptions()
|
||||
try:
|
||||
client = Client(config)
|
||||
response = client.classifying_rubbish_advance(classifying_rubbish_request, runtime)
|
||||
# result = response.body
|
||||
category = response.body.data.elements[0].category;
|
||||
return category
|
||||
# return JsonResponse({'result': result})
|
||||
except Exception as error:
|
||||
print(error)
|
||||
# print(error.code)
|
||||
# return JsonResponse({'error': str(error)}, status=500)
|
||||
|
||||
|
||||
#垃圾识别接口
|
||||
app = Flask(__name__)
|
||||
# 保存上传文件的目标目录
|
||||
UPLOAD_FOLDER = 'C:/Users/admin/Desktop'
|
||||
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
||||
|
||||
@app.route('/upload_image', methods=['POST'])
|
||||
def upload_image():
|
||||
if 'file' not in request.files:
|
||||
return 'No file part'
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
if file.filename == '':
|
||||
return 'No selected file'
|
||||
|
||||
if file:
|
||||
# 保存上传的文件到指定目录
|
||||
file.save(os.path.join(app.config['UPLOAD_FOLDER'], 'garbage.jpg'))
|
||||
a = recognition(os.path.join(app.config['UPLOAD_FOLDER'], 'garbage.jpg'))
|
||||
print(a)
|
||||
return a
|
||||
|
||||
@app.route('/upload_text', methods=['POST'])
|
||||
def upload_text():
|
||||
data = request.data.decode('utf-8') # 解码接收到的数据
|
||||
lst = data.split('|')
|
||||
str = ''
|
||||
if lst[0][0]:
|
||||
str+='用户体验'
|
||||
if lst[0][1]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='识别偏差'
|
||||
if lst[0][2]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='功能建议'
|
||||
if lst[0][3]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='分类扩充'
|
||||
if lst[0][4]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='其他问题'
|
||||
if lst[0][5]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='好评推荐'
|
||||
|
||||
insert_feedback(str,lst[1])
|
||||
return "Successfully"
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5000)
|
Loading…
Reference in new issue