forked from pl8qemw3k/garbage
develop
parent
c1c27834ff
commit
646cb3884c
@ -0,0 +1,81 @@
|
||||
'''
|
||||
数据样本分析
|
||||
画出数据量条形图
|
||||
画出图像分辨率散点图
|
||||
'''
|
||||
import os
|
||||
import PIL.Image as Image
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def plot_resolution(dataset_root_path):
|
||||
img_size_list = [] # 存储图片长宽数据
|
||||
for root, dirs, files in os.walk(dataset_root_path):
|
||||
for file_i in files:
|
||||
file_i_full_path = os.path.join(root, file_i)
|
||||
img_i = Image.open(file_i_full_path)
|
||||
img_i_size = img_i.size # 获取单张图像的长宽
|
||||
img_size_list.append(img_i_size)
|
||||
|
||||
print(img_size_list) #
|
||||
|
||||
width_list = [img_size_list[i][0] for i in range(len(img_size_list))]#提取所有图片的宽度信息,构建一个新的列表
|
||||
height_list = [img_size_list[i][1] for i in range(len(img_size_list))]#提取所有图片的高度信息,构建一个新的列表
|
||||
|
||||
# print(width_list) # 640
|
||||
# print(height_list) # 346
|
||||
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 设置中文字体
|
||||
plt.rcParams["font.size"] = 8
|
||||
plt.rcParams["axes.unicode_minus"] = False # 该语句解决图像中的“-”负号的乱码问题
|
||||
|
||||
plt.scatter(width_list, height_list, s=1)
|
||||
plt.xlabel("宽")
|
||||
plt.ylabel("高")
|
||||
plt.title("图像宽高分布")
|
||||
plt.show()
|
||||
|
||||
|
||||
# 画出条形图
|
||||
def plot_bar(dataset_root_path):
|
||||
|
||||
file_name_list = []
|
||||
file_num_list = []
|
||||
for root, dirs, files in os.walk(dataset_root_path):
|
||||
if len(dirs) != 0:
|
||||
for dir_i in dirs:
|
||||
file_name_list.append(dir_i)
|
||||
file_num_list.append(len(files))
|
||||
|
||||
file_num_list = file_num_list[1:]
|
||||
# 求均值,并把均值以横线形式显示出来
|
||||
mean = np.mean(file_num_list)
|
||||
print("mean = ", mean)
|
||||
|
||||
bar_positions = np.arange(len(file_name_list))
|
||||
|
||||
fig, ax = plt.subplots() # 定义画的区间和子画
|
||||
ax.bar(bar_positions, file_num_list, 0.5) # 画柱图,参数:柱间的距离,柱的值,柱的宽度
|
||||
|
||||
ax.plot(bar_positions, [mean for i in bar_positions], color="red") # 显示平均值
|
||||
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 设置中文字体
|
||||
plt.rcParams["font.size"] = 8
|
||||
plt.rcParams["axes.unicode_minus"] = False # 该语句解决图像中的“-”负号的乱码问题
|
||||
|
||||
ax.set_xticks(bar_positions) # 设置x轴的刻度
|
||||
ax.set_xticklabels(file_name_list, rotation=90) # 设置x轴的标签
|
||||
ax.set_ylabel("类别数量")
|
||||
ax.set_title("数据分布图")
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
dataset_root_path = "dataset"
|
||||
|
||||
plot_resolution(dataset_root_path)
|
||||
|
||||
# plot_bar(dataset_root_path)
|
||||
|
||||
|
||||
|
@ -0,0 +1,42 @@
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
dataset_root_path = "dataset"
|
||||
|
||||
min = 200 # 短边
|
||||
max = 2000 # 长边
|
||||
ratio = 0.5 # 短边 / 长边
|
||||
|
||||
delete_list = [] # 所有图片的长宽数据
|
||||
for root,dirs,files in os.walk(dataset_root_path):
|
||||
for file_i in files:
|
||||
file_i_full_path = os.path.join(root, file_i)#构建当前文件的完整路径
|
||||
img_i = Image.open(file_i_full_path)#将其加载为一个图像对象
|
||||
img_i_size = img_i.size # 获取单张图像的长宽
|
||||
|
||||
# 删除单边过短的图片
|
||||
if img_i_size[0]<min or img_i_size[1]<min:
|
||||
print(file_i_full_path, " 不满足要求")
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
# 删除单边过长的图片
|
||||
if img_i_size[0] > max or img_i_size[1] > max:
|
||||
print(file_i_full_path, " 不满足要求")
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
# 删除宽高比例不当的图片
|
||||
long = img_i_size[0] if img_i_size[0] > img_i_size[1] else img_i_size[1]
|
||||
short = img_i_size[0] if img_i_size[0] < img_i_size[1] else img_i_size[1]
|
||||
|
||||
if short / long < ratio:
|
||||
print(file_i_full_path, " 不满足要求",img_i_size[0],img_i_size[1])
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
|
||||
# print(delete_list)
|
||||
for file_i in delete_list:
|
||||
try:
|
||||
print("正在删除",file_i)
|
||||
os.remove(file_i)
|
||||
except:
|
||||
pass
|
@ -0,0 +1,49 @@
|
||||
'''
|
||||
使用翻转进行数据增强
|
||||
'''
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# 水平翻转
|
||||
def Horizontal(image):
|
||||
return cv2.flip(image,1,dst=None) #水平镜像
|
||||
|
||||
# 垂直翻转
|
||||
def Vertical(image):
|
||||
return cv2.flip(image,0,dst=None) #垂直镜像
|
||||
|
||||
if __name__ == '__main__':
|
||||
from_root = r"dataset"
|
||||
save_root = r"enhance_dataset"
|
||||
|
||||
threshold = 200
|
||||
|
||||
for a,b,c in os.walk(from_root):
|
||||
for file_i in c:
|
||||
file_i_path = os.path.join(a,file_i)
|
||||
|
||||
split = os.path.split(file_i_path)
|
||||
dir_loc = os.path.split(split[0])[1]
|
||||
save_path = os.path.join(save_root,dir_loc)
|
||||
|
||||
print(file_i_path)
|
||||
print(save_path)
|
||||
|
||||
if os.path.isdir(save_path) == False:
|
||||
os.makedirs(save_path)
|
||||
|
||||
img_i = cv2.imdecode(np.fromfile(file_i_path, dtype=np.uint8),-1) # 读取图片
|
||||
|
||||
cv2.imencode('.jpg', img_i)[1].tofile(os.path.join(save_path, file_i[:-5] + "_original.jpg")) # 保存图片
|
||||
|
||||
if len(c) < threshold:
|
||||
|
||||
img_horizontal = Horizontal(img_i)
|
||||
cv2.imencode('.jpg', img_horizontal)[1].tofile(os.path.join(save_path, file_i[:-5] + "_horizontal.jpg")) # 保存图片
|
||||
|
||||
img_vertical = Vertical(img_i)
|
||||
cv2.imencode('.jpg', img_vertical)[1].tofile(os.path.join(save_path, file_i[:-5] + "_vertical.jpg")) # 保存图片
|
||||
|
||||
else:
|
||||
pass
|
@ -0,0 +1,23 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
img_root = r"enhance_dataset"
|
||||
threshold = 300
|
||||
|
||||
for a,b,c in os.walk(img_root):
|
||||
if len(c) > threshold:
|
||||
delete_list = []
|
||||
for file_i in c:
|
||||
file_i_full_path = os.path.join(a,file_i)
|
||||
delete_list.append(file_i_full_path)
|
||||
|
||||
random.shuffle(delete_list)
|
||||
|
||||
print(delete_list)
|
||||
delete_list = delete_list[threshold:]
|
||||
for file_delete_i in delete_list:
|
||||
os.remove(file_delete_i)
|
||||
print("将会删除",file_delete_i)
|
||||
|
||||
|
||||
|
@ -0,0 +1,28 @@
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torch
|
||||
from torchvision import transforms as T
|
||||
from tqdm import tqdm
|
||||
|
||||
transform = T.Compose([
|
||||
T.RandomResizedCrop(224),
|
||||
T.ToTensor(),
|
||||
])
|
||||
|
||||
def getStat(train_data):
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
|
||||
|
||||
mean = torch.zeros(3)
|
||||
std = torch.zeros(3)
|
||||
for X, _ in tqdm(train_loader):
|
||||
for d in range(3):
|
||||
mean[d] += X[:, d, :, :].mean() # N, C, H ,W
|
||||
std[d] += X[:, d, :, :].std()
|
||||
mean.div_(len(train_data))
|
||||
std.div_(len(train_data))
|
||||
return list(mean.numpy()), list(std.numpy())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_dataset = ImageFolder(root=r'enhance_dataset', transform=transform)
|
||||
print(getStat(train_dataset))
|
@ -0,0 +1,40 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
train_ratio = 0.9 #训练集比例
|
||||
test_ratio = 1-train_ratio #测试集比例
|
||||
|
||||
rootdata = r"dataset"
|
||||
|
||||
train_list, test_list = [],[]
|
||||
|
||||
class_flag = -1#初始类别标签
|
||||
t=1
|
||||
for a,b,c in os.walk(rootdata):#目录路径,子目录名,文件名
|
||||
# if t==1:
|
||||
# # print(a)
|
||||
# print(b)
|
||||
# print(len(c))
|
||||
# t=-1
|
||||
|
||||
for i in range(0, int(len(c)*train_ratio)):#根据训练集比例确定训练集的文件数量
|
||||
train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n'
|
||||
# print('666'+train_data)
|
||||
train_list.append(train_data)
|
||||
|
||||
for i in range(int(len(c) * train_ratio), len(c)):
|
||||
test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'
|
||||
test_list.append(test_data)
|
||||
|
||||
class_flag += 1
|
||||
|
||||
random.shuffle(train_list)#随机打乱训练集列表 train_list 中的元素顺序
|
||||
random.shuffle(test_list)
|
||||
|
||||
with open('train.txt','w',encoding='UTF-8') as f:
|
||||
for train_img in train_list:
|
||||
f.write(str(train_img))#将训练集中的文件路径和类别标签写入文件 'train.txt',将其转换为字符串形式并写入文件。
|
||||
|
||||
with open('test.txt','w',encoding='UTF-8') as f:
|
||||
for test_img in test_list:
|
||||
f.write(test_img)
|
@ -0,0 +1,74 @@
|
||||
# -*-coding:utf-8-*-
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ReadData(data_loc):
|
||||
epoch_list = []
|
||||
train_loss_list = []
|
||||
test_loss_list = []
|
||||
test_accuracy_list = []
|
||||
|
||||
# open(data_loc,"r").readlines()
|
||||
with open(data_loc, "r") as f:
|
||||
linedata = f.readlines()
|
||||
|
||||
for line_i in linedata:
|
||||
data = line_i.split('\t')
|
||||
print("data = ", data)
|
||||
epoch_i , train_loss_i,test_loss_i,test_accuracy_i =data[1], data[3],data[5],data[7]
|
||||
epoch_list.append(int(epoch_i))
|
||||
train_loss_list.append(float(train_loss_i))
|
||||
test_loss_list.append(float(test_loss_i))
|
||||
test_accuracy_list.append(float(test_accuracy_i))
|
||||
|
||||
# print(epoch_list)
|
||||
# print(train_loss_list)
|
||||
# print(test_loss_list)
|
||||
# print(test_accuracy_list)
|
||||
return epoch_list, train_loss_list ,test_loss_list,test_accuracy_list
|
||||
|
||||
|
||||
|
||||
def DrawLoss(train_loss_list,train_loss_list_2):
|
||||
plt.style.use('dark_background')
|
||||
plt.title("Loss")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("loss")
|
||||
|
||||
train_loss_list = train_loss_list[:10]
|
||||
|
||||
epoch_list = [i for i in range(len(train_loss_list))]
|
||||
|
||||
p1, = plt.plot(epoch_list, train_loss_list, linewidth=3)
|
||||
p2, = plt.plot(epoch_list, train_loss_list_2, linewidth=3)
|
||||
|
||||
plt.legend([p1, p2], ["with pretrain", "no pretrain"])
|
||||
plt.show()
|
||||
|
||||
def DrawAcc(train_loss_list,train_loss_list_2):
|
||||
plt.style.use('dark_background')
|
||||
plt.title("Accuracy")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("accuracy")
|
||||
|
||||
train_loss_list = train_loss_list[:10]
|
||||
|
||||
epoch_list = [i for i in range(len(train_loss_list))]
|
||||
|
||||
p1, = plt.plot(epoch_list, train_loss_list, linewidth=3)
|
||||
p2, = plt.plot(epoch_list, train_loss_list_2, linewidth=3)
|
||||
|
||||
plt.legend([p1, p2], ["with pretrain", "no pretrain"])
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_1_loc = "output/resnet18.txt"
|
||||
data_2_loc = "output/resnet18_no_pretrain.txt"
|
||||
|
||||
_, train_loss_list ,test_loss_list,test_accuracy_list = ReadData(data_1_loc)
|
||||
_, train_loss_list_2 ,test_loss_list_2,test_accuracy_list_2 = ReadData(data_2_loc)
|
||||
|
||||
DrawLoss(train_loss_list,train_loss_list_2)
|
||||
|
||||
DrawAcc(test_accuracy_list,test_accuracy_list_2)
|
Loading…
Reference in new issue