You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
477 lines
16 KiB
477 lines
16 KiB
5 months ago
|
import bz2
|
||
|
import gzip
|
||
|
import hashlib
|
||
|
import lzma
|
||
|
import os
|
||
|
import os.path
|
||
|
import pathlib
|
||
|
import re
|
||
|
import sys
|
||
|
import tarfile
|
||
|
import urllib
|
||
|
import urllib.error
|
||
|
import urllib.request
|
||
|
import zipfile
|
||
|
from typing import Any, Callable, Dict, IO, Iterable, List, Optional, Tuple, TypeVar, Union
|
||
|
from urllib.parse import urlparse
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch.utils.model_zoo import tqdm
|
||
|
|
||
|
from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
|
||
|
|
||
|
USER_AGENT = "pytorch/vision"
|
||
|
|
||
|
|
||
|
def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
|
||
|
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
|
||
|
with open(filename, "wb") as fh, tqdm(total=response.length) as pbar:
|
||
|
while chunk := response.read(chunk_size):
|
||
|
fh.write(chunk)
|
||
|
pbar.update(len(chunk))
|
||
|
|
||
|
|
||
|
def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
|
||
|
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
|
||
|
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
|
||
|
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
|
||
|
if sys.version_info >= (3, 9):
|
||
|
md5 = hashlib.md5(usedforsecurity=False)
|
||
|
else:
|
||
|
md5 = hashlib.md5()
|
||
|
with open(fpath, "rb") as f:
|
||
|
while chunk := f.read(chunk_size):
|
||
|
md5.update(chunk)
|
||
|
return md5.hexdigest()
|
||
|
|
||
|
|
||
|
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
|
||
|
return md5 == calculate_md5(fpath, **kwargs)
|
||
|
|
||
|
|
||
|
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
|
||
|
if not os.path.isfile(fpath):
|
||
|
return False
|
||
|
if md5 is None:
|
||
|
return True
|
||
|
return check_md5(fpath, md5)
|
||
|
|
||
|
|
||
|
def _get_redirect_url(url: str, max_hops: int = 3) -> str:
|
||
|
initial_url = url
|
||
|
headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
|
||
|
|
||
|
for _ in range(max_hops + 1):
|
||
|
with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
|
||
|
if response.url == url or response.url is None:
|
||
|
return url
|
||
|
|
||
|
url = response.url
|
||
|
else:
|
||
|
raise RecursionError(
|
||
|
f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
|
||
|
)
|
||
|
|
||
|
|
||
|
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
||
|
parts = urlparse(url)
|
||
|
|
||
|
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
||
|
return None
|
||
|
|
||
|
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
||
|
if match is None:
|
||
|
return None
|
||
|
|
||
|
return match.group("id")
|
||
|
|
||
|
|
||
|
def download_url(
|
||
|
url: str,
|
||
|
root: Union[str, pathlib.Path],
|
||
|
filename: Optional[Union[str, pathlib.Path]] = None,
|
||
|
md5: Optional[str] = None,
|
||
|
max_redirect_hops: int = 3,
|
||
|
) -> None:
|
||
|
"""Download a file from a url and place it in root.
|
||
|
|
||
|
Args:
|
||
|
url (str): URL to download file from
|
||
|
root (str): Directory to place downloaded file in
|
||
|
filename (str, optional): Name to save the file under. If None, use the basename of the URL
|
||
|
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||
|
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
|
||
|
"""
|
||
|
root = os.path.expanduser(root)
|
||
|
if not filename:
|
||
|
filename = os.path.basename(url)
|
||
|
fpath = os.fspath(os.path.join(root, filename))
|
||
|
|
||
|
os.makedirs(root, exist_ok=True)
|
||
|
|
||
|
# check if file is already present locally
|
||
|
if check_integrity(fpath, md5):
|
||
|
print("Using downloaded and verified file: " + fpath)
|
||
|
return
|
||
|
|
||
|
if _is_remote_location_available():
|
||
|
_download_file_from_remote_location(fpath, url)
|
||
|
else:
|
||
|
# expand redirect chain if needed
|
||
|
url = _get_redirect_url(url, max_hops=max_redirect_hops)
|
||
|
|
||
|
# check if file is located on Google Drive
|
||
|
file_id = _get_google_drive_file_id(url)
|
||
|
if file_id is not None:
|
||
|
return download_file_from_google_drive(file_id, root, filename, md5)
|
||
|
|
||
|
# download the file
|
||
|
try:
|
||
|
print("Downloading " + url + " to " + fpath)
|
||
|
_urlretrieve(url, fpath)
|
||
|
except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
|
||
|
if url[:5] == "https":
|
||
|
url = url.replace("https:", "http:")
|
||
|
print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
|
||
|
_urlretrieve(url, fpath)
|
||
|
else:
|
||
|
raise e
|
||
|
|
||
|
# check integrity of downloaded file
|
||
|
if not check_integrity(fpath, md5):
|
||
|
raise RuntimeError("File not found or corrupted.")
|
||
|
|
||
|
|
||
|
def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:
|
||
|
"""List all directories at a given root
|
||
|
|
||
|
Args:
|
||
|
root (str): Path to directory whose folders need to be listed
|
||
|
prefix (bool, optional): If true, prepends the path to each result, otherwise
|
||
|
only returns the name of the directories found
|
||
|
"""
|
||
|
root = os.path.expanduser(root)
|
||
|
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
|
||
|
if prefix is True:
|
||
|
directories = [os.path.join(root, d) for d in directories]
|
||
|
return directories
|
||
|
|
||
|
|
||
|
def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:
|
||
|
"""List all files ending with a suffix at a given root
|
||
|
|
||
|
Args:
|
||
|
root (str): Path to directory whose folders need to be listed
|
||
|
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
|
||
|
It uses the Python "str.endswith" method and is passed directly
|
||
|
prefix (bool, optional): If true, prepends the path to each result, otherwise
|
||
|
only returns the name of the files found
|
||
|
"""
|
||
|
root = os.path.expanduser(root)
|
||
|
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
|
||
|
if prefix is True:
|
||
|
files = [os.path.join(root, d) for d in files]
|
||
|
return files
|
||
|
|
||
|
|
||
|
def download_file_from_google_drive(
|
||
|
file_id: str,
|
||
|
root: Union[str, pathlib.Path],
|
||
|
filename: Optional[Union[str, pathlib.Path]] = None,
|
||
|
md5: Optional[str] = None,
|
||
|
):
|
||
|
"""Download a Google Drive file from and place it in root.
|
||
|
|
||
|
Args:
|
||
|
file_id (str): id of file to be downloaded
|
||
|
root (str): Directory to place downloaded file in
|
||
|
filename (str, optional): Name to save the file under. If None, use the id of the file.
|
||
|
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||
|
"""
|
||
|
try:
|
||
|
import gdown
|
||
|
except ModuleNotFoundError:
|
||
|
raise RuntimeError(
|
||
|
"To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
|
||
|
)
|
||
|
|
||
|
root = os.path.expanduser(root)
|
||
|
if not filename:
|
||
|
filename = file_id
|
||
|
fpath = os.fspath(os.path.join(root, filename))
|
||
|
|
||
|
os.makedirs(root, exist_ok=True)
|
||
|
|
||
|
if check_integrity(fpath, md5):
|
||
|
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
|
||
|
return
|
||
|
|
||
|
gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)
|
||
|
|
||
|
if not check_integrity(fpath, md5):
|
||
|
raise RuntimeError("File not found or corrupted.")
|
||
|
|
||
|
|
||
|
def _extract_tar(
|
||
|
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
|
||
|
) -> None:
|
||
|
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
|
||
|
tar.extractall(to_path)
|
||
|
|
||
|
|
||
|
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
|
||
|
".bz2": zipfile.ZIP_BZIP2,
|
||
|
".xz": zipfile.ZIP_LZMA,
|
||
|
}
|
||
|
|
||
|
|
||
|
def _extract_zip(
|
||
|
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
|
||
|
) -> None:
|
||
|
with zipfile.ZipFile(
|
||
|
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
|
||
|
) as zip:
|
||
|
zip.extractall(to_path)
|
||
|
|
||
|
|
||
|
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
|
||
|
".tar": _extract_tar,
|
||
|
".zip": _extract_zip,
|
||
|
}
|
||
|
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
|
||
|
".bz2": bz2.open,
|
||
|
".gz": gzip.open,
|
||
|
".xz": lzma.open,
|
||
|
}
|
||
|
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
|
||
|
".tbz": (".tar", ".bz2"),
|
||
|
".tbz2": (".tar", ".bz2"),
|
||
|
".tgz": (".tar", ".gz"),
|
||
|
}
|
||
|
|
||
|
|
||
|
def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
|
||
|
"""Detect the archive type and/or compression of a file.
|
||
|
|
||
|
Args:
|
||
|
file (str): the filename
|
||
|
|
||
|
Returns:
|
||
|
(tuple): tuple of suffix, archive type, and compression
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: if file has no suffix or suffix is not supported
|
||
|
"""
|
||
|
suffixes = pathlib.Path(file).suffixes
|
||
|
if not suffixes:
|
||
|
raise RuntimeError(
|
||
|
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
|
||
|
)
|
||
|
suffix = suffixes[-1]
|
||
|
|
||
|
# check if the suffix is a known alias
|
||
|
if suffix in _FILE_TYPE_ALIASES:
|
||
|
return (suffix, *_FILE_TYPE_ALIASES[suffix])
|
||
|
|
||
|
# check if the suffix is an archive type
|
||
|
if suffix in _ARCHIVE_EXTRACTORS:
|
||
|
return suffix, suffix, None
|
||
|
|
||
|
# check if the suffix is a compression
|
||
|
if suffix in _COMPRESSED_FILE_OPENERS:
|
||
|
# check for suffix hierarchy
|
||
|
if len(suffixes) > 1:
|
||
|
suffix2 = suffixes[-2]
|
||
|
|
||
|
# check if the suffix2 is an archive type
|
||
|
if suffix2 in _ARCHIVE_EXTRACTORS:
|
||
|
return suffix2 + suffix, suffix2, suffix
|
||
|
|
||
|
return suffix, None, suffix
|
||
|
|
||
|
valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
|
||
|
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
|
||
|
|
||
|
|
||
|
def _decompress(
|
||
|
from_path: Union[str, pathlib.Path],
|
||
|
to_path: Optional[Union[str, pathlib.Path]] = None,
|
||
|
remove_finished: bool = False,
|
||
|
) -> pathlib.Path:
|
||
|
r"""Decompress a file.
|
||
|
|
||
|
The compression is automatically detected from the file name.
|
||
|
|
||
|
Args:
|
||
|
from_path (str): Path to the file to be decompressed.
|
||
|
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
|
||
|
remove_finished (bool): If ``True``, remove the file after the extraction.
|
||
|
|
||
|
Returns:
|
||
|
(str): Path to the decompressed file.
|
||
|
"""
|
||
|
suffix, archive_type, compression = _detect_file_type(from_path)
|
||
|
if not compression:
|
||
|
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
|
||
|
|
||
|
if to_path is None:
|
||
|
to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))
|
||
|
|
||
|
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
|
||
|
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
|
||
|
|
||
|
with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
|
||
|
wfh.write(rfh.read())
|
||
|
|
||
|
if remove_finished:
|
||
|
os.remove(from_path)
|
||
|
|
||
|
return pathlib.Path(to_path)
|
||
|
|
||
|
|
||
|
def extract_archive(
|
||
|
from_path: Union[str, pathlib.Path],
|
||
|
to_path: Optional[Union[str, pathlib.Path]] = None,
|
||
|
remove_finished: bool = False,
|
||
|
) -> Union[str, pathlib.Path]:
|
||
|
"""Extract an archive.
|
||
|
|
||
|
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
|
||
|
but not an archive the call is dispatched to :func:`decompress`.
|
||
|
|
||
|
Args:
|
||
|
from_path (str): Path to the file to be extracted.
|
||
|
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
|
||
|
used.
|
||
|
remove_finished (bool): If ``True``, remove the file after the extraction.
|
||
|
|
||
|
Returns:
|
||
|
(str): Path to the directory the file was extracted to.
|
||
|
"""
|
||
|
|
||
|
def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
|
||
|
if isinstance(from_path, str):
|
||
|
return os.fspath(ret_path)
|
||
|
else:
|
||
|
return ret_path
|
||
|
|
||
|
if to_path is None:
|
||
|
to_path = os.path.dirname(from_path)
|
||
|
|
||
|
suffix, archive_type, compression = _detect_file_type(from_path)
|
||
|
if not archive_type:
|
||
|
ret_path = _decompress(
|
||
|
from_path,
|
||
|
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
|
||
|
remove_finished=remove_finished,
|
||
|
)
|
||
|
return path_or_str(ret_path)
|
||
|
|
||
|
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
|
||
|
extractor = _ARCHIVE_EXTRACTORS[archive_type]
|
||
|
|
||
|
extractor(from_path, to_path, compression)
|
||
|
if remove_finished:
|
||
|
os.remove(from_path)
|
||
|
|
||
|
return path_or_str(pathlib.Path(to_path))
|
||
|
|
||
|
|
||
|
def download_and_extract_archive(
|
||
|
url: str,
|
||
|
download_root: Union[str, pathlib.Path],
|
||
|
extract_root: Optional[Union[str, pathlib.Path]] = None,
|
||
|
filename: Optional[Union[str, pathlib.Path]] = None,
|
||
|
md5: Optional[str] = None,
|
||
|
remove_finished: bool = False,
|
||
|
) -> None:
|
||
|
download_root = os.path.expanduser(download_root)
|
||
|
if extract_root is None:
|
||
|
extract_root = download_root
|
||
|
if not filename:
|
||
|
filename = os.path.basename(url)
|
||
|
|
||
|
download_url(url, download_root, filename, md5)
|
||
|
|
||
|
archive = os.path.join(download_root, filename)
|
||
|
print(f"Extracting {archive} to {extract_root}")
|
||
|
extract_archive(archive, extract_root, remove_finished)
|
||
|
|
||
|
|
||
|
def iterable_to_str(iterable: Iterable) -> str:
|
||
|
return "'" + "', '".join([str(item) for item in iterable]) + "'"
|
||
|
|
||
|
|
||
|
T = TypeVar("T", str, bytes)
|
||
|
|
||
|
|
||
|
def verify_str_arg(
|
||
|
value: T,
|
||
|
arg: Optional[str] = None,
|
||
|
valid_values: Optional[Iterable[T]] = None,
|
||
|
custom_msg: Optional[str] = None,
|
||
|
) -> T:
|
||
|
if not isinstance(value, str):
|
||
|
if arg is None:
|
||
|
msg = "Expected type str, but got type {type}."
|
||
|
else:
|
||
|
msg = "Expected type str for argument {arg}, but got type {type}."
|
||
|
msg = msg.format(type=type(value), arg=arg)
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
if valid_values is None:
|
||
|
return value
|
||
|
|
||
|
if value not in valid_values:
|
||
|
if custom_msg is not None:
|
||
|
msg = custom_msg
|
||
|
else:
|
||
|
msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
|
||
|
msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
return value
|
||
|
|
||
|
|
||
|
def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
|
||
|
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
|
||
|
|
||
|
Args:
|
||
|
file_name (str): Path to the file.
|
||
|
slice_channels (int): Number of channels to slice out of the file.
|
||
|
Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
|
||
|
"""
|
||
|
|
||
|
with open(file_name, "rb") as f:
|
||
|
header = f.readline().rstrip()
|
||
|
if header not in [b"PF", b"Pf"]:
|
||
|
raise ValueError("Invalid PFM file")
|
||
|
|
||
|
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
|
||
|
if not dim_match:
|
||
|
raise Exception("Malformed PFM header.")
|
||
|
w, h = (int(dim) for dim in dim_match.groups())
|
||
|
|
||
|
scale = float(f.readline().rstrip())
|
||
|
if scale < 0: # little-endian
|
||
|
endian = "<"
|
||
|
scale = -scale
|
||
|
else:
|
||
|
endian = ">" # big-endian
|
||
|
|
||
|
data = np.fromfile(f, dtype=endian + "f")
|
||
|
|
||
|
pfm_channels = 3 if header == b"PF" else 1
|
||
|
|
||
|
data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
|
||
|
data = np.flip(data, axis=1) # flip on h dimension
|
||
|
data = data[:slice_channels, :, :]
|
||
|
return data.astype(np.float32)
|
||
|
|
||
|
|
||
|
def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
|
||
|
return (
|
||
|
t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
|
||
|
)
|