parent
86e5127d50
commit
9b75ab2155
@ -0,0 +1,55 @@
|
||||
import os
|
||||
import sys
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
import gzip
|
||||
|
||||
def unzipfile(gzip_path):
|
||||
#定义解压缩文件
|
||||
open_file = open(gzip_path.replace('.gz', ''), 'wb')
|
||||
gz_file = gzip.GzipFile(gzip_path)
|
||||
open_file.write(gz_file.read())
|
||||
gz_file.close()
|
||||
|
||||
def download_progress(url, file_name):
|
||||
res = requests.get(url, stream=True, verify=False)
|
||||
# 获取mnist数据集大小
|
||||
total_size = int(res.headers["Content-Length"])
|
||||
temp_size = 0
|
||||
with open(file_name, "wb+") as f:
|
||||
for chunk in res.iter_content(chunk_size=1024):
|
||||
temp_size += len(chunk)
|
||||
f.write(chunk)
|
||||
f.flush()
|
||||
done = int(100 * temp_size / total_size)
|
||||
# 显示下载进度
|
||||
sys.stdout.write("\r[{}{}] {:.2f}%".format("█" * done, " " * (100 - done), 100 * temp_size / total_size))
|
||||
sys.stdout.flush()
|
||||
print("\n============== {} is already ==============".format(file_name))
|
||||
unzipfile(file_name) # 解压压缩包
|
||||
os.remove(file_name) # 删除压缩包
|
||||
|
||||
def download_dataset():
|
||||
"""从 http://yann.lecun.com/exdb/mnist/ 下载数据集"""
|
||||
print("************** Downloading the MNIST dataset **************")
|
||||
train_path = "./MNIST_Data/train/"
|
||||
test_path = "./MNIST_Data/test/"
|
||||
train_path_check = os.path.exists(train_path)
|
||||
test_path_check = os.path.exists(test_path)
|
||||
if not train_path_check and not test_path_check:
|
||||
os.makedirs(train_path)
|
||||
os.makedirs(test_path)
|
||||
train_url = {"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"}
|
||||
test_url = {"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"}
|
||||
for url in train_url:
|
||||
url_parse = urlparse(url)
|
||||
# 从url分割文件名
|
||||
file_name = os.path.join(train_path, url_parse.path.split('/')[-1])
|
||||
if not os.path.exists(file_name.replace('.gz', '')):
|
||||
download_progress(url, file_name)
|
||||
for url in test_url:
|
||||
url_parse = urlparse(url)
|
||||
# 从url分割文件名
|
||||
file_name = os.path.join(test_path,url_parse.path.split('/')[-1])
|
||||
if not os.path.exists(file_name.replace('.gz', '')):
|
||||
download_progress(url, file_name)
|
Loading…
Reference in new issue