510 lines
18 KiB
Python
510 lines
18 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
import errno
|
|
import hashlib
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import torch
|
|
import warnings
|
|
import zipfile
|
|
|
|
if sys.version_info[0] == 2:
|
|
from urlparse import urlparse
|
|
from urllib2 import urlopen # noqa f811
|
|
else:
|
|
from urllib.request import urlopen
|
|
from urllib.parse import urlparse # noqa: F401
|
|
|
|
try:
|
|
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
|
|
except ImportError:
|
|
try:
|
|
from tqdm import tqdm
|
|
except ImportError:
|
|
# fake tqdm if it's not installed
|
|
class tqdm(object):
|
|
|
|
def __init__(self, total=None, disable=False,
|
|
unit=None, unit_scale=None, unit_divisor=None):
|
|
self.total = total
|
|
self.disable = disable
|
|
self.n = 0
|
|
# ignore unit, unit_scale, unit_divisor; they're just for real tqdm
|
|
|
|
def update(self, n):
|
|
if self.disable:
|
|
return
|
|
|
|
self.n += n
|
|
if self.total is None:
|
|
sys.stderr.write("\r{0:.1f} bytes".format(self.n))
|
|
else:
|
|
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
|
|
sys.stderr.flush()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.disable:
|
|
return
|
|
|
|
sys.stderr.write('\n')
|
|
|
|
# matches bfd8deac from resnet18-bfd8deac.pth
|
|
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
|
|
|
|
MASTER_BRANCH = 'master'
|
|
ENV_TORCH_HOME = 'TORCH_HOME'
|
|
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
|
DEFAULT_CACHE_DIR = '~/.cache'
|
|
VAR_DEPENDENCY = 'dependencies'
|
|
MODULE_HUBCONF = 'hubconf.py'
|
|
READ_DATA_CHUNK = 8192
|
|
hub_dir = None
|
|
|
|
|
|
# Copied from tools/shared/module_loader to be included in torch package
|
|
def import_module(name, path):
|
|
if sys.version_info >= (3, 5):
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location(name, path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
elif sys.version_info >= (3, 0):
|
|
from importlib.machinery import SourceFileLoader
|
|
return SourceFileLoader(name, path).load_module()
|
|
else:
|
|
import imp
|
|
return imp.load_source(name, path)
|
|
|
|
|
|
def _remove_if_exists(path):
|
|
if os.path.exists(path):
|
|
if os.path.isfile(path):
|
|
os.remove(path)
|
|
else:
|
|
shutil.rmtree(path)
|
|
|
|
|
|
def _git_archive_link(repo_owner, repo_name, branch):
|
|
return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)
|
|
|
|
|
|
def _load_attr_from_module(module, func_name):
|
|
# Check if callable is defined in the module
|
|
if func_name not in dir(module):
|
|
return None
|
|
return getattr(module, func_name)
|
|
|
|
|
|
def _get_torch_home():
|
|
torch_home = hub_dir
|
|
if torch_home is None:
|
|
torch_home = os.path.expanduser(
|
|
os.getenv(ENV_TORCH_HOME,
|
|
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
|
|
DEFAULT_CACHE_DIR), 'torch')))
|
|
return torch_home
|
|
|
|
|
|
def _setup_hubdir():
|
|
global hub_dir
|
|
# Issue warning to move data if old env is set
|
|
if os.getenv('TORCH_HUB'):
|
|
warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
|
|
|
|
if hub_dir is None:
|
|
torch_home = _get_torch_home()
|
|
hub_dir = os.path.join(torch_home, 'hub')
|
|
|
|
if not os.path.exists(hub_dir):
|
|
os.makedirs(hub_dir)
|
|
|
|
|
|
def _parse_repo_info(github):
|
|
branch = MASTER_BRANCH
|
|
if ':' in github:
|
|
repo_info, branch = github.split(':')
|
|
else:
|
|
repo_info = github
|
|
repo_owner, repo_name = repo_info.split('/')
|
|
return repo_owner, repo_name, branch
|
|
|
|
|
|
def _get_cache_or_reload(github, force_reload, verbose=True):
|
|
# Parse github repo information
|
|
repo_owner, repo_name, branch = _parse_repo_info(github)
|
|
# Github allows branch name with slash '/',
|
|
# this causes confusion with path on both Linux and Windows.
|
|
# Backslash is not allowed in Github branch name so no need to
|
|
# to worry about it.
|
|
normalized_br = branch.replace('/', '_')
|
|
# Github renames folder repo-v1.x.x to repo-1.x.x
|
|
# We don't know the repo name before downloading the zip file
|
|
# and inspect name from it.
|
|
# To check if cached repo exists, we need to normalize folder names.
|
|
repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br]))
|
|
|
|
use_cache = (not force_reload) and os.path.exists(repo_dir)
|
|
|
|
if use_cache:
|
|
if verbose:
|
|
sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
|
|
else:
|
|
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
|
|
_remove_if_exists(cached_file)
|
|
|
|
url = _git_archive_link(repo_owner, repo_name, branch)
|
|
sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file))
|
|
download_url_to_file(url, cached_file, progress=False)
|
|
|
|
with zipfile.ZipFile(cached_file) as cached_zipfile:
|
|
extraced_repo_name = cached_zipfile.infolist()[0].filename
|
|
extracted_repo = os.path.join(hub_dir, extraced_repo_name)
|
|
_remove_if_exists(extracted_repo)
|
|
# Unzip the code and rename the base folder
|
|
cached_zipfile.extractall(hub_dir)
|
|
|
|
_remove_if_exists(cached_file)
|
|
_remove_if_exists(repo_dir)
|
|
shutil.move(extracted_repo, repo_dir) # rename the repo
|
|
|
|
return repo_dir
|
|
|
|
|
|
def _check_module_exists(name):
|
|
if sys.version_info >= (3, 4):
|
|
import importlib.util
|
|
return importlib.util.find_spec(name) is not None
|
|
elif sys.version_info >= (3, 3):
|
|
# Special case for python3.3
|
|
import importlib.find_loader
|
|
return importlib.find_loader(name) is not None
|
|
else:
|
|
# NB: Python2.7 imp.find_module() doesn't respect PEP 302,
|
|
# it cannot find a package installed as .egg(zip) file.
|
|
# Here we use workaround from:
|
|
# https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1
|
|
# Also imp doesn't handle hierarchical module names (names contains dots).
|
|
try:
|
|
# 1. Try imp.find_module(), which searches sys.path, but does
|
|
# not respect PEP 302 import hooks.
|
|
import imp
|
|
result = imp.find_module(name)
|
|
if result:
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
path = sys.path
|
|
for item in path:
|
|
# 2. Scan path for import hooks. sys.path_importer_cache maps
|
|
# path items to optional "importer" objects, that implement
|
|
# find_module() etc. Note that path must be a subset of
|
|
# sys.path for this to work.
|
|
importer = sys.path_importer_cache.get(item)
|
|
if importer:
|
|
try:
|
|
result = importer.find_module(name, [item])
|
|
if result:
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
return False
|
|
|
|
def _check_dependencies(m):
|
|
dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
|
|
|
|
if dependencies is not None:
|
|
missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
|
|
if len(missing_deps):
|
|
raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
|
|
|
|
|
|
def _load_entry_from_hubconf(m, model):
|
|
if not isinstance(model, str):
|
|
raise ValueError('Invalid input: model should be a string of function name')
|
|
|
|
# Note that if a missing dependency is imported at top level of hubconf, it will
|
|
# throw before this function. It's a chicken and egg situation where we have to
|
|
# load hubconf to know what're the dependencies, but to import hubconf it requires
|
|
# a missing package. This is fine, Python will throw proper error message for users.
|
|
_check_dependencies(m)
|
|
|
|
func = _load_attr_from_module(m, model)
|
|
|
|
if func is None or not callable(func):
|
|
raise RuntimeError('Cannot find callable {} in hubconf'.format(model))
|
|
|
|
return func
|
|
|
|
|
|
def set_dir(d):
|
|
r"""
|
|
Optionally set hub_dir to a local dir to save downloaded models & weights.
|
|
|
|
If ``set_dir`` is not called, default path is ``$TORCH_HOME/hub`` where
|
|
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
|
|
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
|
|
filesytem layout, with a default value ``~/.cache`` if the environment
|
|
variable is not set.
|
|
|
|
|
|
Args:
|
|
d (string): path to a local folder to save downloaded models & weights.
|
|
"""
|
|
global hub_dir
|
|
hub_dir = d
|
|
|
|
|
|
def list(github, force_reload=False):
|
|
r"""
|
|
List all entrypoints available in `github` hubconf.
|
|
|
|
Args:
|
|
github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
|
|
tag/branch. The default branch is `master` if not specified.
|
|
Example: 'pytorch/vision[:hub]'
|
|
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
|
|
Default is `False`.
|
|
Returns:
|
|
entrypoints: a list of available entrypoint names
|
|
|
|
Example:
|
|
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
|
|
"""
|
|
# Setup hub_dir to save downloaded files
|
|
_setup_hubdir()
|
|
|
|
repo_dir = _get_cache_or_reload(github, force_reload, True)
|
|
|
|
sys.path.insert(0, repo_dir)
|
|
|
|
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
|
|
|
|
sys.path.remove(repo_dir)
|
|
|
|
# We take functions starts with '_' as internal helper functions
|
|
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
|
|
|
|
return entrypoints
|
|
|
|
|
|
def help(github, model, force_reload=False):
|
|
r"""
|
|
Show the docstring of entrypoint `model`.
|
|
|
|
Args:
|
|
github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
|
|
tag/branch. The default branch is `master` if not specified.
|
|
Example: 'pytorch/vision[:hub]'
|
|
model (string): a string of entrypoint name defined in repo's hubconf.py
|
|
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
|
|
Default is `False`.
|
|
Example:
|
|
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
|
|
"""
|
|
# Setup hub_dir to save downloaded files
|
|
_setup_hubdir()
|
|
|
|
repo_dir = _get_cache_or_reload(github, force_reload, True)
|
|
|
|
sys.path.insert(0, repo_dir)
|
|
|
|
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
|
|
|
|
sys.path.remove(repo_dir)
|
|
|
|
entry = _load_entry_from_hubconf(hub_module, model)
|
|
|
|
return entry.__doc__
|
|
|
|
|
|
# Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`,
|
|
# but Python2 complains syntax error for it. We have to skip force_reload in function
|
|
# signature here but detect it in kwargs instead.
|
|
# TODO: fix it after Python2 EOL
|
|
def load(github, model, *args, **kwargs):
|
|
r"""
|
|
Load a model from a github repo, with pretrained weights.
|
|
|
|
Args:
|
|
github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
|
|
tag/branch. The default branch is `master` if not specified.
|
|
Example: 'pytorch/vision[:hub]'
|
|
model (string): a string of entrypoint name defined in repo's hubconf.py
|
|
*args (optional): the corresponding args for callable `model`.
|
|
force_reload (bool, optional): whether to force a fresh download of github repo unconditionally.
|
|
Default is `False`.
|
|
verbose (bool, optional): If False, mute messages about hitting local caches. Note that the message
|
|
about first download is cannot be muted.
|
|
Default is `True`.
|
|
**kwargs (optional): the corresponding kwargs for callable `model`.
|
|
|
|
Returns:
|
|
a single model with corresponding pretrained weights.
|
|
|
|
Example:
|
|
>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
|
|
"""
|
|
# Setup hub_dir to save downloaded files
|
|
_setup_hubdir()
|
|
|
|
force_reload = kwargs.get('force_reload', False)
|
|
kwargs.pop('force_reload', None)
|
|
verbose = kwargs.get('verbose', True)
|
|
kwargs.pop('verbose', None)
|
|
|
|
repo_dir = _get_cache_or_reload(github, force_reload, verbose)
|
|
|
|
sys.path.insert(0, repo_dir)
|
|
|
|
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
|
|
|
|
entry = _load_entry_from_hubconf(hub_module, model)
|
|
|
|
model = entry(*args, **kwargs)
|
|
|
|
sys.path.remove(repo_dir)
|
|
|
|
return model
|
|
|
|
|
|
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
|
r"""Download object at the given URL to a local path.
|
|
|
|
Args:
|
|
url (string): URL of the object to download
|
|
dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
|
|
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
|
|
Default: None
|
|
progress (bool, optional): whether or not to display a progress bar to stderr
|
|
Default: True
|
|
|
|
Example:
|
|
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
|
|
|
|
"""
|
|
file_size = None
|
|
# We use a different API for python2 since urllib(2) doesn't recognize the CA
|
|
# certificates in older Python
|
|
u = urlopen(url)
|
|
meta = u.info()
|
|
if hasattr(meta, 'getheaders'):
|
|
content_length = meta.getheaders("Content-Length")
|
|
else:
|
|
content_length = meta.get_all("Content-Length")
|
|
if content_length is not None and len(content_length) > 0:
|
|
file_size = int(content_length[0])
|
|
|
|
# We deliberately save it in a temp file and move it after
|
|
# download is complete. This prevents a local working checkpoint
|
|
# being overridden by a broken download.
|
|
dst = os.path.expanduser(dst)
|
|
dst_dir = os.path.dirname(dst)
|
|
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
|
|
|
try:
|
|
if hash_prefix is not None:
|
|
sha256 = hashlib.sha256()
|
|
with tqdm(total=file_size, disable=not progress,
|
|
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
|
while True:
|
|
buffer = u.read(8192)
|
|
if len(buffer) == 0:
|
|
break
|
|
f.write(buffer)
|
|
if hash_prefix is not None:
|
|
sha256.update(buffer)
|
|
pbar.update(len(buffer))
|
|
|
|
f.close()
|
|
if hash_prefix is not None:
|
|
digest = sha256.hexdigest()
|
|
if digest[:len(hash_prefix)] != hash_prefix:
|
|
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
|
|
.format(hash_prefix, digest))
|
|
shutil.move(f.name, dst)
|
|
finally:
|
|
f.close()
|
|
if os.path.exists(f.name):
|
|
os.remove(f.name)
|
|
|
|
def _download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
|
warnings.warn('torch.hub._download_url_to_file has been renamed to\
|
|
torch.hub.download_url_to_file to be a public API,\
|
|
_download_url_to_file will be removed in after 1.3 release')
|
|
download_url_to_file(url, dst, hash_prefix, progress)
|
|
|
|
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False):
|
|
r"""Loads the Torch serialized object at the given URL.
|
|
|
|
If downloaded file is a zip file, it will be automatically
|
|
decompressed.
|
|
|
|
If the object is already present in `model_dir`, it's deserialized and
|
|
returned.
|
|
The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
|
|
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
|
|
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
|
|
filesytem layout, with a default value ``~/.cache`` if not set.
|
|
|
|
Args:
|
|
url (string): URL of the object to download
|
|
model_dir (string, optional): directory in which to save the object
|
|
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
|
|
progress (bool, optional): whether or not to display a progress bar to stderr.
|
|
Default: True
|
|
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
|
|
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
|
digits of the SHA256 hash of the contents of the file. The hash is used to
|
|
ensure unique names and to verify the contents of the file.
|
|
Default: False
|
|
|
|
Example:
|
|
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
|
|
|
|
"""
|
|
# Issue warning to move data if old env is set
|
|
if os.getenv('TORCH_MODEL_ZOO'):
|
|
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
|
|
|
if model_dir is None:
|
|
torch_home = _get_torch_home()
|
|
model_dir = os.path.join(torch_home, 'checkpoints')
|
|
|
|
try:
|
|
os.makedirs(model_dir)
|
|
except OSError as e:
|
|
if e.errno == errno.EEXIST:
|
|
# Directory already exists, ignore.
|
|
pass
|
|
else:
|
|
# Unexpected OSError, re-raise.
|
|
raise
|
|
|
|
parts = urlparse(url)
|
|
filename = os.path.basename(parts.path)
|
|
cached_file = os.path.join(model_dir, filename)
|
|
if not os.path.exists(cached_file):
|
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
|
hash_prefix = HASH_REGEX.search(filename).group(1) if check_hash else None
|
|
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
|
|
|
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
|
|
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
|
|
# E.g. resnet18-5c106cde.pth which is widely used.
|
|
if zipfile.is_zipfile(cached_file):
|
|
with zipfile.ZipFile(cached_file) as cached_zipfile:
|
|
members = cached_zipfile.infolist()
|
|
if len(members) != 1:
|
|
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
|
|
cached_zipfile.extractall(model_dir)
|
|
extraced_name = members[0].filename
|
|
cached_file = os.path.join(model_dir, extraced_name)
|
|
|
|
return torch.load(cached_file, map_location=map_location)
|