You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
2.4 KiB
56 lines
2.4 KiB
import os
|
|
import sys
|
|
import requests
|
|
from urllib.parse import urlparse
|
|
import gzip
|
|
|
|
def unzipfile(gzip_path):
|
|
#定义解压缩文件
|
|
open_file = open(gzip_path.replace('.gz', ''), 'wb')
|
|
gz_file = gzip.GzipFile(gzip_path)
|
|
open_file.write(gz_file.read())
|
|
gz_file.close()
|
|
|
|
def download_progress(url, file_name):
|
|
res = requests.get(url, stream=True, verify=False)
|
|
# 获取mnist数据集大小
|
|
total_size = int(res.headers["Content-Length"])
|
|
temp_size = 0
|
|
with open(file_name, "wb+") as f:
|
|
for chunk in res.iter_content(chunk_size=1024):
|
|
temp_size += len(chunk)
|
|
f.write(chunk)
|
|
f.flush()
|
|
done = int(100 * temp_size / total_size)
|
|
# 显示下载进度
|
|
sys.stdout.write("\r[{}{}] {:.2f}%".format("█" * done, " " * (100 - done), 100 * temp_size / total_size))
|
|
sys.stdout.flush()
|
|
print("\n============== {} is already ==============".format(file_name))
|
|
unzipfile(file_name) # 解压压缩包
|
|
os.remove(file_name) # 删除压缩包
|
|
|
|
def download_dataset():
|
|
"""从 http://yann.lecun.com/exdb/mnist/ 下载数据集"""
|
|
print("************** Downloading the MNIST dataset **************")
|
|
train_path = "./MNIST_Data/train/"
|
|
test_path = "./MNIST_Data/test/"
|
|
train_path_check = os.path.exists(train_path)
|
|
test_path_check = os.path.exists(test_path)
|
|
if not train_path_check and not test_path_check:
|
|
os.makedirs(train_path)
|
|
os.makedirs(test_path)
|
|
train_url = {"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"}
|
|
test_url = {"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"}
|
|
for url in train_url:
|
|
url_parse = urlparse(url)
|
|
# 从url分割文件名
|
|
file_name = os.path.join(train_path, url_parse.path.split('/')[-1])
|
|
if not os.path.exists(file_name.replace('.gz', '')):
|
|
download_progress(url, file_name)
|
|
for url in test_url:
|
|
url_parse = urlparse(url)
|
|
# 从url分割文件名
|
|
file_name = os.path.join(test_path,url_parse.path.split('/')[-1])
|
|
if not os.path.exists(file_name.replace('.gz', '')):
|
|
download_progress(url, file_name)
|