From de37ea63a079277b18e4cde50561ce5c4ee77b4a Mon Sep 17 00:00:00 2001 From: p4w2aybsf <2363061197@qq.com> Date: Thu, 29 Apr 2021 17:16:57 +0800 Subject: [PATCH] Add 'util.py' --- util.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 util.py diff --git a/util.py b/util.py new file mode 100644 index 0000000..9e0f0b3 --- /dev/null +++ b/util.py @@ -0,0 +1,99 @@ +# coding: utf-8 +import numpy as np + + +def smooth_curve(x): + """用于使损失函数的图形变圆滑 + + 参考:http://glowingpython.blogspot.jp/2012/02/convolution-with-numpy.html + """ + window_len = 11 + s = np.r_[x[window_len-1:0:-1], x, x[-1:-window_len:-1]] + w = np.kaiser(window_len, 2) + y = np.convolve(w/w.sum(), s, mode='valid') + return y[5:len(y)-5] + + +def shuffle_dataset(x, t): + """打乱数据集 + + Parameters + ---------- + x : 训练数据 + t : 监督数据 + + Returns + ------- + x, t : 打乱的训练数据和监督数据 + """ + permutation = np.random.permutation(x.shape[0]) + x = x[permutation,:] if x.ndim == 2 else x[permutation,:,:,:] + t = t[permutation] + + return x, t + +def conv_output_size(input_size, filter_size, stride=1, pad=0): + return (input_size + 2*pad - filter_size) / stride + 1 + + +def im2col(input_data, filter_h, filter_w, stride=1, pad=0): + """ + + Parameters + ---------- + input_data : 由(数据量, 通道, 高, 长)的4维数组构成的输入数据 + filter_h : 滤波器的高 + filter_w : 滤波器的长 + stride : 步幅 + pad : 填充 + + Returns + ------- + col : 2维数组 + """ + N, C, H, W = input_data.shape + out_h = (H + 2*pad - filter_h)//stride + 1 + out_w = (W + 2*pad - filter_w)//stride + 1 + + img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant') + col = np.zeros((N, C, filter_h, filter_w, out_h, out_w)) + + for y in range(filter_h): + y_max = y + stride*out_h + for x in range(filter_w): + x_max = x + stride*out_w + col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride] + + col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1) + return col + + +def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0): + """ + + Parameters + ---------- + col : + input_shape : 输入数据的形状(例:(10, 1, 28, 28)) + filter_h : + filter_w + stride + pad + + Returns + ------- + + """ + N, C, H, W = input_shape + out_h = (H + 2*pad - filter_h)//stride + 1 + out_w = (W + 2*pad - filter_w)//stride + 1 + col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2) + + img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1)) + for y in range(filter_h): + y_max = y + stride*out_h + for x in range(filter_w): + x_max = x + stride*out_w + img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :] + + return img[:, :, pad:H + pad, pad:W + pad] \ No newline at end of file