Merge pull request 'final' (#4) from develop into main
commit
1003aaa8fb
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 6.3 MiB |
@ -0,0 +1,81 @@
|
||||
'''
|
||||
数据样本分析
|
||||
画出数据量条形图
|
||||
画出图像分辨率散点图
|
||||
'''
|
||||
import os
|
||||
import PIL.Image as Image
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def plot_resolution(dataset_root_path):
|
||||
img_size_list = [] # 存储图片长宽数据
|
||||
for root, dirs, files in os.walk(dataset_root_path):
|
||||
for file_i in files:
|
||||
file_i_full_path = os.path.join(root, file_i)
|
||||
img_i = Image.open(file_i_full_path)
|
||||
img_i_size = img_i.size # 获取单张图像的长宽
|
||||
img_size_list.append(img_i_size)
|
||||
|
||||
print(img_size_list) #
|
||||
|
||||
width_list = [img_size_list[i][0] for i in range(len(img_size_list))]#提取所有图片的宽度信息,构建一个新的列表
|
||||
height_list = [img_size_list[i][1] for i in range(len(img_size_list))]#提取所有图片的高度信息,构建一个新的列表
|
||||
|
||||
# print(width_list) # 640
|
||||
# print(height_list) # 346
|
||||
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 设置中文字体
|
||||
plt.rcParams["font.size"] = 8
|
||||
plt.rcParams["axes.unicode_minus"] = False # 该语句解决图像中的“-”负号的乱码问题
|
||||
|
||||
plt.scatter(width_list, height_list, s=1)
|
||||
plt.xlabel("宽")
|
||||
plt.ylabel("高")
|
||||
plt.title("图像宽高分布")
|
||||
plt.show()
|
||||
|
||||
|
||||
# 画出条形图
|
||||
def plot_bar(dataset_root_path):
|
||||
|
||||
file_name_list = []
|
||||
file_num_list = []
|
||||
for root, dirs, files in os.walk(dataset_root_path):
|
||||
if len(dirs) != 0:
|
||||
for dir_i in dirs:
|
||||
file_name_list.append(dir_i)
|
||||
file_num_list.append(len(files))
|
||||
|
||||
file_num_list = file_num_list[1:]
|
||||
# 求均值,并把均值以横线形式显示出来
|
||||
mean = np.mean(file_num_list)
|
||||
print("mean = ", mean)
|
||||
|
||||
bar_positions = np.arange(len(file_name_list))
|
||||
|
||||
fig, ax = plt.subplots() # 定义画的区间和子画
|
||||
ax.bar(bar_positions, file_num_list, 0.5) # 画柱图,参数:柱间的距离,柱的值,柱的宽度
|
||||
|
||||
ax.plot(bar_positions, [mean for i in bar_positions], color="red") # 显示平均值
|
||||
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 设置中文字体
|
||||
plt.rcParams["font.size"] = 8
|
||||
plt.rcParams["axes.unicode_minus"] = False # 该语句解决图像中的“-”负号的乱码问题
|
||||
|
||||
ax.set_xticks(bar_positions) # 设置x轴的刻度
|
||||
ax.set_xticklabels(file_name_list, rotation=90) # 设置x轴的标签
|
||||
ax.set_ylabel("类别数量")
|
||||
ax.set_title("数据分布图")
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
dataset_root_path = "dataset"
|
||||
|
||||
plot_resolution(dataset_root_path)
|
||||
|
||||
# plot_bar(dataset_root_path)
|
||||
|
||||
|
||||
|
@ -0,0 +1,42 @@
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
dataset_root_path = "dataset"
|
||||
|
||||
min = 200 # 短边
|
||||
max = 2000 # 长边
|
||||
ratio = 0.5 # 短边 / 长边
|
||||
|
||||
delete_list = [] # 所有图片的长宽数据
|
||||
for root,dirs,files in os.walk(dataset_root_path):
|
||||
for file_i in files:
|
||||
file_i_full_path = os.path.join(root, file_i)#构建当前文件的完整路径
|
||||
img_i = Image.open(file_i_full_path)#将其加载为一个图像对象
|
||||
img_i_size = img_i.size # 获取单张图像的长宽
|
||||
|
||||
# 删除单边过短的图片
|
||||
if img_i_size[0]<min or img_i_size[1]<min:
|
||||
print(file_i_full_path, " 不满足要求")
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
# 删除单边过长的图片
|
||||
if img_i_size[0] > max or img_i_size[1] > max:
|
||||
print(file_i_full_path, " 不满足要求")
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
# 删除宽高比例不当的图片
|
||||
long = img_i_size[0] if img_i_size[0] > img_i_size[1] else img_i_size[1]
|
||||
short = img_i_size[0] if img_i_size[0] < img_i_size[1] else img_i_size[1]
|
||||
|
||||
if short / long < ratio:
|
||||
print(file_i_full_path, " 不满足要求",img_i_size[0],img_i_size[1])
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
|
||||
# print(delete_list)
|
||||
for file_i in delete_list:
|
||||
try:
|
||||
print("正在删除",file_i)
|
||||
os.remove(file_i)
|
||||
except:
|
||||
pass
|
@ -0,0 +1,49 @@
|
||||
'''
|
||||
使用翻转进行数据增强
|
||||
'''
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# 水平翻转
|
||||
def Horizontal(image):
|
||||
return cv2.flip(image,1,dst=None) #水平镜像
|
||||
|
||||
# 垂直翻转
|
||||
def Vertical(image):
|
||||
return cv2.flip(image,0,dst=None) #垂直镜像
|
||||
|
||||
if __name__ == '__main__':
|
||||
from_root = r"dataset"
|
||||
save_root = r"enhance_dataset"
|
||||
|
||||
threshold = 200
|
||||
|
||||
for a,b,c in os.walk(from_root):
|
||||
for file_i in c:
|
||||
file_i_path = os.path.join(a,file_i)
|
||||
|
||||
split = os.path.split(file_i_path)
|
||||
dir_loc = os.path.split(split[0])[1]
|
||||
save_path = os.path.join(save_root,dir_loc)
|
||||
|
||||
print(file_i_path)
|
||||
print(save_path)
|
||||
|
||||
if os.path.isdir(save_path) == False:
|
||||
os.makedirs(save_path)
|
||||
|
||||
img_i = cv2.imdecode(np.fromfile(file_i_path, dtype=np.uint8),-1) # 读取图片
|
||||
|
||||
cv2.imencode('.jpg', img_i)[1].tofile(os.path.join(save_path, file_i[:-5] + "_original.jpg")) # 保存图片
|
||||
|
||||
if len(c) < threshold:
|
||||
|
||||
img_horizontal = Horizontal(img_i)
|
||||
cv2.imencode('.jpg', img_horizontal)[1].tofile(os.path.join(save_path, file_i[:-5] + "_horizontal.jpg")) # 保存图片
|
||||
|
||||
img_vertical = Vertical(img_i)
|
||||
cv2.imencode('.jpg', img_vertical)[1].tofile(os.path.join(save_path, file_i[:-5] + "_vertical.jpg")) # 保存图片
|
||||
|
||||
else:
|
||||
pass
|
@ -0,0 +1,23 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
img_root = r"enhance_dataset"
|
||||
threshold = 300
|
||||
|
||||
for a,b,c in os.walk(img_root):
|
||||
if len(c) > threshold:
|
||||
delete_list = []
|
||||
for file_i in c:
|
||||
file_i_full_path = os.path.join(a,file_i)
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
random.shuffle(delete_list)
|
||||
|
||||
print(delete_list)
|
||||
delete_list = delete_list[threshold:]
|
||||
for file_delete_i in delete_list:
|
||||
os.remove(file_delete_i)
|
||||
print("将会删除",file_delete_i)
|
||||
|
||||
|
||||
|
@ -0,0 +1,28 @@
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torch
|
||||
from torchvision import transforms as T
|
||||
from tqdm import tqdm
|
||||
|
||||
transform = T.Compose([
|
||||
T.RandomResizedCrop(224),
|
||||
T.ToTensor(),
|
||||
])
|
||||
|
||||
def getStat(train_data):
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
mean = torch.zeros(3)
|
||||
std = torch.zeros(3)
|
||||
for X, _ in tqdm(train_loader):
|
||||
for d in range(3):
|
||||
mean[d] += X[:, d, :, :].mean() # N, C, H ,W
|
||||
std[d] += X[:, d, :, :].std()
|
||||
mean.div_(len(train_data))
|
||||
std.div_(len(train_data))
|
||||
return list(mean.numpy()), list(std.numpy())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_dataset = ImageFolder(root=r'enhance_dataset', transform=transform)
|
||||
print(getStat(train_dataset))
|
@ -0,0 +1,40 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
train_ratio = 0.9 #训练集比例
|
||||
test_ratio = 1-train_ratio #测试集比例
|
||||
|
||||
rootdata = r"dataset"
|
||||
|
||||
train_list, test_list = [],[]
|
||||
|
||||
class_flag = -1#初始类别标签
|
||||
t=1
|
||||
for a,b,c in os.walk(rootdata):#目录路径,子目录名,文件名
|
||||
# if t==1:
|
||||
# # print(a)
|
||||
# print(b)
|
||||
# print(len(c))
|
||||
# t=-1
|
||||
|
||||
for i in range(0, int(len(c)*train_ratio)):#根据训练集比例确定训练集的文件数量
|
||||
train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n'
|
||||
# print('666'+train_data)
|
||||
train_list.append(train_data)
|
||||
|
||||
for i in range(int(len(c) * train_ratio), len(c)):
|
||||
test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'
|
||||
test_list.append(test_data)
|
||||
|
||||
class_flag += 1
|
||||
|
||||
random.shuffle(train_list)#随机打乱训练集列表 train_list 中的元素顺序
|
||||
random.shuffle(test_list)
|
||||
|
||||
with open('train.txt','w',encoding='UTF-8') as f:
|
||||
for train_img in train_list:
|
||||
f.write(str(train_img))#将训练集中的文件路径和类别标签写入文件 'train.txt',将其转换为字符串形式并写入文件。
|
||||
|
||||
with open('test.txt','w',encoding='UTF-8') as f:
|
||||
for test_img in test_list:
|
||||
f.write(test_img)
|
@ -0,0 +1,74 @@
|
||||
# -*-coding:utf-8-*-
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ReadData(data_loc):
|
||||
epoch_list = []
|
||||
train_loss_list = []
|
||||
test_loss_list = []
|
||||
test_accuracy_list = []
|
||||
|
||||
# open(data_loc,"r").readlines()
|
||||
with open(data_loc, "r") as f:
|
||||
linedata = f.readlines()
|
||||
|
||||
for line_i in linedata:
|
||||
data = line_i.split('\t')
|
||||
print("data = ", data)
|
||||
epoch_i , train_loss_i,test_loss_i,test_accuracy_i =data[1], data[3],data[5],data[7]
|
||||
epoch_list.append(int(epoch_i))
|
||||
train_loss_list.append(float(train_loss_i))
|
||||
test_loss_list.append(float(test_loss_i))
|
||||
test_accuracy_list.append(float(test_accuracy_i))
|
||||
|
||||
# print(epoch_list)
|
||||
# print(train_loss_list)
|
||||
# print(test_loss_list)
|
||||
# print(test_accuracy_list)
|
||||
return epoch_list, train_loss_list ,test_loss_list,test_accuracy_list
|
||||
|
||||
|
||||
|
||||
def DrawLoss(train_loss_list,train_loss_list_2):
|
||||
plt.style.use('dark_background')
|
||||
plt.title("Loss")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("loss")
|
||||
|
||||
train_loss_list = train_loss_list[:10]
|
||||
|
||||
epoch_list = [i for i in range(len(train_loss_list))]
|
||||
|
||||
p1, = plt.plot(epoch_list, train_loss_list, linewidth=3)
|
||||
p2, = plt.plot(epoch_list, train_loss_list_2, linewidth=3)
|
||||
|
||||
plt.legend([p1, p2], ["with pretrain", "no pretrain"])
|
||||
plt.show()
|
||||
|
||||
def DrawAcc(train_loss_list,train_loss_list_2):
|
||||
plt.style.use('dark_background')
|
||||
plt.title("Accuracy")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("accuracy")
|
||||
|
||||
train_loss_list = train_loss_list[:10]
|
||||
|
||||
epoch_list = [i for i in range(len(train_loss_list))]
|
||||
|
||||
p1, = plt.plot(epoch_list, train_loss_list, linewidth=3)
|
||||
p2, = plt.plot(epoch_list, train_loss_list_2, linewidth=3)
|
||||
|
||||
plt.legend([p1, p2], ["with pretrain", "no pretrain"])
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_1_loc = "output/resnet18.txt"
|
||||
data_2_loc = "output/resnet18_no_pretrain.txt"
|
||||
|
||||
_, train_loss_list ,test_loss_list,test_accuracy_list = ReadData(data_1_loc)
|
||||
_, train_loss_list_2 ,test_loss_list_2,test_accuracy_list_2 = ReadData(data_2_loc)
|
||||
|
||||
DrawLoss(train_loss_list,train_loss_list_2)
|
||||
|
||||
DrawAcc(test_accuracy_list,test_accuracy_list_2)
|
@ -0,0 +1,109 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import io
|
||||
from flask import Flask, request
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlopen
|
||||
import mysql.connector
|
||||
from mysql.connector import Error
|
||||
|
||||
def insert_feedback(feedbackInfo,feedbackType):
|
||||
try:
|
||||
# 弹出文件选择对话框,选择要插入的图片文件
|
||||
file_path = r'C:\Users\admin\Desktop\garbage.jpg'
|
||||
|
||||
# 读取图片文件并转换为字节数组
|
||||
with open(file_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
# files = {'file': open(file_path, 'rb')}
|
||||
|
||||
# 建立数据库连接
|
||||
database_name = "garbage"
|
||||
username = "root"
|
||||
password = "root"
|
||||
host = '127.0.0.1'
|
||||
port = 3306
|
||||
connection = mysql.connector.connect(
|
||||
host=host,
|
||||
port=port,
|
||||
user=username,
|
||||
password=password,
|
||||
database=database_name
|
||||
)
|
||||
cursor = connection.cursor(prepared=True)
|
||||
# 插入用户反馈信息数据到数据库
|
||||
sql = "INSERT INTO feedback (feedbackInfo, feedbackType, feedbackPhoto) VALUES (%s, %s, %s)"
|
||||
cursor.execute(sql, (feedbackInfo, feedbackType, image_data))
|
||||
connection.commit()
|
||||
print("信息上传成功!")
|
||||
if cursor is not None:
|
||||
cursor.close()
|
||||
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print("未选择文件。")
|
||||
except (Error, IOError) as e:
|
||||
print(e)
|
||||
finally:
|
||||
if 'connection' in locals():
|
||||
if connection.is_connected():
|
||||
connection.close()
|
||||
|
||||
|
||||
#垃圾识别接口
|
||||
app = Flask(__name__)
|
||||
# 保存上传文件的目标目录
|
||||
UPLOAD_FOLDER = 'C:/Users/admin/Desktop'
|
||||
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
||||
|
||||
@app.route('/upload_image', methods=['POST'])
|
||||
def upload_image():
|
||||
if 'file' not in request.files:
|
||||
return 'No file part'
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
if file.filename == '':
|
||||
return 'No selected file'
|
||||
|
||||
if file:
|
||||
# 保存上传的文件到指定目录
|
||||
file.save(os.path.join(app.config['UPLOAD_FOLDER'], 'garbage.jpg'))
|
||||
a = recognition(os.path.join(app.config['UPLOAD_FOLDER'], 'garbage.jpg'))
|
||||
print(a)
|
||||
return a
|
||||
|
||||
@app.route('/upload_text', methods=['POST'])
|
||||
def upload_text():
|
||||
data = request.data.decode('utf-8') # 解码接收到的数据
|
||||
lst = data.split('|')
|
||||
str = ''
|
||||
if lst[0][0]:
|
||||
str+='用户体验'
|
||||
if lst[0][1]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='识别偏差'
|
||||
if lst[0][2]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='功能建议'
|
||||
if lst[0][3]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='分类扩充'
|
||||
if lst[0][4]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='其他问题'
|
||||
if lst[0][5]:
|
||||
if len(str)>0:
|
||||
str+=','
|
||||
str+='好评推荐'
|
||||
|
||||
insert_feedback(str,lst[1])
|
||||
return "Successfully"
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5000)
|
Loading…
Reference in new issue