parent
86e5127d50
commit
9b75ab2155
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import gzip
|
||||||
|
|
||||||
|
def unzipfile(gzip_path):
|
||||||
|
#定义解压缩文件
|
||||||
|
open_file = open(gzip_path.replace('.gz', ''), 'wb')
|
||||||
|
gz_file = gzip.GzipFile(gzip_path)
|
||||||
|
open_file.write(gz_file.read())
|
||||||
|
gz_file.close()
|
||||||
|
|
||||||
|
def download_progress(url, file_name):
|
||||||
|
res = requests.get(url, stream=True, verify=False)
|
||||||
|
# 获取mnist数据集大小
|
||||||
|
total_size = int(res.headers["Content-Length"])
|
||||||
|
temp_size = 0
|
||||||
|
with open(file_name, "wb+") as f:
|
||||||
|
for chunk in res.iter_content(chunk_size=1024):
|
||||||
|
temp_size += len(chunk)
|
||||||
|
f.write(chunk)
|
||||||
|
f.flush()
|
||||||
|
done = int(100 * temp_size / total_size)
|
||||||
|
# 显示下载进度
|
||||||
|
sys.stdout.write("\r[{}{}] {:.2f}%".format("█" * done, " " * (100 - done), 100 * temp_size / total_size))
|
||||||
|
sys.stdout.flush()
|
||||||
|
print("\n============== {} is already ==============".format(file_name))
|
||||||
|
unzipfile(file_name) # 解压压缩包
|
||||||
|
os.remove(file_name) # 删除压缩包
|
||||||
|
|
||||||
|
def download_dataset():
|
||||||
|
"""从 http://yann.lecun.com/exdb/mnist/ 下载数据集"""
|
||||||
|
print("************** Downloading the MNIST dataset **************")
|
||||||
|
train_path = "./MNIST_Data/train/"
|
||||||
|
test_path = "./MNIST_Data/test/"
|
||||||
|
train_path_check = os.path.exists(train_path)
|
||||||
|
test_path_check = os.path.exists(test_path)
|
||||||
|
if not train_path_check and not test_path_check:
|
||||||
|
os.makedirs(train_path)
|
||||||
|
os.makedirs(test_path)
|
||||||
|
train_url = {"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"}
|
||||||
|
test_url = {"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"}
|
||||||
|
for url in train_url:
|
||||||
|
url_parse = urlparse(url)
|
||||||
|
# 从url分割文件名
|
||||||
|
file_name = os.path.join(train_path, url_parse.path.split('/')[-1])
|
||||||
|
if not os.path.exists(file_name.replace('.gz', '')):
|
||||||
|
download_progress(url, file_name)
|
||||||
|
for url in test_url:
|
||||||
|
url_parse = urlparse(url)
|
||||||
|
# 从url分割文件名
|
||||||
|
file_name = os.path.join(test_path,url_parse.path.split('/')[-1])
|
||||||
|
if not os.path.exists(file_name.replace('.gz', '')):
|
||||||
|
download_progress(url, file_name)
|
Loading…
Reference in new issue