mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: C416: Unnecessary (list/set) comprehension - rewrite using list/set(). See https://pypi.org/project/flake8-comprehensions/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/33429 Differential Revision: D19972858 Pulled By: ezyang fbshipit-source-id: faac042a94c59d737bd5ae983121a0a029346e23
1859 lines
75 KiB
Python
1859 lines
75 KiB
Python
import math
|
|
import sys
|
|
import errno
|
|
import os
|
|
import ctypes
|
|
import torch
|
|
import gc
|
|
import time
|
|
import signal
|
|
import unittest
|
|
import itertools
|
|
import warnings
|
|
from torch import multiprocessing as mp
|
|
from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset
|
|
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
|
from torch.utils.data.dataset import random_split
|
|
from torch._utils import ExceptionWrapper
|
|
from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, PY3,
|
|
IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm,
|
|
load_tests, TEST_WITH_TSAN)
|
|
|
|
try:
|
|
import psutil
|
|
HAS_PSUTIL = True
|
|
except ImportError:
|
|
HAS_PSUTIL = False
|
|
err_msg = ("psutil not found. Some critical data loader tests relying on it "
|
|
"(e.g., TestDataLoader.test_proper_exit) will not run.")
|
|
if IS_PYTORCH_CI:
|
|
raise ImportError(err_msg)
|
|
else:
|
|
warnings.warn(err_msg)
|
|
|
|
try:
|
|
import faulthandler
|
|
HAS_FAULTHANDLER = True
|
|
except ImportError:
|
|
HAS_FAULTHANDLER = False
|
|
err_msg = ("faulthandler not found. Some data loader tests use it for error "
|
|
"reporting (e.g., TestDataLoader.test_proper_exit).")
|
|
if IS_PYTORCH_CI:
|
|
raise ImportError(err_msg)
|
|
else:
|
|
warnings.warn(err_msg)
|
|
|
|
|
|
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
# We cannot import TEST_CUDA from torch.testing._internal.common_cuda here, because if we do that,
|
|
# the TEST_CUDNN line from torch.testing._internal.common_cuda will be executed multiple times
|
|
# as well during the execution of this test suite, and it will cause
|
|
# CUDA OOM error on Windows.
|
|
TEST_CUDA = torch.cuda.is_available()
|
|
|
|
|
|
if not NO_MULTIPROCESSING_SPAWN:
|
|
# We want to use `spawn` if able because some of our tests check that the
|
|
# data loader terminiates gracefully. To prevent hanging in the testing
|
|
# process, such data loaders are run in a separate subprocess.
|
|
#
|
|
# We also want to test the `pin_memory=True` configuration, thus `spawn` is
|
|
# required to launch such processes and they initialize the CUDA context.
|
|
#
|
|
# Mixing different start method is a recipe for disaster (e.g., using a fork
|
|
# `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally
|
|
# to avoid bugs.
|
|
#
|
|
# Get a multiprocessing context because some test / third party library will
|
|
# set start_method when imported, and setting again triggers `RuntimeError`.
|
|
mp = mp.get_context(method='spawn')
|
|
|
|
|
|
# 60s of timeout?
|
|
# Yes, in environments where physical CPU resources are shared, e.g., CI, the
|
|
# time for a inter-process communication can be highly varying. With 15~17s of
|
|
# timeout, we have observed flakiness in some CI builds (see
|
|
# pytorch/pytorch#14501, pytorch/pytorch#16608). We follow the CPython
|
|
# multiprocessing setup and set the timeout to 60s here:
|
|
#
|
|
# https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73
|
|
JOIN_TIMEOUT = 60.0 # seconds
|
|
|
|
|
|
supported_multiprocessing_contexts = [None]
|
|
if torch.multiprocessing._supports_context:
|
|
supported_multiprocessing_contexts += list(torch.multiprocessing.get_all_start_methods())
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestDatasetRandomSplit(TestCase):
|
|
def test_lengths_must_equal_dataset_size(self):
|
|
with self.assertRaises(ValueError):
|
|
random_split([1, 2, 3, 4], [1, 2])
|
|
|
|
def test_splits_have_correct_size(self):
|
|
splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
|
|
self.assertEqual(len(splits), 2)
|
|
self.assertEqual(len(splits[0]), 2)
|
|
self.assertEqual(len(splits[1]), 4)
|
|
|
|
def test_splits_are_mutually_exclusive(self):
|
|
data = [5, 2, 3, 4, 1, 6]
|
|
splits = random_split(data, [2, 4])
|
|
all_values = []
|
|
all_values.extend(list(splits[0]))
|
|
all_values.extend(list(splits[1]))
|
|
data.sort()
|
|
all_values.sort()
|
|
self.assertListEqual(data, all_values)
|
|
|
|
def test_splits_indexing_type(self):
|
|
r"""Indices generated by random_split
|
|
should be of integer type
|
|
"""
|
|
class CustomDataset():
|
|
def __init__(self, test_object, custom_list):
|
|
self.data = custom_list
|
|
self.test_object = test_object
|
|
|
|
def __getitem__(self, key):
|
|
self.test_object.assertEqual(type(key), type(0))
|
|
return self.data[key]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
x = [1, 2, 3, 4, 5]
|
|
dataset = CustomDataset(self, x)
|
|
dataset = random_split(dataset, [5])[0]
|
|
data_loader = DataLoader(dataset)
|
|
for batch in data_loader:
|
|
pass
|
|
|
|
|
|
class CUDACountingDataset(Dataset):
|
|
def __init__(self, n):
|
|
super(CUDACountingDataset, self).__init__()
|
|
self.n = n
|
|
|
|
def __getitem__(self, i):
|
|
return torch.as_tensor(i, device='cuda')
|
|
|
|
def __len__(self):
|
|
return self.n
|
|
|
|
|
|
class CountingDataset(Dataset):
|
|
def __init__(self, n):
|
|
super(CountingDataset, self).__init__()
|
|
self.n = n
|
|
|
|
def __getitem__(self, i):
|
|
return i
|
|
|
|
def __len__(self):
|
|
return self.n
|
|
|
|
|
|
class CountingIterableDataset(IterableDataset):
|
|
def __init__(self, n):
|
|
super(CountingIterableDataset, self).__init__()
|
|
self.n = n
|
|
|
|
def __iter__(self):
|
|
return iter(range(self.n))
|
|
|
|
def __len__(self):
|
|
return self.n
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestTensorDataset(TestCase):
|
|
|
|
def test_len(self):
|
|
source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
|
|
self.assertEqual(len(source), 15)
|
|
|
|
def test_getitem(self):
|
|
t = torch.randn(15, 10, 2, 3, 4, 5)
|
|
l = torch.randn(15, 10)
|
|
source = TensorDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
self.assertEqual(l[i], source[i][1])
|
|
|
|
def test_getitem_1d(self):
|
|
t = torch.randn(15)
|
|
l = torch.randn(15)
|
|
source = TensorDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
self.assertEqual(l[i], source[i][1])
|
|
|
|
def test_single_tensor(self):
|
|
t = torch.randn(5, 10)
|
|
source = TensorDataset(t)
|
|
self.assertEqual(len(source), 5)
|
|
for i in range(5):
|
|
self.assertEqual(t[i], source[i][0])
|
|
|
|
def test_many_tensors(self):
|
|
t0 = torch.randn(5, 10, 2, 3, 4, 5)
|
|
t1 = torch.randn(5, 10)
|
|
t2 = torch.randn(5, 10, 2, 5)
|
|
t3 = torch.randn(5, 10, 3, 7)
|
|
source = TensorDataset(t0, t1, t2, t3)
|
|
self.assertEqual(len(source), 5)
|
|
for i in range(5):
|
|
self.assertEqual(t0[i], source[i][0])
|
|
self.assertEqual(t1[i], source[i][1])
|
|
self.assertEqual(t2[i], source[i][2])
|
|
self.assertEqual(t3[i], source[i][3])
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestConcatDataset(TestCase):
|
|
|
|
def test_concat_two_singletons(self):
|
|
result = ConcatDataset([[0], [1]])
|
|
self.assertEqual(2, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(1, result[1])
|
|
|
|
def test_concat_two_non_singletons(self):
|
|
result = ConcatDataset([[0, 1, 2, 3, 4],
|
|
[5, 6, 7, 8, 9]])
|
|
self.assertEqual(10, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(5, result[5])
|
|
|
|
def test_concat_two_non_singletons_with_empty(self):
|
|
# Adding an empty dataset somewhere is correctly handled
|
|
result = ConcatDataset([[0, 1, 2, 3, 4],
|
|
[],
|
|
[5, 6, 7, 8, 9]])
|
|
self.assertEqual(10, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(5, result[5])
|
|
|
|
def test_concat_raises_index_error(self):
|
|
result = ConcatDataset([[0, 1, 2, 3, 4],
|
|
[5, 6, 7, 8, 9]])
|
|
with self.assertRaises(IndexError):
|
|
# this one goes to 11
|
|
result[11]
|
|
|
|
def test_add_dataset(self):
|
|
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
result = d1 + d2 + d3
|
|
self.assertEqual(21, len(result))
|
|
self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
|
|
self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
|
|
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
|
|
|
|
def test_iterable_dataset_err(self):
|
|
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
it1 = CountingIterableDataset(5)
|
|
it2 = CountingIterableDataset(10)
|
|
|
|
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
|
|
ConcatDataset([d1, it2, it1])
|
|
|
|
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
|
|
ConcatDataset([it2])
|
|
|
|
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
|
|
ConcatDataset([it1, d1])
|
|
|
|
|
|
# takes in dummy var so this can also be used as a `worker_init_fn`
|
|
def set_faulthander_if_available(_=None):
|
|
if HAS_FAULTHANDLER:
|
|
faulthandler.enable(sys.__stderr__)
|
|
if not IS_WINDOWS:
|
|
# windows does not have faulthandler.register
|
|
# chain=False prevents the default behavior of killing the process
|
|
faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False)
|
|
|
|
|
|
set_faulthander_if_available()
|
|
|
|
# Process `pid` must have called `set_faulthander_if_available`
|
|
def print_traces_of_all_threads(pid):
|
|
if HAS_FAULTHANDLER:
|
|
if not IS_WINDOWS:
|
|
# use the custom signal if available
|
|
os.kill(pid, signal.SIGUSR1)
|
|
else:
|
|
# otherwise we can still use the handler given by faulthandler.enable()
|
|
# at the cost of killing the process.
|
|
os.kill(pid, signal.SIGSEGV)
|
|
else:
|
|
# if there is no faulthandler, use SIGINT otherwise and hope for the best
|
|
os.kill(pid, signal.SIGINT)
|
|
# wait in parent process to give subprocess some time to print
|
|
time.sleep(5)
|
|
|
|
|
|
# The following `ErrorTrackingProcess` stores the first encountered exception in
|
|
# its `.exception` attribute.
|
|
# Inspired by https://stackoverflow.com/a/33599967
|
|
class ErrorTrackingProcess(mp.Process):
|
|
|
|
# Why no *args?
|
|
# py2 doesn't support def fn(x, *args, key=val, **kwargs)
|
|
# Setting disable_stderr=True may generate a lot of unrelated error outputs
|
|
# but could be helpful for debugging.
|
|
def __init__(self, disable_stderr=True, **kwargs):
|
|
super(ErrorTrackingProcess, self).__init__(**kwargs)
|
|
self._pconn, self._cconn = mp.Pipe()
|
|
self._exception = None
|
|
self.disable_stderr = disable_stderr
|
|
|
|
def run(self):
|
|
set_faulthander_if_available()
|
|
if self.disable_stderr:
|
|
# Disable polluting stderr with errors that are supposed to happen.
|
|
sys.stderr = open(os.devnull, "w")
|
|
try:
|
|
super(ErrorTrackingProcess, self).run()
|
|
self._cconn.send(None)
|
|
except Exception:
|
|
self._cconn.send(ExceptionWrapper(sys.exc_info()))
|
|
raise
|
|
|
|
def print_traces_of_all_threads(self):
|
|
assert self.is_alive(), "can only use print_traces_of_all_threads if the process is alive"
|
|
assert not self.disable_stderr, "do not disable stderr if you use print_traces_of_all_threads"
|
|
# On platforms without `SIGUSR1`, `set_faulthander_if_available` sets
|
|
# `faulthandler.enable()`, and `print_traces_of_all_threads` may kill
|
|
# the process. So let's poll the exception first
|
|
_ = self.exception
|
|
print_traces_of_all_threads(self.pid)
|
|
|
|
@property
|
|
def exception(self):
|
|
if self._pconn.poll():
|
|
self._exception = self._pconn.recv()
|
|
if self._exception is None:
|
|
return None
|
|
else:
|
|
return self._exception.exc_type(self._exception.exc_msg)
|
|
|
|
# ESRCH means that os.kill can't finds alive proc
|
|
def send_signal(self, signum, ignore_ESRCH=False):
|
|
try:
|
|
os.kill(self.pid, signum)
|
|
except OSError as e:
|
|
if not ignore_ESRCH or e.errno != errno.ESRCH:
|
|
raise
|
|
|
|
|
|
class ErrorDataset(Dataset):
|
|
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SegfaultDataset(Dataset):
|
|
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __getitem__(self, idx):
|
|
return ctypes.string_at(0)
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SleepDataset(Dataset):
|
|
|
|
def __init__(self, size, sleep_sec):
|
|
self.size = size
|
|
self.sleep_sec = sleep_sec
|
|
self.sleeped = False
|
|
|
|
def __getitem__(self, idx):
|
|
if not self.sleeped:
|
|
time.sleep(self.sleep_sec)
|
|
self.sleeped = True
|
|
return idx
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SeedDataset(Dataset):
|
|
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __getitem__(self, idx):
|
|
return torch.initial_seed()
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class WorkerSpecificIterableDataset(IterableDataset):
|
|
def __init__(self, sizes_for_all_workers):
|
|
self.sizes_for_all_workers = sizes_for_all_workers
|
|
|
|
def __iter__(self):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
assert worker_info is not None
|
|
return iter(range(self.sizes_for_all_workers[worker_info.id]))
|
|
|
|
def __len__(self):
|
|
return sum(self.sizes_for_all_workers)
|
|
|
|
|
|
# Inspired by https://stackoverflow.com/a/26703365
|
|
# If all workers will call `sync_once`, they will be blocked until all workers
|
|
# reach the call (i.e., acting like a barrier).
|
|
# This can be used to ensure that each worker at least processes one data.
|
|
class SynchronizedDataset(Dataset):
|
|
|
|
def __init__(self, size, batch_size, num_workers):
|
|
assert size >= num_workers * batch_size
|
|
self.count = mp.Value('i', 0, lock=True)
|
|
self.barrier = mp.Semaphore(0)
|
|
self.num_workers = num_workers
|
|
self.size = size
|
|
|
|
def sync_once(self):
|
|
with self.count.get_lock():
|
|
self.count.value += 1
|
|
if self.count.value == self.num_workers:
|
|
self.barrier.release()
|
|
self.barrier.acquire()
|
|
self.barrier.release()
|
|
|
|
def __getitem__(self, idx):
|
|
raise NotImplementedError
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SynchronizedSeedDataset(SynchronizedDataset):
|
|
def __getitem__(self, idx):
|
|
self.sync_once()
|
|
return torch.initial_seed()
|
|
|
|
|
|
def _test_timeout():
|
|
dataset = SleepDataset(10, 3)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
def _test_timeout_pin_memory():
|
|
dataset = SleepDataset(10, 3)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, pin_memory=True)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
def disable_stderr(worker_id):
|
|
r"""
|
|
Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
|
|
from workers. Since worker signal handler prints with low-level write(),
|
|
this has to be done on OS level via dup.
|
|
|
|
This is used as worker_init_fn for test_segfault.
|
|
"""
|
|
sys.stderr.flush() # flush library buffers that dup2 knows nothing about
|
|
# Can't use a with-block because otherwise the fd will be closed when this
|
|
# function ends.
|
|
devnull = open(os.devnull, 'w')
|
|
os.dup2(devnull.fileno(), sys.stderr.fileno())
|
|
|
|
|
|
def _test_segfault():
|
|
dataset = SegfaultDataset(10)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
class TestProperExitDataset(Dataset):
|
|
def __init__(self, size, error_event):
|
|
self.size = size
|
|
self.error_event = error_event
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
def __getitem__(self, idx):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
if self.error_event is not None and self.error_event.is_set() and \
|
|
worker_info.id == worker_info.num_workers - 1:
|
|
# only error in the last worker
|
|
raise RuntimeError('Worker error')
|
|
return torch.tensor([idx])
|
|
|
|
|
|
class TestProperExitIterableDataset(IterableDataset):
|
|
def __init__(self, size, error_event):
|
|
self.error_event = error_event
|
|
self.size = size
|
|
self.remaining = size
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
if self.error_event is not None and self.error_event.is_set() and \
|
|
worker_info.id == worker_info.num_workers - 1:
|
|
# only error in the last worker
|
|
raise RuntimeError('Worker error')
|
|
self.remaining -= 1
|
|
if self.remaining < 0:
|
|
raise StopIteration
|
|
return torch.tensor(-1000)
|
|
|
|
next = __next__ # py2 compatibility
|
|
|
|
|
|
# See TestDataLoader.test_proper_exit for usage
|
|
def _test_proper_exit(is_iterable_dataset, use_workers, pin_memory, exit_method,
|
|
hold_iter_reference, loader_setup_event, tester_setup_event):
|
|
num_workers = 2 if use_workers else 0
|
|
|
|
if exit_method == 'worker_error' or exit_method == 'worker_kill':
|
|
assert use_workers is True
|
|
|
|
if exit_method == 'worker_error':
|
|
worker_error_event = mp.Event()
|
|
else:
|
|
worker_error_event = None
|
|
|
|
if is_iterable_dataset:
|
|
ds = TestProperExitIterableDataset(7, worker_error_event)
|
|
else:
|
|
ds = TestProperExitDataset(12, worker_error_event)
|
|
|
|
loader = DataLoader(ds, batch_size=1, shuffle=False,
|
|
num_workers=num_workers, pin_memory=pin_memory,
|
|
worker_init_fn=set_faulthander_if_available)
|
|
|
|
error_it = 2
|
|
|
|
if use_workers:
|
|
# 2 is the magical per-worker prefetch number...
|
|
# FIXME: change this after the number becomes configurable.
|
|
if is_iterable_dataset:
|
|
assert len(ds) * num_workers > (error_it + 2 + 1)
|
|
else:
|
|
assert len(loader) > (error_it + 2 + 1) * num_workers
|
|
else:
|
|
if is_iterable_dataset:
|
|
assert len(ds) > error_it + 1
|
|
else:
|
|
assert len(loader) > error_it + 1
|
|
|
|
it = iter(loader)
|
|
if use_workers:
|
|
workers = it._workers
|
|
|
|
def kill_pid(pid):
|
|
psutil_p = psutil.Process(pid)
|
|
psutil_p.kill()
|
|
psutil_p.wait(JOIN_TIMEOUT)
|
|
assert not psutil_p.is_running()
|
|
|
|
for i, _ in enumerate(it):
|
|
if i == 0:
|
|
if not hold_iter_reference:
|
|
del it
|
|
loader_setup_event.set()
|
|
tester_setup_event.wait()
|
|
# ensure that the workers are still alive
|
|
if use_workers:
|
|
for w in workers:
|
|
assert w.is_alive()
|
|
if worker_error_event is not None:
|
|
worker_error_event.set()
|
|
|
|
if i == error_it:
|
|
if exit_method == 'loader_error':
|
|
raise RuntimeError('Loader error')
|
|
elif exit_method == 'loader_kill':
|
|
kill_pid(os.getpid())
|
|
elif exit_method == 'worker_kill':
|
|
kill_pid(workers[-1].pid) # kill last worker
|
|
|
|
if not hold_iter_reference:
|
|
# Tries to trigger the __del__ clean-up rather than the automatic
|
|
# exiting of daemonic children. Technically it should be automatically
|
|
# triggered, but I don't want to rely on the implementation detail of
|
|
# Python gc.
|
|
gc.collect()
|
|
|
|
|
|
class TestWorkerInfoDataset(SynchronizedDataset):
|
|
def __getitem__(self, idx):
|
|
self.sync_once()
|
|
return torch.tensor(self.value)
|
|
|
|
|
|
# Should be used as worker_init_fn with TestWorkerInfoDataset.
|
|
# See _test_get_worker_info below for usage.
|
|
def test_worker_info_init_fn(worker_id):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
assert worker_id == worker_info.id, "worker_init_fn and worker_info should have consistent id"
|
|
assert worker_id < worker_info.num_workers, "worker_init_fn and worker_info should have valid id"
|
|
assert worker_info.seed == torch.initial_seed(), "worker_init_fn and worker_info should have consistent seed"
|
|
dataset = worker_info.dataset
|
|
assert isinstance(dataset, TestWorkerInfoDataset), "worker_info should have correct dataset copy"
|
|
assert not hasattr(dataset, 'value'), "worker_info should have correct dataset copy"
|
|
# test that WorkerInfo attributes are read-only
|
|
try:
|
|
worker_info.id = 3999
|
|
except RuntimeError as e:
|
|
assert str(e) == "Cannot assign attributes to WorkerInfo objects"
|
|
try:
|
|
worker_info.a = 3
|
|
except RuntimeError as e:
|
|
assert str(e) == "Cannot assign attributes to WorkerInfo objects"
|
|
dataset.value = [worker_id, os.getpid()]
|
|
|
|
|
|
def _test_get_worker_info():
|
|
# get_worker_info returns None in main proc
|
|
assert torch.utils.data.get_worker_info() is None
|
|
num_workers = 2
|
|
batch_size = 2
|
|
dataset = TestWorkerInfoDataset(6, batch_size, num_workers)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
worker_init_fn=test_worker_info_init_fn)
|
|
it = iter(dataloader)
|
|
data = []
|
|
for d in it:
|
|
data.append(d)
|
|
worker_pids = [w.pid for w in it._workers]
|
|
data = torch.cat(data, 0)
|
|
for d in data:
|
|
# each `d` is a [worker_id, worker_pid] pair, which is set in
|
|
# test_worker_info_init_fn
|
|
assert d[1] == worker_pids[d[0]]
|
|
# get_worker_info returns None in main proc after data loading
|
|
assert torch.utils.data.get_worker_info() is None
|
|
# main proc dataset was never assigned this attribute
|
|
assert not hasattr(dataset, 'value')
|
|
try:
|
|
_ = dataset[0]
|
|
except AttributeError:
|
|
return
|
|
raise RuntimeError('Expected AttributeError')
|
|
|
|
|
|
# test custom init function
|
|
def init_fn(worker_id):
|
|
torch.manual_seed(12345)
|
|
|
|
|
|
# used with test_error_in_init
|
|
class ErrorIterableDataset(IterableDataset):
|
|
def __iter__(self):
|
|
raise RuntimeError("Error in __iter__")
|
|
|
|
|
|
# used with test_error_in_init
|
|
def error_worker_init_fn(_):
|
|
raise RuntimeError("Error in worker_init_fn")
|
|
|
|
|
|
class BulkLoadingDataset(Dataset):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __getitem__(self, indices):
|
|
assert isinstance(indices, (list, tuple))
|
|
return torch.as_tensor(indices)
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
class BulkLoadingSampler(torch.utils.data.Sampler):
|
|
def __init__(self, dataset, batch_size):
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
|
|
def __iter__(self):
|
|
for x in torch.randperm(len(self.dataset)).split(self.batch_size):
|
|
yield x.tolist()
|
|
|
|
def __len__(self):
|
|
return int(math.ceil(len(self.dataset) / float(self.batch_size)))
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestDataLoader(TestCase):
|
|
|
|
def setUp(self):
|
|
super(TestDataLoader, self).setUp()
|
|
self.data = torch.randn(100, 2, 3, 5)
|
|
self.labels = torch.randperm(50).repeat(2)
|
|
self.dataset = TensorDataset(self.data, self.labels)
|
|
|
|
def _test_sequential(self, loader):
|
|
batch_size = loader.batch_size
|
|
if batch_size is None:
|
|
for idx, (sample, target) in enumerate(loader):
|
|
self.assertEqual(sample, self.data[idx])
|
|
self.assertEqual(target, self.labels[idx])
|
|
self.assertEqual(idx, len(self.dataset) - 1)
|
|
else:
|
|
for i, (sample, target) in enumerate(loader):
|
|
idx = i * batch_size
|
|
self.assertEqual(sample, self.data[idx:idx + batch_size])
|
|
self.assertEqual(target, self.labels[idx:idx + batch_size])
|
|
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
|
|
|
def _test_shuffle(self, loader):
|
|
found_data = {i: 0 for i in range(self.data.size(0))}
|
|
found_labels = {i: 0 for i in range(self.labels.size(0))}
|
|
batch_size = loader.batch_size
|
|
for i, (batch_samples, batch_targets) in enumerate(loader):
|
|
for sample, target in zip(batch_samples, batch_targets):
|
|
for data_point_idx, data_point in enumerate(self.data):
|
|
if data_point.eq(sample).all():
|
|
self.assertFalse(found_data[data_point_idx])
|
|
found_data[data_point_idx] += 1
|
|
break
|
|
self.assertEqual(target, self.labels[data_point_idx])
|
|
found_labels[data_point_idx] += 1
|
|
self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
|
|
self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
|
|
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
|
|
|
def _test_error(self, loader):
|
|
it = iter(loader)
|
|
errors = 0
|
|
while True:
|
|
try:
|
|
next(it)
|
|
except NotImplementedError:
|
|
errors += 1
|
|
except StopIteration:
|
|
self.assertEqual(errors,
|
|
math.ceil(float(len(loader.dataset)) / loader.batch_size))
|
|
return
|
|
|
|
def test_error_in_init(self):
|
|
for num_workers in [0, 2]:
|
|
loader = DataLoader(ErrorIterableDataset(), num_workers=num_workers)
|
|
with self.assertRaisesRegex(RuntimeError, 'Error in __iter__'):
|
|
list(iter(loader))
|
|
|
|
loader = DataLoader(self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn)
|
|
with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'):
|
|
list(iter(loader))
|
|
|
|
def test_invalid_assign_after_init(self):
|
|
dl = DataLoader(self.dataset)
|
|
for attr in ('batch_size', 'sampler', 'batch_sampler', 'drop_last', 'dataset'):
|
|
def fn():
|
|
setattr(dl, attr, {})
|
|
|
|
self.assertRaises(ValueError, fn)
|
|
|
|
def test_sequential_nonbatch(self):
|
|
self._test_sequential(DataLoader(self.dataset, batch_size=None))
|
|
|
|
def test_sequential_batch(self):
|
|
self._test_sequential(DataLoader(self.dataset))
|
|
self._test_sequential(DataLoader(self.dataset, batch_size=2))
|
|
|
|
def test_bulk_loading_nobatch(self):
|
|
n = 35
|
|
bs = 4
|
|
ds = BulkLoadingDataset(n)
|
|
sampler = BulkLoadingSampler(ds, batch_size=4)
|
|
|
|
for num_workers in [0, 4]:
|
|
dl = DataLoader(ds, num_workers=num_workers, batch_size=None, sampler=sampler, pin_memory=TEST_CUDA)
|
|
self.assertFalse(dl._auto_collation)
|
|
samples = list(dl)
|
|
self.assertEqual(samples[0].is_pinned(), TEST_CUDA)
|
|
self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n)))
|
|
|
|
def test_growing_dataset(self):
|
|
dataset = [torch.ones(4) for _ in range(4)]
|
|
dataloader_seq = DataLoader(dataset, shuffle=False)
|
|
dataloader_shuffle = DataLoader(dataset, shuffle=True)
|
|
dataset.append(torch.ones(4))
|
|
self.assertEqual(len(dataloader_seq), 5)
|
|
self.assertEqual(len(dataloader_shuffle), 5)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_sequential_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
|
|
for input, target in loader:
|
|
self.assertTrue(input.is_pinned())
|
|
self.assertTrue(target.is_pinned())
|
|
|
|
def test_multiple_dataloaders(self):
|
|
for multiprocessing_context in supported_multiprocessing_contexts:
|
|
loader1_it = iter(DataLoader(self.dataset, num_workers=1))
|
|
loader2_it = iter(DataLoader(self.dataset, num_workers=2, multiprocessing_context=multiprocessing_context))
|
|
next(loader1_it)
|
|
next(loader1_it)
|
|
next(loader2_it)
|
|
next(loader2_it)
|
|
next(loader1_it)
|
|
next(loader2_it)
|
|
|
|
@unittest.skip("temporarily disable until flaky failures are fixed")
|
|
def test_segfault(self):
|
|
p = ErrorTrackingProcess(target=_test_segfault)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertNotEqual(p.exitcode, 0)
|
|
if IS_WINDOWS:
|
|
self.assertIsInstance(p.exception, OSError)
|
|
self.assertRegex(str(p.exception), r'access violation reading ')
|
|
else:
|
|
self.assertIsInstance(p.exception, RuntimeError)
|
|
self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
|
|
finally:
|
|
p.terminate()
|
|
|
|
def test_timeout(self):
|
|
if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
|
|
# This test runs in a subprocess, which can only initialize CUDA with spawn.
|
|
# _test_timeout_pin_memory with pin_memory=True initializes CUDA when the iterator is
|
|
# constructed.
|
|
targets = (_test_timeout, _test_timeout_pin_memory)
|
|
else:
|
|
targets = (_test_timeout,)
|
|
for target in targets:
|
|
p = ErrorTrackingProcess(target=target)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertNotEqual(p.exitcode, 0)
|
|
self.assertIsInstance(p.exception, RuntimeError)
|
|
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
|
|
finally:
|
|
p.terminate()
|
|
|
|
def test_invalid_ctor_args_combinations(self):
|
|
# general
|
|
with self.assertRaisesRegex(ValueError, "num_workers option should be non-negative"):
|
|
DataLoader(self.dataset, num_workers=-1)
|
|
with self.assertRaisesRegex(ValueError, "timeout option should be non-negative"):
|
|
DataLoader(self.dataset, timeout=-1)
|
|
|
|
|
|
# disable auto-batching
|
|
with self.assertRaisesRegex(ValueError,
|
|
"batch_size=None option disables auto-batching and is mutually exclusive"):
|
|
DataLoader(self.dataset, batch_size=None, shuffle=True)
|
|
with self.assertRaisesRegex(ValueError,
|
|
"batch_size=None option disables auto-batching and is mutually exclusive"):
|
|
DataLoader(self.dataset, batch_size=None, drop_last=True)
|
|
|
|
if torch.multiprocessing._supports_context:
|
|
valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1]
|
|
with self.assertRaisesRegex(ValueError, r"multi-process loading \(num_workers > 0\), but got"):
|
|
DataLoader(self.dataset, num_workers=0, multiprocessing_context=valid_ctx)
|
|
with self.assertRaisesRegex(ValueError, "should specify a valid start method in"):
|
|
DataLoader(self.dataset, num_workers=1, multiprocessing_context='bad')
|
|
with self.assertRaisesRegex(ValueError, "multiprocessing_context option should be a valid context "):
|
|
DataLoader(self.dataset, num_workers=1, multiprocessing_context=object())
|
|
else:
|
|
with self.assertRaisesRegex(ValueError, "multiprocessing_context relies on Python >= 3.4"):
|
|
DataLoader(self.dataset, num_workers=1, multiprocessing_context='fork')
|
|
|
|
# map-style
|
|
sampler = torch.utils.data.SequentialSampler(self.dataset)
|
|
batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False)
|
|
with self.assertRaisesRegex(ValueError, "sampler option is mutually exclusive with shuffle"):
|
|
DataLoader(self.dataset, batch_size=11, sampler=sampler, shuffle=True)
|
|
with self.assertRaisesRegex(ValueError, "sampler option is mutually exclusive with shuffle"):
|
|
DataLoader(self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True)
|
|
with self.assertRaisesRegex(ValueError, "sampler option is mutually exclusive with shuffle"):
|
|
DataLoader(self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3)
|
|
with self.assertRaisesRegex(ValueError, "batch_sampler option is mutually exclusive with"):
|
|
DataLoader(self.dataset, batch_size=11, batch_sampler=batch_sampler)
|
|
with self.assertRaisesRegex(ValueError, "batch_sampler option is mutually exclusive with"):
|
|
DataLoader(self.dataset, shuffle=True, batch_sampler=batch_sampler)
|
|
with self.assertRaisesRegex(ValueError, "batch_sampler option is mutually exclusive with"):
|
|
DataLoader(self.dataset, drop_last=True, batch_sampler=batch_sampler)
|
|
with self.assertRaisesRegex(ValueError, "batch_sampler option is mutually exclusive with"):
|
|
DataLoader(self.dataset, drop_last=3, batch_sampler=batch_sampler)
|
|
|
|
# iterable-style
|
|
dataset = CountingIterableDataset(20)
|
|
with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"):
|
|
DataLoader(dataset, shuffle=True)
|
|
with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"):
|
|
DataLoader(dataset, shuffle=3)
|
|
with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified sampler"):
|
|
DataLoader(dataset, sampler=torch.utils.data.SequentialSampler(dataset))
|
|
with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified sampler"):
|
|
DataLoader(dataset, sampler=3)
|
|
with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified batch_sampler"):
|
|
DataLoader(dataset, batch_sampler=torch.utils.data.BatchSampler(
|
|
torch.utils.data.SequentialSampler(dataset), 3, False))
|
|
with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified batch_sampler"):
|
|
DataLoader(dataset, batch_sampler=3)
|
|
|
|
def test_builtin_collection_conversion(self):
|
|
for coll_ty in (list, tuple):
|
|
for num_workers in (0, 1):
|
|
# map-style dataset
|
|
dataset = CountingDataset(20)
|
|
# no auto-batching
|
|
fetched = coll_ty(DataLoader(dataset, batch_size=None, num_workers=num_workers))
|
|
self.assertEqual(fetched, coll_ty(range(20)))
|
|
# auto-batching
|
|
fetched = coll_ty(DataLoader(dataset, batch_size=2, num_workers=num_workers))
|
|
self.assertEqual(fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2)))
|
|
|
|
# iterable-style dataset
|
|
dataset = CountingIterableDataset(20)
|
|
# no auto-batching
|
|
fetched = coll_ty(DataLoader(dataset, batch_size=None, num_workers=num_workers))
|
|
self.assertEqual(fetched, coll_ty(range(20)))
|
|
# auto-batching
|
|
# this IterableDataset isn't configured for each worker, so for
|
|
# the equality test below to be valid, we cannot have more than 1 workers.
|
|
assert num_workers in [0, 1], "invalid test"
|
|
fetched = coll_ty(DataLoader(dataset, batch_size=2, num_workers=num_workers))
|
|
self.assertEqual(fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2)))
|
|
|
|
def test_iterable_style_dataset(self):
|
|
# [no auto-batching] single process loading
|
|
dataset = CountingIterableDataset(20)
|
|
dataloader = DataLoader(dataset, batch_size=None)
|
|
fetched = list(dataloader)
|
|
self.assertEqual(len(fetched), 20)
|
|
for i, d in enumerate(fetched):
|
|
# non-batched should not convert ints into tensors
|
|
self.assertIsInstance(d, torch._six.int_classes)
|
|
self.assertEqual(d, i)
|
|
# DataLoader should match len of the iterable-style dataset (if implemented)
|
|
self.assertEqual(len(dataloader), len(dataset))
|
|
|
|
# [no auto-batching] multiprocessing loading
|
|
num_workers = 3
|
|
sizes_for_all_workers = [0, 4, 20]
|
|
expected = sorted(sum((list(range(s)) for s in sizes_for_all_workers), []))
|
|
assert len(sizes_for_all_workers) == num_workers, 'invalid test case'
|
|
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
|
|
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=None,
|
|
worker_init_fn=set_faulthander_if_available)
|
|
dataloader_iter = iter(dataloader)
|
|
fetched = sorted(dataloader_iter)
|
|
for a, b in zip(fetched, expected):
|
|
# non-batched should not convert ints into tensors
|
|
self.assertIsInstance(a, torch._six.int_classes)
|
|
self.assertEqual(a, b)
|
|
# DataLoader should match len of the iterable-style dataset (if implemented)
|
|
self.assertEqual(len(dataloader), len(dataset))
|
|
# When loading more than len(dataset) data, after accessing len(dataloader),
|
|
# we should get a warning. See NOTE [ IterableDataset and __len__ ].
|
|
dataset = CountingIterableDataset(20)
|
|
dataloader = DataLoader(dataset, num_workers=num_workers,
|
|
worker_init_fn=set_faulthander_if_available)
|
|
it = iter(dataloader)
|
|
for _ in range(40):
|
|
self.assertNotWarn(lambda: next(it), "Should not warn before accessing len(dataloader)")
|
|
self.assertEqual(len(dataloader), len(dataset))
|
|
self.assertEqual(len(dataloader), 20)
|
|
it = iter(dataloader)
|
|
for _ in range(20):
|
|
self.assertNotWarn(lambda: next(it), "Should not warn before exceeding length")
|
|
for _ in range(3):
|
|
self.assertWarnsRegex(
|
|
lambda: next(it),
|
|
r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this",
|
|
"Should always warn after exceeding length")
|
|
|
|
# [no auto-batching] test that workers exit gracefully
|
|
workers = dataloader_iter._workers
|
|
del dataloader_iter
|
|
try:
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive())
|
|
self.assertEqual(w.exitcode, 0)
|
|
finally:
|
|
for w in workers:
|
|
w.terminate()
|
|
|
|
# [auto-batching] single process loading
|
|
dataset = CountingIterableDataset(20)
|
|
fetched = list(DataLoader(dataset, batch_size=7))
|
|
self.assertEqual(len(fetched), 3)
|
|
self.assertEqual(fetched[0].tolist(), list(range(7)))
|
|
self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
|
|
self.assertEqual(fetched[2].tolist(), list(range(14, 20)))
|
|
|
|
# [auto-batching] multiprocessing loading
|
|
num_workers = 3
|
|
sizes_for_all_workers = [0, 4, 20]
|
|
expected = sorted(sum((list(range(s)) for s in sizes_for_all_workers), []))
|
|
assert len(sizes_for_all_workers) == num_workers, 'invalid test case'
|
|
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
|
|
# worker 0 should return 0 batches
|
|
# worker 1 should return 1 batches
|
|
# worker 2 should return 3 batches
|
|
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=7)
|
|
dataloader_iter = iter(dataloader)
|
|
fetched = list(dataloader_iter)
|
|
self.assertEqual(len(fetched), 4)
|
|
fetched = set(tuple(t.tolist()) for t in fetched)
|
|
self.assertEqual(fetched, {tuple(range(4)), tuple(range(7)), tuple(range(7, 14)), tuple(range(14, 20))})
|
|
|
|
# [auto-batching] test that workers exit gracefully
|
|
workers = dataloader_iter._workers
|
|
del dataloader_iter
|
|
try:
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive())
|
|
self.assertEqual(w.exitcode, 0)
|
|
finally:
|
|
for w in workers:
|
|
w.terminate()
|
|
|
|
# [auto-batching & drop_last] single process loading
|
|
dataset = CountingIterableDataset(20)
|
|
fetched = list(DataLoader(dataset, batch_size=7, drop_last=True))
|
|
self.assertEqual(len(fetched), 2)
|
|
self.assertEqual(fetched[0].tolist(), list(range(7)))
|
|
self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
|
|
|
|
# [auto-batching & drop_last] multiprocessing loading
|
|
num_workers = 3
|
|
sizes_for_all_workers = [0, 4, 20]
|
|
expected = sorted(sum((list(range(s)) for s in sizes_for_all_workers), []))
|
|
assert len(sizes_for_all_workers) == num_workers, 'invalid test case'
|
|
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
|
|
# worker 0 should return 0 batches
|
|
# worker 1 should return 1 batches
|
|
# worker 2 should return 3 batches
|
|
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=7, drop_last=True,
|
|
worker_init_fn=set_faulthander_if_available)
|
|
dataloader_iter = iter(dataloader)
|
|
fetched = list(dataloader_iter)
|
|
self.assertEqual(len(fetched), 2)
|
|
fetched = set(tuple(t.tolist()) for t in fetched)
|
|
self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
|
|
|
|
# [auto-batching & drop_last] test that workers exit gracefully
|
|
workers = dataloader_iter._workers
|
|
del dataloader_iter
|
|
try:
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive())
|
|
self.assertEqual(w.exitcode, 0)
|
|
finally:
|
|
for w in workers:
|
|
w.terminate()
|
|
|
|
def test_chain_iterable_style_dataset(self):
|
|
# chaining (concatenation)
|
|
dataset1 = CountingIterableDataset(20)
|
|
dataset2 = CountingIterableDataset(15)
|
|
expected = list(range(20)) + list(range(15))
|
|
for num_workers in [0, 1]:
|
|
for chained_dataset in [dataset1 + dataset2, ChainDataset([dataset1, dataset2])]:
|
|
fetched = list(DataLoader(chained_dataset, num_workers=num_workers))
|
|
self.assertEqual(len(fetched), len(expected))
|
|
for e, d in zip(expected, fetched):
|
|
self.assertIsInstance(d, torch.Tensor)
|
|
self.assertEqual(e, d)
|
|
|
|
with self.assertRaisesRegex(AssertionError, "ChainDataset only supports IterableDataset"):
|
|
list(iter(dataset1 + self.dataset))
|
|
|
|
with self.assertRaisesRegex(AssertionError, "ChainDataset only supports IterableDataset"):
|
|
list(iter(ChainDataset([dataset1, self.dataset])))
|
|
|
|
def test_multiprocessing_contexts(self):
|
|
reference = [
|
|
torch.arange(3),
|
|
torch.arange(3, 6),
|
|
torch.arange(6, 9),
|
|
torch.arange(9, 11),
|
|
]
|
|
counting_ds_n = 11
|
|
dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA))
|
|
for ctx in supported_multiprocessing_contexts:
|
|
if ctx in ['spawn', 'forkserver'] and TEST_CUDA and not IS_WINDOWS: # windows doesn't support sharing cuda tensor
|
|
dl_cls = CUDACountingDataset
|
|
else:
|
|
ds_cls = CountingDataset
|
|
self.assertEqual(
|
|
reference, list(DataLoader(ds_cls(counting_ds_n), multiprocessing_context=ctx, **dl_common_args)))
|
|
if ctx is not None:
|
|
# test ctx object
|
|
ctx = mp.get_context(ctx)
|
|
self.assertEqual(
|
|
reference, list(DataLoader(ds_cls(counting_ds_n), multiprocessing_context=ctx, **dl_common_args)))
|
|
|
|
def test_worker_seed(self):
|
|
num_workers = 6
|
|
batch_size = 1
|
|
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
|
|
seeds = set()
|
|
for batch in dataloader:
|
|
seeds.add(batch[0])
|
|
self.assertEqual(len(seeds), num_workers)
|
|
|
|
def test_worker_init_fn(self):
|
|
dataset = SeedDataset(4)
|
|
dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
|
|
worker_init_fn=init_fn)
|
|
for batch in dataloader:
|
|
self.assertEqual(12345, batch[0])
|
|
self.assertEqual(12345, batch[1])
|
|
|
|
def test_get_worker_info(self):
|
|
p = ErrorTrackingProcess(target=_test_get_worker_info)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertEqual(p.exitcode, 0)
|
|
finally:
|
|
p.terminate()
|
|
|
|
def test_shuffle(self):
|
|
self._test_shuffle(DataLoader(self.dataset, shuffle=True))
|
|
|
|
def test_shuffle_batch(self):
|
|
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True))
|
|
|
|
def test_sequential_workers(self):
|
|
self._test_sequential(DataLoader(self.dataset, num_workers=4))
|
|
|
|
def test_seqential_batch_workers(self):
|
|
self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4))
|
|
|
|
def test_shuffle_workers(self):
|
|
self._test_shuffle(DataLoader(self.dataset, shuffle=True, num_workers=4))
|
|
|
|
def test_shuffle_batch_workers(self):
|
|
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))
|
|
|
|
def test_RandomSampler(self):
|
|
|
|
from collections import Counter
|
|
from torch.utils.data import RandomSampler
|
|
|
|
def sample_stat(sampler, num_samples):
|
|
counts = Counter(sampler)
|
|
count_repeated = sum(val > 1 for val in counts.values())
|
|
return (count_repeated, min(counts.keys()), max(counts.keys()))
|
|
|
|
# test sample with replacement
|
|
n = len(self.dataset) + 1 # ensure at least one sample is drawn more than once
|
|
sampler_with_replacement = RandomSampler(self.dataset, replacement=True, num_samples=n)
|
|
count_repeated, minval, maxval = sample_stat(sampler_with_replacement, n)
|
|
self.assertTrue(count_repeated > 0)
|
|
self.assertTrue(minval >= 0)
|
|
self.assertTrue(maxval < len(self.dataset))
|
|
|
|
# test sample without replacement
|
|
sampler_without_replacement = RandomSampler(self.dataset)
|
|
count_repeated, minval, maxval = sample_stat(sampler_without_replacement, len(self.dataset))
|
|
self.assertTrue(count_repeated == 0)
|
|
self.assertTrue(minval == 0)
|
|
self.assertTrue(maxval == len(self.dataset) - 1)
|
|
|
|
# raise error when replacement=False and num_samples is not None
|
|
self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=len(self.dataset)))
|
|
|
|
self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=0))
|
|
|
|
def test_random_sampler_len_with_replacement(self):
|
|
from torch.utils.data import RandomSampler
|
|
# add 5 extra samples
|
|
num_samples = len(self.dataset) + 5
|
|
sampler = RandomSampler(self.dataset,
|
|
replacement=True,
|
|
num_samples=num_samples)
|
|
# test len method
|
|
self.assertEqual(num_samples, len(sampler))
|
|
|
|
# test with iteration
|
|
count_num_samples = sum(1 for _ in sampler)
|
|
self.assertEqual(num_samples, count_num_samples)
|
|
|
|
# test with dataloader, batch_size = 1
|
|
batch_size = 1
|
|
count_num_samples_in_data_loader = len(DataLoader(
|
|
self.dataset, batch_size=batch_size, sampler=sampler))
|
|
self.assertEqual(num_samples, count_num_samples_in_data_loader)
|
|
|
|
# test with dataloader, batch_size = 6
|
|
batch_size = 6
|
|
count_num_samples_in_data_loader = len(DataLoader(
|
|
self.dataset, batch_size=batch_size, sampler=sampler))
|
|
self.assertEqual(int(math.ceil(float(num_samples) / batch_size)),
|
|
count_num_samples_in_data_loader)
|
|
|
|
def test_duplicating_data_with_drop_last(self):
|
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
num_processes = 4
|
|
num_batches = 9
|
|
data_set = torch.IntTensor(range(num_batches))
|
|
scanned_data = torch.IntTensor([])
|
|
for i in range(num_processes):
|
|
s = DistributedSampler(data_set, num_processes, i)
|
|
d_loader = DataLoader(data_set, batch_size=int(num_batches / num_processes), drop_last=True, sampler=s)
|
|
for data in d_loader:
|
|
scanned_data = torch.cat((scanned_data, data), 0)
|
|
|
|
self.assertEqual(scanned_data.size(), scanned_data.unique().size())
|
|
|
|
def _test_batch_sampler(self, **kwargs):
|
|
# [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
|
|
batches = []
|
|
for i in range(0, 20, 5):
|
|
batches.append(tuple(range(i, i + 2)))
|
|
batches.append(tuple(range(i + 2, i + 5)))
|
|
|
|
dl = DataLoader(self.dataset, batch_sampler=batches, **kwargs)
|
|
self.assertEqual(len(dl), 8)
|
|
for i, (input, _target) in enumerate(dl):
|
|
if i % 2 == 0:
|
|
offset = i * 5 // 2
|
|
self.assertEqual(len(input), 2)
|
|
self.assertEqual(input, self.data[offset:offset + 2])
|
|
else:
|
|
offset = i * 5 // 2
|
|
self.assertEqual(len(input), 3)
|
|
self.assertEqual(input, self.data[offset:offset + 3])
|
|
|
|
def test_batch_sampler(self):
|
|
self._test_batch_sampler()
|
|
self._test_batch_sampler(num_workers=4)
|
|
if not NO_MULTIPROCESSING_SPAWN and torch.multiprocessing._supports_context:
|
|
self._test_batch_sampler(num_workers=4, multiprocessing_context='spawn')
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_shuffle_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
|
|
for input, target in loader:
|
|
self.assertTrue(input.is_pinned())
|
|
self.assertTrue(target.is_pinned())
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_numpy(self):
|
|
import numpy as np
|
|
|
|
class TestDataset(torch.utils.data.Dataset):
|
|
def __getitem__(self, i):
|
|
return np.ones((2, 3, 4)) * i
|
|
|
|
def __len__(self):
|
|
return 1000
|
|
|
|
loader = DataLoader(TestDataset(), batch_size=12)
|
|
batch = next(iter(loader))
|
|
self.assertIsInstance(batch, torch.DoubleTensor)
|
|
self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
|
|
|
|
def test_error(self):
|
|
self._test_error(DataLoader(ErrorDataset(100), batch_size=2, shuffle=True))
|
|
|
|
def test_error_workers(self):
|
|
self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
|
|
def test_partial_workers(self):
|
|
r"""Check that workers exit even if the iterator is not exhausted."""
|
|
if TEST_CUDA:
|
|
pin_memory_configs = (True, False)
|
|
else:
|
|
pin_memory_configs = (False,)
|
|
|
|
for pin_memory in pin_memory_configs:
|
|
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
|
|
workers = loader._workers
|
|
if pin_memory:
|
|
pin_memory_thread = loader._pin_memory_thread
|
|
for i, _ in enumerate(loader):
|
|
if i == 10:
|
|
break
|
|
assert i == 10
|
|
del loader
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive(), 'subprocess not terminated')
|
|
if pin_memory:
|
|
pin_memory_thread.join(JOIN_TIMEOUT)
|
|
self.assertFalse(pin_memory_thread.is_alive())
|
|
|
|
@skipIfRocm
|
|
@unittest.skipIf(not HAS_PSUTIL, "psutil not found")
|
|
def test_proper_exit(self):
|
|
(r'''There might be ConnectionResetError or leaked semaphore warning '''
|
|
r'''(due to dirty process exit), but they are all safe to ignore''')
|
|
|
|
# TODO: test the case where the pin_memory_thread triggers an
|
|
# error/fatal signal. I haven't found out how to properly do that.
|
|
|
|
for is_iterable_dataset, use_workers, pin_memory, hold_iter_reference in \
|
|
itertools.product([True, False], repeat=4):
|
|
|
|
# `hold_iter_reference` specifies whether we hold a reference to the
|
|
# iterator. This is interesting because Python3 error traces holds a
|
|
# reference to the frames, which hold references to all the local
|
|
# variables including the iterator, and then the iterator dtor may
|
|
# not be called before process end. It is important to see that the
|
|
# processes still exit in both cases.
|
|
|
|
if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS):
|
|
# This test runs in a subprocess, which can only initialize CUDA with spawn.
|
|
# DataLoader with pin_memory=True initializes CUDA when its iterator is constructed.
|
|
# For windows, pin_memory sometimes causes CUDA oom.
|
|
continue
|
|
|
|
# `exit_method` controls the way the loader process ends.
|
|
# - `*_kill` means that `*` is killed by OS.
|
|
# - `*_error` means that `*` raises an error.
|
|
# - `None` means that no error happens.
|
|
# In all cases, all processes should end properly.
|
|
if use_workers:
|
|
exit_methods = [None, 'loader_error', 'loader_kill', 'worker_error', 'worker_kill']
|
|
else:
|
|
exit_methods = [None, 'loader_error', 'loader_kill']
|
|
|
|
for exit_method in exit_methods:
|
|
if exit_method == 'worker_kill':
|
|
# FIXME: This sometimes hangs. See #16608.
|
|
continue
|
|
|
|
desc = []
|
|
desc.append('is_iterable_dataset={}'.format(is_iterable_dataset))
|
|
desc.append('use_workers={}'.format(use_workers))
|
|
desc.append('pin_memory={}'.format(pin_memory))
|
|
desc.append('hold_iter_reference={}'.format(hold_iter_reference))
|
|
desc.append('exit_method={}'.format(exit_method))
|
|
desc = 'test_proper_exit with ' + ', '.join(desc)
|
|
|
|
# Event that the loader process uses to signal testing process
|
|
# that various things are setup, including that the worker pids
|
|
# are specified in `worker_pids` array.
|
|
loader_setup_event = mp.Event()
|
|
|
|
# Event that this process has finished setting up, and the
|
|
# loader process can now proceed to trigger error events or
|
|
# finish normally.
|
|
tester_setup_event = mp.Event()
|
|
|
|
loader_p = ErrorTrackingProcess(target=_test_proper_exit,
|
|
args=(is_iterable_dataset, use_workers, pin_memory,
|
|
exit_method, hold_iter_reference,
|
|
loader_setup_event, tester_setup_event),
|
|
disable_stderr=False)
|
|
loader_p.start()
|
|
loader_psutil_p = psutil.Process(loader_p.pid)
|
|
|
|
# Wait for loader process to set everything up, e.g., starting
|
|
# workers.
|
|
loader_setup_event.wait(timeout=JOIN_TIMEOUT)
|
|
if not loader_setup_event.is_set():
|
|
fail_msg = desc + ': loader process failed to setup within given time'
|
|
if loader_p.exception is not None:
|
|
fail_msg += ', and had exception {}'.format(loader_p.exception)
|
|
elif not loader_p.is_alive():
|
|
fail_msg += ', and exited with code {} but had no exception'.format(loader_p.exitcode)
|
|
else:
|
|
fail_msg += ', and is still alive.'
|
|
if loader_p.is_alive():
|
|
# this may kill the process, needs to run after the above lines
|
|
loader_p.print_traces_of_all_threads()
|
|
self.fail(fail_msg)
|
|
|
|
# We are certain that the workers have started now.
|
|
worker_psutil_ps = loader_psutil_p.children()
|
|
|
|
def fail(reason):
|
|
report_psutil_attrs = ['pid', 'name', 'cpu_times', 'io_counters',
|
|
'memory_full_info', 'num_ctx_switches',
|
|
'open_files', 'threads', 'status',
|
|
'nice', 'ionice']
|
|
if reason is None:
|
|
err_msg = desc
|
|
else:
|
|
err_msg = '{}: {}'.format(desc, reason)
|
|
err_msg += '\nLoader info:\n\t'
|
|
if loader_psutil_p.is_running():
|
|
err_msg += str(loader_psutil_p.as_dict(attrs=report_psutil_attrs))
|
|
# this may kill the process, needs to run after the above line
|
|
loader_p.print_traces_of_all_threads()
|
|
else:
|
|
err_msg += 'exited with code {}'.format(loader_p.exitcode)
|
|
if use_workers:
|
|
err_msg += '\nWorker(s) info:'
|
|
for idx, worker_psutil_p in enumerate(worker_psutil_ps):
|
|
err_msg += '\n\tWorker {}:\n\t\t'.format(idx)
|
|
if worker_psutil_p.is_running():
|
|
err_msg += str(worker_psutil_p.as_dict(attrs=report_psutil_attrs))
|
|
# this may kill the process, needs to run after the above line
|
|
print_traces_of_all_threads(worker_psutil_p.pid)
|
|
else:
|
|
err_msg += 'exited with unknown code'
|
|
self.fail(err_msg)
|
|
|
|
tester_setup_event.set()
|
|
|
|
try:
|
|
loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
|
|
if loader_p.is_alive():
|
|
fail_reason = 'loader process did not terminate'
|
|
if loader_p.exception is not None:
|
|
fail(fail_reason + ', and had exception {}'.format(loader_p.exception))
|
|
else:
|
|
fail(fail_reason + ', and had no exception')
|
|
_, alive = psutil.wait_procs(worker_psutil_ps, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
|
|
if len(alive) > 0:
|
|
self.fail(get_fail_msg('worker process (pid(s) {}) did not terminate'.format(
|
|
', '.join(str(p.pid) for p in alive))))
|
|
if exit_method is None:
|
|
if loader_p.exitcode != 0:
|
|
fail('loader process had nonzero exitcode {}'.format(loader_p.exitcode))
|
|
else:
|
|
if loader_p.exitcode == 0:
|
|
fail('loader process had zero exitcode')
|
|
if exit_method == 'loader_error':
|
|
if not isinstance(loader_p.exception, RuntimeError) or \
|
|
'Loader error' not in str(loader_p.exception):
|
|
fail('loader process did not raise expected exception, but had {}'.format(
|
|
loader_p.exception))
|
|
elif exit_method == 'worker_kill':
|
|
if isinstance(loader_p.exception, RuntimeError):
|
|
if 'DataLoader worker (pid' not in str(loader_p.exception):
|
|
fail('loader process did not raise expected exception, but had {}'.format(
|
|
loader_p.exception))
|
|
elif PY3 and isinstance(loader_p.exception, ConnectionRefusedError):
|
|
# Sometimes, when the worker is being killed and is freeing its
|
|
# resources, the unpickling in loader process will be met an
|
|
# a `ConnectionRefusedError` as it can not open a socket to receive
|
|
# resource. In such cases, the worker may not have fully exited,
|
|
# and the loader can't know this via `is_alive` check or `SIGCHLD`
|
|
# handler. So we permit this as an allowed error as well.
|
|
# After all, we are happy as long as it terminates.
|
|
pass
|
|
elif not PY3 and isinstance(loader_p.exception, OSError):
|
|
# Same reasoning as the above if-block for Py2,
|
|
# where ConnectionRefusedError isn't a thing.
|
|
if loader_p.exception.errno != errno.ECONNREFUSED:
|
|
fail('loader process did not raise expected exception, but had {}'.format(
|
|
loader_p.exception))
|
|
else:
|
|
fail('loader process did not raise expected exception, but had {}'.format(
|
|
loader_p.exception))
|
|
elif exit_method == 'worker_error':
|
|
if not isinstance(loader_p.exception, RuntimeError) or \
|
|
'Worker error' not in str(loader_p.exception):
|
|
fail('loader process did not raise expected exception, but had {}'.format(
|
|
loader_p.exception))
|
|
finally:
|
|
loader_p.terminate()
|
|
|
|
def test_len(self):
|
|
def check_len(dl, expected):
|
|
self.assertEqual(len(dl), expected)
|
|
n = 0
|
|
for _ in dl:
|
|
n += 1
|
|
self.assertEqual(n, expected)
|
|
check_len(self.dataset, 100)
|
|
check_len(DataLoader(self.dataset, batch_size=2), 50)
|
|
check_len(DataLoader(self.dataset, batch_size=3), 34)
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_numpy_scalars(self):
|
|
import numpy as np
|
|
|
|
class ScalarDataset(torch.utils.data.Dataset):
|
|
def __init__(self, dtype):
|
|
self.dtype = dtype
|
|
|
|
def __getitem__(self, i):
|
|
return self.dtype()
|
|
|
|
def __len__(self):
|
|
return 4
|
|
|
|
dtypes = {
|
|
np.float64: torch.DoubleTensor,
|
|
np.float32: torch.FloatTensor,
|
|
np.float16: torch.HalfTensor,
|
|
np.int64: torch.LongTensor,
|
|
np.int32: torch.IntTensor,
|
|
np.int16: torch.ShortTensor,
|
|
np.int8: torch.CharTensor,
|
|
np.uint8: torch.ByteTensor,
|
|
}
|
|
for dt, tt in dtypes.items():
|
|
dset = ScalarDataset(dt)
|
|
loader = DataLoader(dset, batch_size=2)
|
|
batch = next(iter(loader))
|
|
self.assertIsInstance(batch, tt)
|
|
|
|
def test_default_collate_dtype(self):
|
|
arr = [1, 2, -1]
|
|
collated = _utils.collate.default_collate(arr)
|
|
self.assertEqual(collated, torch.tensor(arr))
|
|
self.assertEqual(collated.dtype, torch.int64)
|
|
|
|
arr = [1.1, 2.3, -0.9]
|
|
collated = _utils.collate.default_collate(arr)
|
|
self.assertEqual(collated, torch.tensor(arr))
|
|
self.assertEqual(collated.dtype, torch.float64)
|
|
|
|
arr = [True, False]
|
|
collated = _utils.collate.default_collate(arr)
|
|
self.assertEqual(collated, torch.tensor(arr))
|
|
self.assertEqual(collated.dtype, torch.bool)
|
|
|
|
# Should be a no-op
|
|
arr = ['a', 'b', 'c']
|
|
self.assertEqual(arr, _utils.collate.default_collate(arr))
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_default_collate_bad_numpy_types(self):
|
|
import numpy as np
|
|
|
|
# Should be a no-op
|
|
arr = np.array(['a', 'b', 'c'])
|
|
self.assertEqual(arr, _utils.collate.default_collate(arr))
|
|
|
|
arr = np.array([[['a', 'b', 'c']]])
|
|
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
|
|
|
arr = np.array([object(), object(), object()])
|
|
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
|
|
|
arr = np.array([[[object(), object(), object()]]])
|
|
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_default_collate_shared_tensor(self):
|
|
import numpy as np
|
|
t_in = torch.zeros(1)
|
|
n_in = np.zeros(1)
|
|
|
|
self.assertEqual(t_in.is_shared(), False)
|
|
|
|
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
|
|
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
|
|
|
|
# FIXME: fix the following hack that makes `default_collate` believe
|
|
# that it is in a worker process (since it tests
|
|
# `get_worker_info() != None`), even though it is not.
|
|
old = _utils.worker._worker_info
|
|
try:
|
|
_utils.worker._worker_info = 'x'
|
|
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
|
|
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
|
|
finally:
|
|
_utils.worker._worker_info = old
|
|
|
|
|
|
class StringDataset(Dataset):
|
|
def __init__(self):
|
|
self.s = '12345'
|
|
|
|
def __len__(self):
|
|
return len(self.s)
|
|
|
|
def __getitem__(self, ndx):
|
|
return (self.s[ndx], ndx)
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestStringDataLoader(TestCase):
|
|
def setUp(self):
|
|
super(TestStringDataLoader, self).setUp()
|
|
self.dataset = StringDataset()
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_shuffle_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
|
|
for (s, n) in loader:
|
|
self.assertIsInstance(s[0], str)
|
|
self.assertTrue(n.is_pinned())
|
|
|
|
|
|
class DictDataset(Dataset):
|
|
def __len__(self):
|
|
return 4
|
|
|
|
def __getitem__(self, ndx):
|
|
return {
|
|
'a_tensor': torch.Tensor(4, 2).fill_(ndx),
|
|
'another_dict': {
|
|
'a_number': ndx,
|
|
},
|
|
}
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestDictDataLoader(TestCase):
|
|
def setUp(self):
|
|
super(TestDictDataLoader, self).setUp()
|
|
self.dataset = DictDataset()
|
|
|
|
def test_sequential_batch(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, shuffle=False)
|
|
batch_size = loader.batch_size
|
|
for i, sample in enumerate(loader):
|
|
idx = i * batch_size
|
|
self.assertEqual(set(sample.keys()), {'a_tensor', 'another_dict'})
|
|
self.assertEqual(set(sample['another_dict'].keys()), {'a_number'})
|
|
|
|
t = sample['a_tensor']
|
|
self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
|
|
self.assertTrue((t[0] == idx).all())
|
|
self.assertTrue((t[1] == idx + 1).all())
|
|
|
|
n = sample['another_dict']['a_number']
|
|
self.assertEqual(n.size(), torch.Size([batch_size]))
|
|
self.assertEqual(n[0], idx)
|
|
self.assertEqual(n[1], idx + 1)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
|
|
for sample in loader:
|
|
self.assertTrue(sample['a_tensor'].is_pinned())
|
|
self.assertTrue(sample['another_dict']['a_number'].is_pinned())
|
|
|
|
|
|
class NamedTupleDataset(Dataset):
|
|
from collections import namedtuple
|
|
Batch = namedtuple('Batch', ['data', 'label', 'random_tensor'])
|
|
Data = namedtuple('Data', ['positive', 'negative'])
|
|
|
|
def __len__(self):
|
|
return 4
|
|
|
|
def __getitem__(self, ndx):
|
|
return self.Batch(data=self.Data(positive=ndx, negative=-ndx),
|
|
label=str(ndx), random_tensor=torch.randn(3))
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestNamedTupleDataLoader(TestCase):
|
|
def setUp(self):
|
|
super(TestNamedTupleDataLoader, self).setUp()
|
|
self.dataset = NamedTupleDataset()
|
|
|
|
def test_dataloader_with_namedtuple(self):
|
|
# auto-collation
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA)
|
|
for batch in loader:
|
|
self.assertIsInstance(batch, NamedTupleDataset.Batch)
|
|
self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
|
|
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
|
|
self.assertIsInstance(batch.data.positive, torch.Tensor)
|
|
self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA)
|
|
# no auto-collation
|
|
loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA)
|
|
for batch in loader:
|
|
self.assertIsInstance(batch, NamedTupleDataset.Batch)
|
|
self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
|
|
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
|
|
self.assertNotIsInstance(batch.data.positive, torch.Tensor)
|
|
|
|
|
|
class SimpleCustomBatch(object):
|
|
def __init__(self, data):
|
|
transposed_data = list(zip(*data))
|
|
self.inp = torch.stack(transposed_data[0], 0)
|
|
self.tgt = torch.stack(transposed_data[1], 0)
|
|
|
|
def pin_memory(self):
|
|
self.inp = self.inp.pin_memory()
|
|
self.tgt = self.tgt.pin_memory()
|
|
return self
|
|
|
|
def is_pinned(self):
|
|
return self.inp.is_pinned() and self.tgt.is_pinned()
|
|
|
|
|
|
def collate_wrapper(batch):
|
|
return SimpleCustomBatch(batch)
|
|
|
|
|
|
def collate_into_packed_sequence(batch):
|
|
data = torch.stack([sample[0] for sample in batch], 1)
|
|
t, b = data.size()
|
|
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
|
|
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
|
|
|
|
|
|
def collate_into_packed_sequence_batch_first(batch):
|
|
data = torch.stack([sample[0] for sample in batch], 0)
|
|
b, t = data.size()
|
|
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
|
|
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, batch_first=True, enforce_sorted=False)
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestCustomPinFn(TestCase):
|
|
def setUp(self):
|
|
super(TestCustomPinFn, self).setUp()
|
|
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
|
|
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
|
|
self.dataset = TensorDataset(inps, tgts)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@skipIfRocm
|
|
def test_custom_batch_pin(self):
|
|
test_cases = [
|
|
(collate_wrapper, SimpleCustomBatch),
|
|
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
|
|
(collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
|
|
]
|
|
for collate_fn, elem_cls in test_cases:
|
|
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
|
|
pin_memory=True)
|
|
for sample in loader:
|
|
self.assertIsInstance(sample, elem_cls)
|
|
self.assertTrue(sample.is_pinned())
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@skipIfRocm
|
|
def test_custom_batch_pin_worker(self):
|
|
test_cases = [
|
|
(collate_wrapper, SimpleCustomBatch),
|
|
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
|
|
(collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
|
|
]
|
|
for collate_fn, elem_cls in test_cases:
|
|
loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
|
|
pin_memory=True, num_workers=1)
|
|
for sample in loader:
|
|
self.assertIsInstance(sample, elem_cls)
|
|
self.assertTrue(sample.is_pinned())
|
|
|
|
|
|
class TestWorkerQueueDataset(Dataset):
|
|
def __init__(self, data):
|
|
self.data = data
|
|
self.worker_id = None
|
|
|
|
def worker_init_fn(self, worker_id):
|
|
self.worker_id = worker_id
|
|
|
|
def __getitem__(self, item):
|
|
return self.worker_id, self.data[item]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)")
|
|
class TestIndividualWorkerQueue(TestCase):
|
|
def setUp(self):
|
|
super(TestIndividualWorkerQueue, self).setUp()
|
|
self.dataset = TestWorkerQueueDataset(list(range(128)))
|
|
|
|
def _run_ind_worker_queue_test(self, batch_size, num_workers):
|
|
loader = DataLoader(
|
|
self.dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
|
worker_init_fn=self.dataset.worker_init_fn
|
|
)
|
|
current_worker_idx = 0
|
|
for i, (worker_ids, sample) in enumerate(loader):
|
|
self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
|
|
self.assertEqual(sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size)))
|
|
current_worker_idx += 1
|
|
if current_worker_idx == num_workers:
|
|
current_worker_idx = 0
|
|
|
|
def test_ind_worker_queue(self):
|
|
for batch_size in (8, 16, 32, 64):
|
|
for num_workers in range(1, 6):
|
|
self._run_ind_worker_queue_test(batch_size=batch_size, num_workers=num_workers)
|
|
|
|
|
|
class SetAffinityDataset(torch.utils.data.IterableDataset):
|
|
|
|
def __iter__(self):
|
|
torch.randperm(1)
|
|
after = os.sched_getaffinity(0)
|
|
return iter(after)
|
|
|
|
|
|
def worker_set_affinity(_):
|
|
os.sched_setaffinity(0, [2])
|
|
|
|
|
|
@unittest.skipIf(
|
|
not hasattr(os, 'sched_setaffinity'),
|
|
"os.sched_setaffinity is not available")
|
|
class TestSetAffinity(TestCase):
|
|
def test_set_affinity_in_worker_init(self):
|
|
dataset = SetAffinityDataset()
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset, num_workers=2, worker_init_fn=worker_set_affinity)
|
|
for sample in dataloader:
|
|
self.assertEqual(sample, [2])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|