mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import gzip
|
|
import os
|
|
import sys
|
|
|
|
try:
|
|
from urllib.error import URLError
|
|
from urllib.request import urlretrieve
|
|
except ImportError:
|
|
from urllib2 import URLError
|
|
from urllib import urlretrieve
|
|
|
|
RESOURCES = [
|
|
'train-images-idx3-ubyte.gz',
|
|
'train-labels-idx1-ubyte.gz',
|
|
't10k-images-idx3-ubyte.gz',
|
|
't10k-labels-idx1-ubyte.gz',
|
|
]
|
|
|
|
|
|
def report_download_progress(chunk_number, chunk_size, file_size):
|
|
if file_size != -1:
|
|
percent = min(1, (chunk_number * chunk_size) / file_size)
|
|
bar = '#' * int(64 * percent)
|
|
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
|
|
|
|
|
|
def download(destination_path, url, quiet):
|
|
if os.path.exists(destination_path):
|
|
if not quiet:
|
|
print('{} already exists, skipping ...'.format(destination_path))
|
|
else:
|
|
print('Downloading {} ...'.format(url))
|
|
try:
|
|
hook = None if quiet else report_download_progress
|
|
urlretrieve(url, destination_path, reporthook=hook)
|
|
except URLError:
|
|
raise RuntimeError('Error downloading resource!')
|
|
finally:
|
|
if not quiet:
|
|
# Just a newline.
|
|
print()
|
|
|
|
|
|
def unzip(zipped_path, quiet):
|
|
unzipped_path = os.path.splitext(zipped_path)[0]
|
|
if os.path.exists(unzipped_path):
|
|
if not quiet:
|
|
print('{} already exists, skipping ... '.format(unzipped_path))
|
|
return
|
|
with gzip.open(zipped_path, 'rb') as zipped_file:
|
|
with open(unzipped_path, 'wb') as unzipped_file:
|
|
unzipped_file.write(zipped_file.read())
|
|
if not quiet:
|
|
print('Unzipped {} ...'.format(zipped_path))
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Download the MNIST dataset from the internet')
|
|
parser.add_argument(
|
|
'-d', '--destination', default='.', help='Destination directory')
|
|
parser.add_argument(
|
|
'-q',
|
|
'--quiet',
|
|
action='store_true',
|
|
help="Don't report about progress")
|
|
options = parser.parse_args()
|
|
|
|
if not os.path.exists(options.destination):
|
|
os.makedirs(options.destination)
|
|
|
|
try:
|
|
for resource in RESOURCES:
|
|
path = os.path.join(options.destination, resource)
|
|
url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource)
|
|
download(path, url, options.quiet)
|
|
unzip(path, options.quiet)
|
|
except KeyboardInterrupt:
|
|
print('Interrupted')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|