mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
3521 lines
124 KiB
Python
3521 lines
124 KiB
Python
# Owner(s): ["module: dataloader"]
|
|
|
|
import ctypes
|
|
import errno
|
|
import faulthandler
|
|
import functools
|
|
import gc
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import os
|
|
import signal
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.utils.data.datapipes as dp
|
|
from torch import multiprocessing as mp
|
|
from torch._utils import ExceptionWrapper
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import (
|
|
IS_CI,
|
|
IS_JETSON,
|
|
IS_MACOS,
|
|
IS_SANDCASTLE,
|
|
IS_WINDOWS,
|
|
load_tests,
|
|
NO_MULTIPROCESSING_SPAWN,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfNoDill,
|
|
skipIfRocm,
|
|
slowTest,
|
|
TEST_CUDA,
|
|
TEST_NUMPY,
|
|
TEST_WITH_ASAN,
|
|
TEST_WITH_ROCM,
|
|
TEST_WITH_TSAN,
|
|
TestCase,
|
|
)
|
|
from torch.utils.data import (
|
|
_utils,
|
|
ChainDataset,
|
|
ConcatDataset,
|
|
DataLoader,
|
|
Dataset,
|
|
IterableDataset,
|
|
IterDataPipe,
|
|
StackDataset,
|
|
Subset,
|
|
TensorDataset,
|
|
)
|
|
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
|
from torch.utils.data.datapipes.iter import IterableWrapper
|
|
from torch.utils.data.dataset import random_split
|
|
|
|
|
|
try:
|
|
import psutil
|
|
|
|
HAS_PSUTIL = True
|
|
except ModuleNotFoundError:
|
|
HAS_PSUTIL = False
|
|
psutil = None
|
|
err_msg = (
|
|
"psutil not found. Some critical data loader tests relying on it "
|
|
"(e.g., TestDataLoader.test_proper_exit) will not run."
|
|
)
|
|
if IS_CI:
|
|
raise ModuleNotFoundError(err_msg) from None
|
|
else:
|
|
warnings.warn(err_msg)
|
|
|
|
|
|
try:
|
|
import numpy as np
|
|
|
|
HAS_NUMPY = True
|
|
except ModuleNotFoundError:
|
|
HAS_NUMPY = False
|
|
np = None
|
|
skipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy")
|
|
|
|
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
TEST_CUDA_IPC = (
|
|
torch.cuda.is_available()
|
|
and sys.platform != "darwin"
|
|
and sys.platform != "win32"
|
|
and not IS_JETSON
|
|
and not TEST_WITH_ROCM
|
|
) # https://github.com/pytorch/pytorch/issues/90940
|
|
|
|
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
|
|
|
|
if TEST_CUDA_IPC:
|
|
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
|
|
|
if not NO_MULTIPROCESSING_SPAWN:
|
|
# We want to use `spawn` if able because some of our tests check that the
|
|
# data loader terminiates gracefully. To prevent hanging in the testing
|
|
# process, such data loaders are run in a separate subprocess.
|
|
#
|
|
# We also want to test the `pin_memory=True` configuration, thus `spawn` is
|
|
# required to launch such processes and they initialize the CUDA context.
|
|
#
|
|
# Mixing different start method is a recipe for disaster (e.g., using a fork
|
|
# `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally
|
|
# to avoid bugs.
|
|
#
|
|
# Get a multiprocessing context because some test / third party library will
|
|
# set start_method when imported, and setting again triggers `RuntimeError`.
|
|
mp = mp.get_context(method="spawn")
|
|
|
|
|
|
# 60s of timeout?
|
|
# Yes, in environments where physical CPU resources are shared, e.g., CI, the
|
|
# time for a inter-process communication can be highly varying. With 15~17s of
|
|
# timeout, we have observed flakiness in some CI builds (see
|
|
# pytorch/pytorch#14501, pytorch/pytorch#16608). We follow the CPython
|
|
# multiprocessing setup and set the timeout to 60s here:
|
|
#
|
|
# https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73
|
|
JOIN_TIMEOUT = 60.0 # seconds
|
|
|
|
|
|
supported_multiprocessing_contexts = [None] + list(
|
|
torch.multiprocessing.get_all_start_methods()
|
|
)
|
|
|
|
|
|
# collate_fn that returns the batch cloned; defined globally here for pickle purposes.
|
|
def _clone_collate(b):
|
|
return [x.clone() for x in b]
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
class TestDatasetRandomSplit(TestCase):
|
|
def test_lengths_must_equal_dataset_size(self):
|
|
with self.assertRaises(ValueError):
|
|
random_split([1, 2, 3, 4], [1, 2])
|
|
|
|
def test_splits_have_correct_size(self):
|
|
splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
|
|
self.assertEqual(len(splits), 2)
|
|
self.assertEqual(len(splits[0]), 2)
|
|
self.assertEqual(len(splits[1]), 4)
|
|
|
|
splits = random_split([1, 2, 3, 4, 5, 6], [0.5, 0.5])
|
|
self.assertEqual(len(splits), 2)
|
|
self.assertEqual(len(splits[0]), 3)
|
|
self.assertEqual(len(splits[1]), 3)
|
|
|
|
# Odd size splits
|
|
self.assertEqual(
|
|
len(
|
|
random_split(
|
|
range(3), [0.5, 0.5], generator=torch.Generator().manual_seed(1)
|
|
)
|
|
),
|
|
2,
|
|
)
|
|
|
|
# Odd sized round-robin splits
|
|
splits = random_split(
|
|
range(106), [0.1, 0.2, 0.3, 0.4], generator=torch.Generator().manual_seed(1)
|
|
)
|
|
self.assertEqual(len(splits[0]), 11)
|
|
self.assertEqual(len(splits[1]), 22)
|
|
self.assertEqual(len(splits[2]), 31)
|
|
self.assertEqual(len(splits[3]), 42)
|
|
|
|
def test_splits_are_mutually_exclusive(self):
|
|
data = [5, 2, 3, 4, 1, 6]
|
|
splits = random_split(data, [2, 4])
|
|
all_values = []
|
|
all_values.extend(list(splits[0]))
|
|
all_values.extend(list(splits[1]))
|
|
data.sort()
|
|
all_values.sort()
|
|
self.assertListEqual(data, all_values)
|
|
|
|
splits = random_split(data, [0.33, 0.67])
|
|
all_values = []
|
|
all_values.extend(list(splits[0]))
|
|
all_values.extend(list(splits[1]))
|
|
data.sort()
|
|
all_values.sort()
|
|
self.assertListEqual(data, all_values)
|
|
|
|
data = [1, 2, 3, 4]
|
|
splits = random_split(data, [0.25, 0.75])
|
|
all_values = []
|
|
all_values.extend(list(splits[0]))
|
|
all_values.extend(list(splits[1]))
|
|
data.sort()
|
|
all_values.sort()
|
|
self.assertListEqual(data, all_values)
|
|
|
|
def test_splits_indexing_type(self):
|
|
r"""Indices generated by random_split
|
|
should be of integer type
|
|
"""
|
|
|
|
class CustomDataset:
|
|
def __init__(self, test_object, custom_list):
|
|
self.data = custom_list
|
|
self.test_object = test_object
|
|
|
|
def __getitem__(self, key):
|
|
self.test_object.assertEqual(type(key), int)
|
|
return self.data[key]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
x = [1, 2, 3, 4, 5]
|
|
dataset = CustomDataset(self, x)
|
|
dataset = random_split(dataset, [5])[0]
|
|
data_loader = DataLoader(dataset)
|
|
for batch in data_loader:
|
|
pass
|
|
|
|
# fractional splitting
|
|
dataset = CustomDataset(self, x)
|
|
dataset = random_split(dataset, [1.0])[0]
|
|
data_loader = DataLoader(dataset)
|
|
for batch in data_loader:
|
|
pass
|
|
|
|
def test_splits_reproducibility(self):
|
|
self.assertEqual(
|
|
[
|
|
list(x)
|
|
for x in random_split(
|
|
range(10), [3, 7], generator=torch.Generator().manual_seed(1)
|
|
)
|
|
],
|
|
[[5, 6, 1], [2, 0, 8, 9, 3, 7, 4]],
|
|
)
|
|
self.assertEqual(
|
|
random_split(
|
|
range(100), [60, 40], generator=torch.Generator().manual_seed(42)
|
|
),
|
|
random_split(
|
|
range(100), [60, 40], generator=torch.Generator().manual_seed(42)
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
random_split(
|
|
range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
|
|
),
|
|
random_split(
|
|
range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
random_split(
|
|
range(100),
|
|
[0.33, 0.33, 0.34],
|
|
generator=torch.Generator().manual_seed(42),
|
|
),
|
|
random_split(
|
|
range(100),
|
|
[0.33, 0.33, 0.34],
|
|
generator=torch.Generator().manual_seed(42),
|
|
),
|
|
)
|
|
|
|
def test_incomplete_fractional_splits(self):
|
|
with self.assertRaises(ValueError):
|
|
# should raise since the sum of fractions is not 1
|
|
random_split([1, 2, 3, 4], [0.1])
|
|
|
|
with self.assertRaises(ValueError):
|
|
# should raise since fraction > 1
|
|
random_split([1, 2, 3, 4], [1.1])
|
|
|
|
def test_splits_generator(self):
|
|
# A random_split without a specific generator should affect the default one
|
|
state = torch.get_rng_state()
|
|
a = torch.rand(10)
|
|
torch.set_rng_state(state)
|
|
random_split(range(10), [5, 5])
|
|
b = torch.rand(10)
|
|
self.assertNotEqual(a, b)
|
|
|
|
# A random_split with a specific generator should not affect the default one
|
|
state = torch.get_rng_state()
|
|
a = torch.rand(10)
|
|
torch.set_rng_state(state)
|
|
random_split(range(10), [5, 5], generator=torch.Generator().manual_seed(42))
|
|
b = torch.rand(10)
|
|
self.assertEqual(a, b)
|
|
|
|
def test_slicing_of_subset_of_dataset(self):
|
|
# Testing slicing a subset initialized with a dataset
|
|
dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
|
|
subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
|
|
self.assertEqual(subset_of_dataset[:], dataset[:])
|
|
self.assertEqual(subset_of_dataset[1:2], dataset[1:2])
|
|
self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2])
|
|
# Testing slicing of subset from random split
|
|
subset1, subset2 = random_split(dataset, [3, 2])
|
|
self.assertEqual(subset1[:], dataset[subset1.indices[:]])
|
|
self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]])
|
|
self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]])
|
|
|
|
def test_slicing_of_subset_of_subset(self):
|
|
# Testing slicing a subset initialized with a subset
|
|
dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
|
|
subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
|
|
subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4])
|
|
self.assertEqual(subset_of_subset[:], dataset[:])
|
|
self.assertEqual(subset_of_subset[0:2], dataset[0:2])
|
|
self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2])
|
|
# Testing slicing of subset of subset from random split
|
|
subset1, subset2 = random_split(dataset, [4, 1])
|
|
subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1])
|
|
idx = [subset1.indices[i] for i in subset_of_subset1.indices]
|
|
self.assertEqual(subset_of_subset1[:], dataset[idx.copy()])
|
|
self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]])
|
|
self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]])
|
|
|
|
|
|
class CUDACountingDataset(Dataset):
|
|
def __init__(self, n):
|
|
super().__init__()
|
|
self.n = n
|
|
|
|
def __getitem__(self, i):
|
|
return torch.as_tensor(i, device="cuda")
|
|
|
|
def __len__(self):
|
|
return self.n
|
|
|
|
|
|
class CountingDataset(Dataset):
|
|
def __init__(self, n):
|
|
super().__init__()
|
|
self.n = n
|
|
|
|
def __getitem__(self, i):
|
|
return i
|
|
|
|
def __len__(self):
|
|
return self.n
|
|
|
|
|
|
class CountingIterableDataset(IterableDataset):
|
|
def __init__(self, n):
|
|
super().__init__()
|
|
self.n = n
|
|
|
|
def __iter__(self):
|
|
return iter(range(self.n))
|
|
|
|
def __len__(self):
|
|
return self.n
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
class TestTensorDataset(TestCase):
|
|
def test_len(self):
|
|
source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
|
|
self.assertEqual(len(source), 15)
|
|
|
|
def test_getitem(self):
|
|
t = torch.randn(15, 10, 2, 3, 4, 5)
|
|
l = torch.randn(15, 10)
|
|
source = TensorDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
self.assertEqual(l[i], source[i][1])
|
|
|
|
def test_getitem_1d(self):
|
|
t = torch.randn(15)
|
|
l = torch.randn(15)
|
|
source = TensorDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
self.assertEqual(l[i], source[i][1])
|
|
|
|
def test_single_tensor(self):
|
|
t = torch.randn(5, 10)
|
|
source = TensorDataset(t)
|
|
self.assertEqual(len(source), 5)
|
|
for i in range(5):
|
|
self.assertEqual(t[i], source[i][0])
|
|
|
|
def test_many_tensors(self):
|
|
t0 = torch.randn(5, 10, 2, 3, 4, 5)
|
|
t1 = torch.randn(5, 10)
|
|
t2 = torch.randn(5, 10, 2, 5)
|
|
t3 = torch.randn(5, 10, 3, 7)
|
|
source = TensorDataset(t0, t1, t2, t3)
|
|
self.assertEqual(len(source), 5)
|
|
for i in range(5):
|
|
self.assertEqual(t0[i], source[i][0])
|
|
self.assertEqual(t1[i], source[i][1])
|
|
self.assertEqual(t2[i], source[i][2])
|
|
self.assertEqual(t3[i], source[i][3])
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
class TestStackDataset(TestCase):
|
|
def test_empty(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError, "At least one dataset should be passed"
|
|
):
|
|
StackDataset()
|
|
|
|
def test_mixed(self):
|
|
with self.assertRaisesRegex(ValueError, "Supported either"):
|
|
StackDataset(
|
|
TensorDataset(torch.randn(15, 10)), a=TensorDataset(torch.randn(10, 15))
|
|
)
|
|
|
|
def test_size_mismatch(self):
|
|
with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
|
|
StackDataset(
|
|
TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(10, 15))
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
|
|
StackDataset(
|
|
a=TensorDataset(torch.randn(15, 10)),
|
|
b=TensorDataset(torch.randn(10, 15)),
|
|
)
|
|
|
|
def test_len(self):
|
|
source = StackDataset(
|
|
TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(15))
|
|
)
|
|
self.assertEqual(len(source), 15)
|
|
source = StackDataset(TensorDataset(torch.randn(15, 10)))
|
|
self.assertEqual(len(source), 15)
|
|
source = StackDataset(
|
|
a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(15))
|
|
)
|
|
self.assertEqual(len(source), 15)
|
|
source = StackDataset(a=TensorDataset(torch.randn(15, 10)))
|
|
self.assertEqual(len(source), 15)
|
|
|
|
def test_single(self):
|
|
t = TensorDataset(torch.randn(15, 10))
|
|
source = StackDataset(t)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
source = StackDataset(a=t)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i]["a"])
|
|
|
|
def test_getitem(self):
|
|
t = TensorDataset(torch.randn(15, 10))
|
|
l = TensorDataset(torch.randn(15, 5, 4))
|
|
source = StackDataset(t, l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i][0])
|
|
self.assertEqual(l[i], source[i][1])
|
|
source = StackDataset(a=t, b=l)
|
|
for i in range(15):
|
|
self.assertEqual(t[i], source[i]["a"])
|
|
self.assertEqual(l[i], source[i]["b"])
|
|
|
|
def test_getitems(self):
|
|
class GetItemsDataset(Dataset):
|
|
def __init__(self) -> None:
|
|
self.data = torch.randn(4)
|
|
|
|
def __getitem__(self, item):
|
|
return self.data[item]
|
|
|
|
def __getitems__(self, items):
|
|
return self.data[items]
|
|
|
|
def __len__(self):
|
|
return 4
|
|
|
|
t = GetItemsDataset()
|
|
l = [1, 2, 3, 4]
|
|
|
|
source = StackDataset(t, l)
|
|
batch = source.__getitems__([0, 1, 2, 3])
|
|
for i in range(4):
|
|
self.assertEqual(t[i], batch[i][0])
|
|
self.assertEqual(l[i], batch[i][1])
|
|
|
|
source = StackDataset(t=t, l=l)
|
|
batch = source.__getitems__([0, 1, 2, 3])
|
|
for i in range(4):
|
|
self.assertEqual(t[i], batch[i]["t"])
|
|
self.assertEqual(l[i], batch[i]["l"])
|
|
|
|
def test_getitems_raises_index_error(self):
|
|
class GetItemsDataset(Dataset):
|
|
def __init__(self) -> None:
|
|
self.data = torch.randn(4)
|
|
|
|
def __getitem__(self, item):
|
|
return self.data[item]
|
|
|
|
def __getitems__(self, items):
|
|
return self.data[items]
|
|
|
|
def __len__(self):
|
|
return 4
|
|
|
|
t = GetItemsDataset()
|
|
l = [1, 2, 3, 4]
|
|
|
|
source = StackDataset(t, l)
|
|
|
|
with self.assertRaises(IndexError):
|
|
source.__getitems__([0, 4])
|
|
|
|
def test_getitems_value_error(self):
|
|
class GetItemsDataset(Dataset):
|
|
def __init__(self) -> None:
|
|
self.data = torch.randn(4)
|
|
|
|
def __getitem__(self, item):
|
|
return self.data[item]
|
|
|
|
def __getitems__(self, items):
|
|
return self.data[items][:-1] # return less
|
|
|
|
def __len__(self):
|
|
return 4
|
|
|
|
t = GetItemsDataset()
|
|
l = [1, 2, 3, 4]
|
|
|
|
source = StackDataset(t, l)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Nested dataset's output size mismatch. Expected 4, got 3"
|
|
):
|
|
source.__getitems__([0, 1, 2, 3])
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
class TestConcatDataset(TestCase):
|
|
def test_concat_two_singletons(self):
|
|
result = ConcatDataset([[0], [1]])
|
|
self.assertEqual(2, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(1, result[1])
|
|
|
|
def test_concat_two_non_singletons(self):
|
|
result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
|
self.assertEqual(10, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(5, result[5])
|
|
|
|
def test_concat_two_non_singletons_with_empty(self):
|
|
# Adding an empty dataset somewhere is correctly handled
|
|
result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]])
|
|
self.assertEqual(10, len(result))
|
|
self.assertEqual(0, result[0])
|
|
self.assertEqual(5, result[5])
|
|
|
|
def test_concat_raises_index_error(self):
|
|
result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
|
with self.assertRaises(IndexError):
|
|
# this one goes to 11
|
|
result[11]
|
|
|
|
def test_add_dataset(self):
|
|
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
result = d1 + d2 + d3
|
|
self.assertEqual(21, len(result))
|
|
self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
|
|
self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
|
|
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
|
|
|
|
def test_iterable_dataset_err(self):
|
|
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
|
|
it1 = CountingIterableDataset(5)
|
|
it2 = CountingIterableDataset(10)
|
|
|
|
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
|
|
ConcatDataset([d1, it2, it1])
|
|
|
|
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
|
|
ConcatDataset([it2])
|
|
|
|
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
|
|
ConcatDataset([it1, d1])
|
|
|
|
|
|
# takes in dummy var so this can also be used as a `worker_init_fn`
|
|
def set_faulthander_if_available(_=None):
|
|
faulthandler.enable(sys.__stderr__)
|
|
if not IS_WINDOWS:
|
|
# windows does not have faulthandler.register
|
|
# chain=False prevents the default behavior of killing the process
|
|
faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False)
|
|
|
|
|
|
set_faulthander_if_available()
|
|
|
|
|
|
# Process `pid` must have called `set_faulthander_if_available`
|
|
def print_traces_of_all_threads(pid):
|
|
if not IS_WINDOWS:
|
|
# use the custom signal if available
|
|
os.kill(pid, signal.SIGUSR1)
|
|
else:
|
|
# otherwise we can still use the handler given by faulthandler.enable()
|
|
# at the cost of killing the process.
|
|
os.kill(pid, signal.SIGSEGV)
|
|
|
|
# wait in parent process to give subprocess some time to print
|
|
time.sleep(5)
|
|
|
|
|
|
# The following `ErrorTrackingProcess` stores the first encountered exception in
|
|
# its `.exception` attribute.
|
|
# Inspired by https://stackoverflow.com/a/33599967
|
|
class ErrorTrackingProcess(mp.Process):
|
|
# Why no *args?
|
|
# py2 doesn't support def fn(x, *args, key=val, **kwargs)
|
|
# Setting disable_stderr=True may generate a lot of unrelated error outputs
|
|
# but could be helpful for debugging.
|
|
def __init__(self, disable_stderr=True, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._pconn, self._cconn = mp.Pipe()
|
|
self._exception = None
|
|
self.disable_stderr = disable_stderr
|
|
|
|
def run(self):
|
|
set_faulthander_if_available()
|
|
if self.disable_stderr:
|
|
# Disable polluting stderr with errors that are supposed to happen.
|
|
with open(os.devnull, "w") as devnull:
|
|
os.dup2(devnull.fileno(), sys.stderr.fileno())
|
|
try:
|
|
super().run()
|
|
self._cconn.send(None)
|
|
except Exception:
|
|
self._cconn.send(ExceptionWrapper(sys.exc_info()))
|
|
raise
|
|
|
|
def print_traces_of_all_threads(self):
|
|
assert (
|
|
self.is_alive()
|
|
), "can only use print_traces_of_all_threads if the process is alive"
|
|
assert (
|
|
not self.disable_stderr
|
|
), "do not disable stderr if you use print_traces_of_all_threads"
|
|
# On platforms without `SIGUSR1`, `set_faulthander_if_available` sets
|
|
# `faulthandler.enable()`, and `print_traces_of_all_threads` may kill
|
|
# the process. So let's poll the exception first
|
|
_ = self.exception
|
|
print_traces_of_all_threads(self.pid)
|
|
|
|
@property
|
|
def exception(self):
|
|
if self._pconn.poll():
|
|
self._exception = self._pconn.recv()
|
|
if self._exception is None:
|
|
return None
|
|
else:
|
|
return self._exception.exc_type(self._exception.exc_msg)
|
|
|
|
# ESRCH means that os.kill can't finds alive proc
|
|
def send_signal(self, signum, ignore_ESRCH=False):
|
|
try:
|
|
os.kill(self.pid, signum)
|
|
except OSError as e:
|
|
if not ignore_ESRCH or e.errno != errno.ESRCH:
|
|
raise
|
|
|
|
|
|
class ErrorDataset(Dataset):
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SegfaultDataset(Dataset):
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __getitem__(self, idx):
|
|
return ctypes.string_at(0)
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SleepDataset(Dataset):
|
|
def __init__(self, size, sleep_sec):
|
|
self.size = size
|
|
self.sleep_sec = sleep_sec
|
|
self.sleeped = False
|
|
|
|
def __getitem__(self, idx):
|
|
if not self.sleeped:
|
|
time.sleep(self.sleep_sec)
|
|
self.sleeped = True
|
|
return idx
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class SeedDataset(Dataset):
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __getitem__(self, idx):
|
|
return torch.initial_seed()
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class WorkerSpecificIterableDataset(IterableDataset):
|
|
def __init__(self, sizes_for_all_workers):
|
|
self.sizes_for_all_workers = sizes_for_all_workers
|
|
|
|
def __iter__(self):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
assert worker_info is not None
|
|
return iter(range(self.sizes_for_all_workers[worker_info.id]))
|
|
|
|
def __len__(self):
|
|
return sum(self.sizes_for_all_workers)
|
|
|
|
|
|
# Inspired by https://stackoverflow.com/a/26703365
|
|
# If all workers will call `sync_once`, they will be blocked until all workers
|
|
# reach the call (i.e., acting like a barrier).
|
|
# This can be used to ensure that each worker at least processes one data.
|
|
class SynchronizedDataset(Dataset):
|
|
def __init__(self, size, batch_size, num_workers):
|
|
assert size >= num_workers * batch_size
|
|
self.count = mp.Value("i", 0, lock=True)
|
|
self.barrier = mp.Semaphore(0)
|
|
self.num_workers = num_workers
|
|
self.size = size
|
|
|
|
def sync_once(self):
|
|
with self.count.get_lock():
|
|
self.count.value += 1
|
|
if self.count.value == self.num_workers:
|
|
self.barrier.release()
|
|
self.barrier.acquire()
|
|
self.barrier.release()
|
|
|
|
def __getitem__(self, idx):
|
|
raise NotImplementedError
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
|
|
class EmptyTensorDataset(torch.utils.data.Dataset):
|
|
def __init__(self, len):
|
|
self.len = len
|
|
|
|
def __len__(self):
|
|
return self.len
|
|
|
|
def __getitem__(self, any):
|
|
return torch.empty(0)
|
|
|
|
|
|
class SynchronizedSeedDataset(SynchronizedDataset):
|
|
def __getitem__(self, idx):
|
|
self.sync_once()
|
|
return torch.initial_seed()
|
|
|
|
|
|
def _test_timeout(persistent_workers):
|
|
dataset = SleepDataset(10, 3)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=2,
|
|
num_workers=2,
|
|
timeout=1,
|
|
persistent_workers=persistent_workers,
|
|
)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
def _test_timeout_pin_memory(persistent_workers):
|
|
dataset = SleepDataset(10, 3)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=2,
|
|
num_workers=2,
|
|
timeout=1,
|
|
pin_memory=True,
|
|
persistent_workers=persistent_workers,
|
|
)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
def _test_large_sampler_indices(persistent_workers):
|
|
# See
|
|
# test_large_sampler_indices
|
|
# https://github.com/pytorch/pytorch/issues/48666
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
EmptyTensorDataset(10000000),
|
|
batch_size=40960,
|
|
persistent_workers=persistent_workers,
|
|
num_workers=1,
|
|
)
|
|
|
|
it = iter(dataloader)
|
|
|
|
for x in it:
|
|
assert x.numel() == 0
|
|
raise RuntimeError("My Error")
|
|
|
|
|
|
def disable_stderr(worker_id):
|
|
r"""
|
|
Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
|
|
from workers. Since worker signal handler prints with low-level write(),
|
|
this has to be done on OS level via dup.
|
|
|
|
This is used as worker_init_fn for test_segfault.
|
|
"""
|
|
sys.stderr.flush() # flush library buffers that dup2 knows nothing about
|
|
# Can't use a with-block because otherwise the fd will be closed when this
|
|
# function ends.
|
|
with open(os.devnull, "w") as devnull:
|
|
os.dup2(devnull.fileno(), sys.stderr.fileno())
|
|
|
|
|
|
def _test_segfault():
|
|
dataset = SegfaultDataset(10)
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr
|
|
)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
def _test_no_segfault():
|
|
dataset = [1, 2, 3]
|
|
num_threads = torch.get_num_threads()
|
|
if num_threads < 4:
|
|
torch.set_num_threads(4)
|
|
else:
|
|
torch.set_num_threads(num_threads)
|
|
mp_ctx = torch.multiprocessing.get_context(method="fork")
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
num_workers=1,
|
|
worker_init_fn=disable_stderr,
|
|
multiprocessing_context=mp_ctx,
|
|
)
|
|
_ = next(iter(dataloader))
|
|
|
|
|
|
class TestProperExitDataset(Dataset):
|
|
def __init__(self, size, error_event):
|
|
self.size = size
|
|
self.error_event = error_event
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
def __getitem__(self, idx):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
if (
|
|
self.error_event is not None
|
|
and self.error_event.is_set()
|
|
and worker_info.id == worker_info.num_workers - 1
|
|
):
|
|
# only error in the last worker
|
|
raise RuntimeError("Worker error")
|
|
return torch.tensor([idx])
|
|
|
|
|
|
class TestProperExitIterableDataset(IterableDataset):
|
|
def __init__(self, size, error_event):
|
|
self.error_event = error_event
|
|
self.size = size
|
|
self.remaining = size
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
if (
|
|
self.error_event is not None
|
|
and self.error_event.is_set()
|
|
and worker_info.id == worker_info.num_workers - 1
|
|
):
|
|
# only error in the last worker
|
|
raise RuntimeError("Worker error")
|
|
self.remaining -= 1
|
|
if self.remaining < 0:
|
|
raise StopIteration
|
|
return torch.tensor(-1000)
|
|
|
|
|
|
# See TestDataLoader.test_proper_exit for usage
|
|
def _test_proper_exit(
|
|
is_iterable_dataset,
|
|
use_workers,
|
|
pin_memory,
|
|
exit_method,
|
|
hold_iter_reference,
|
|
loader_setup_event,
|
|
tester_setup_event,
|
|
persistent_workers,
|
|
):
|
|
num_workers = 2 if use_workers else 0
|
|
|
|
if exit_method == "worker_error" or exit_method == "worker_kill":
|
|
assert use_workers is True
|
|
|
|
if exit_method == "worker_error":
|
|
worker_error_event = mp.Event()
|
|
else:
|
|
worker_error_event = None
|
|
|
|
if is_iterable_dataset:
|
|
ds = TestProperExitIterableDataset(7, worker_error_event)
|
|
else:
|
|
ds = TestProperExitDataset(12, worker_error_event)
|
|
|
|
loader = DataLoader(
|
|
ds,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory,
|
|
worker_init_fn=set_faulthander_if_available,
|
|
persistent_workers=persistent_workers,
|
|
)
|
|
|
|
error_it = 2
|
|
|
|
if use_workers:
|
|
# 2 is the magical per-worker prefetch number...
|
|
# FIXME: change this after the number becomes configurable.
|
|
if is_iterable_dataset:
|
|
assert len(ds) * num_workers > (error_it + 2 + 1)
|
|
else:
|
|
assert len(loader) > (error_it + 2 + 1) * num_workers
|
|
else:
|
|
if is_iterable_dataset:
|
|
assert len(ds) > error_it + 1
|
|
else:
|
|
assert len(loader) > error_it + 1
|
|
|
|
it = iter(loader)
|
|
if use_workers:
|
|
workers = it._workers
|
|
|
|
def kill_pid(pid):
|
|
psutil_p = psutil.Process(pid)
|
|
psutil_p.kill()
|
|
psutil_p.wait(JOIN_TIMEOUT)
|
|
assert not psutil_p.is_running()
|
|
|
|
for i, _ in enumerate(it):
|
|
if i == 0:
|
|
if not hold_iter_reference:
|
|
del it
|
|
del loader
|
|
loader_setup_event.set()
|
|
tester_setup_event.wait()
|
|
# ensure that the workers are still alive
|
|
if use_workers:
|
|
for w in workers:
|
|
assert w.is_alive()
|
|
if worker_error_event is not None:
|
|
worker_error_event.set()
|
|
|
|
if i == error_it:
|
|
if exit_method == "loader_error":
|
|
raise RuntimeError("Loader error")
|
|
elif exit_method == "loader_kill":
|
|
kill_pid(os.getpid())
|
|
elif exit_method == "worker_kill":
|
|
kill_pid(workers[-1].pid) # kill last worker
|
|
|
|
if not hold_iter_reference:
|
|
# Tries to trigger the __del__ clean-up rather than the automatic
|
|
# exiting of daemonic children. Technically it should be automatically
|
|
# triggered, but I don't want to rely on the implementation detail of
|
|
# Python gc.
|
|
gc.collect()
|
|
|
|
|
|
class TestWorkerInfoDataset(SynchronizedDataset):
|
|
def __getitem__(self, idx):
|
|
self.sync_once()
|
|
return torch.tensor(self.value)
|
|
|
|
|
|
# Should be used as worker_init_fn with TestWorkerInfoDataset.
|
|
# See _test_get_worker_info below for usage.
|
|
def _test_worker_info_init_fn(worker_id):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
assert (
|
|
worker_id == worker_info.id
|
|
), "worker_init_fn and worker_info should have consistent id"
|
|
assert (
|
|
worker_id < worker_info.num_workers
|
|
), "worker_init_fn and worker_info should have valid id"
|
|
assert (
|
|
worker_info.seed == torch.initial_seed()
|
|
), "worker_init_fn and worker_info should have consistent seed"
|
|
dataset = worker_info.dataset
|
|
assert isinstance(
|
|
dataset, TestWorkerInfoDataset
|
|
), "worker_info should have correct dataset copy"
|
|
assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy"
|
|
# test that WorkerInfo attributes are read-only
|
|
try:
|
|
worker_info.id = 3999
|
|
except RuntimeError as e:
|
|
assert str(e) == "Cannot assign attributes to WorkerInfo objects"
|
|
try:
|
|
worker_info.a = 3
|
|
except RuntimeError as e:
|
|
assert str(e) == "Cannot assign attributes to WorkerInfo objects"
|
|
for k in ["id", "num_workers", "seed", "dataset"]:
|
|
assert f"{k}=" in repr(worker_info)
|
|
dataset.value = [worker_id, os.getpid()]
|
|
|
|
|
|
def _test_get_worker_info():
|
|
# get_worker_info returns None in main proc
|
|
assert torch.utils.data.get_worker_info() is None
|
|
num_workers = 2
|
|
batch_size = 2
|
|
dataset = TestWorkerInfoDataset(6, batch_size, num_workers)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
worker_init_fn=_test_worker_info_init_fn,
|
|
)
|
|
it = iter(dataloader)
|
|
data = []
|
|
for d in it:
|
|
data.append(d) # noqa: PERF402
|
|
worker_pids = [w.pid for w in it._workers]
|
|
data = torch.cat(data, 0)
|
|
for d in data:
|
|
# each `d` is a [worker_id, worker_pid] pair, which is set in
|
|
# _test_worker_info_init_fn
|
|
assert d[1] == worker_pids[d[0]]
|
|
# get_worker_info returns None in main proc after data loading
|
|
assert torch.utils.data.get_worker_info() is None
|
|
# main proc dataset was never assigned this attribute
|
|
assert not hasattr(dataset, "value")
|
|
try:
|
|
_ = dataset[0]
|
|
except AttributeError:
|
|
return
|
|
raise RuntimeError("Expected AttributeError")
|
|
|
|
|
|
# test custom init function
|
|
def init_fn(worker_id):
|
|
torch.manual_seed(12345)
|
|
|
|
|
|
# used with test_error_in_init
|
|
class ErrorIterableDataset(IterableDataset):
|
|
def __iter__(self):
|
|
raise RuntimeError("Error in __iter__")
|
|
|
|
|
|
# used with test_error_in_init
|
|
def error_worker_init_fn(_):
|
|
raise RuntimeError("Error in worker_init_fn")
|
|
|
|
|
|
class BulkLoadingDataset(Dataset):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __getitem__(self, indices):
|
|
assert isinstance(indices, (list, tuple))
|
|
return torch.as_tensor(indices)
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
class BulkLoadingSampler(torch.utils.data.Sampler):
|
|
def __init__(self, dataset, batch_size):
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
|
|
def __iter__(self):
|
|
for x in torch.randperm(len(self.dataset)).split(self.batch_size):
|
|
yield x.tolist()
|
|
|
|
def __len__(self):
|
|
return int(math.ceil(len(self.dataset) / float(self.batch_size)))
|
|
|
|
|
|
class TestMultiEpochDataset(IterableDataset):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __iter__(self):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
assert worker_info is not None
|
|
worker_id = worker_info.id
|
|
for idx in range(self.length // worker_info.num_workers):
|
|
yield worker_id
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
class CustomList(list):
|
|
pass
|
|
|
|
|
|
class CustomDict(dict):
|
|
pass
|
|
|
|
|
|
def row_processor(row):
|
|
return np.add(row, 1)
|
|
|
|
|
|
def filter_len(row):
|
|
return len(row) == 4
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
@unittest.skipIf(
|
|
TEST_WITH_ASAN,
|
|
"DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
|
|
)
|
|
class TestDataLoader(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.data = torch.randn(100, 2, 3, 5)
|
|
self.labels = torch.randperm(50).repeat(2)
|
|
self.dataset = TensorDataset(self.data, self.labels)
|
|
self.persistent_workers = False
|
|
|
|
def _get_data_loader(self, dataset, **kwargs):
|
|
persistent_workers = kwargs.get("persistent_workers", self.persistent_workers)
|
|
if persistent_workers and kwargs.get("num_workers", 0) == 0:
|
|
persistent_workers = False
|
|
kwargs["persistent_workers"] = persistent_workers
|
|
return DataLoader(dataset, **kwargs)
|
|
|
|
def _test_sequential(self, loader):
|
|
batch_size = loader.batch_size
|
|
if batch_size is None:
|
|
for idx, (sample, target) in enumerate(loader):
|
|
self.assertEqual(sample, self.data[idx])
|
|
self.assertEqual(target, self.labels[idx])
|
|
self.assertEqual(idx, len(self.dataset) - 1)
|
|
else:
|
|
for i, (sample, target) in enumerate(loader):
|
|
idx = i * batch_size
|
|
self.assertEqual(sample, self.data[idx : idx + batch_size])
|
|
self.assertEqual(target, self.labels[idx : idx + batch_size])
|
|
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
|
|
|
def _test_shuffle(self, loader):
|
|
found_data = dict.fromkeys(range(self.data.size(0)), 0)
|
|
found_labels = dict.fromkeys(range(self.labels.size(0)), 0)
|
|
batch_size = loader.batch_size
|
|
if batch_size is None:
|
|
for i, (batch_samples, batch_targets) in enumerate(loader):
|
|
sample, target = (batch_samples, batch_targets)
|
|
for data_point_idx, data_point in enumerate(self.data):
|
|
if data_point.eq(sample).all():
|
|
self.assertFalse(found_data[data_point_idx])
|
|
found_data[data_point_idx] += 1
|
|
break
|
|
self.assertEqual(target, self.labels[data_point_idx])
|
|
found_labels[data_point_idx] += 1
|
|
self.assertEqual(sum(found_data.values()), (i + 1))
|
|
self.assertEqual(sum(found_labels.values()), (i + 1))
|
|
self.assertEqual(i, (len(self.dataset) - 1))
|
|
else:
|
|
for i, (batch_samples, batch_targets) in enumerate(loader):
|
|
for sample, target in zip(batch_samples, batch_targets):
|
|
for data_point_idx, data_point in enumerate(self.data):
|
|
if data_point.eq(sample).all():
|
|
self.assertFalse(found_data[data_point_idx])
|
|
found_data[data_point_idx] += 1
|
|
break
|
|
self.assertEqual(target, self.labels[data_point_idx])
|
|
found_labels[data_point_idx] += 1
|
|
self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
|
|
self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
|
|
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
|
|
|
|
def _test_error(self, loader):
|
|
it = iter(loader)
|
|
errors = 0
|
|
while True:
|
|
try:
|
|
next(it)
|
|
except NotImplementedError:
|
|
errors += 1
|
|
except StopIteration:
|
|
self.assertEqual(
|
|
errors, math.ceil(float(len(loader.dataset)) / loader.batch_size)
|
|
)
|
|
return
|
|
|
|
def test_error_in_init(self):
|
|
for num_workers in [0, 2]:
|
|
loader = self._get_data_loader(
|
|
ErrorIterableDataset(), num_workers=num_workers
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "Error in __iter__"):
|
|
list(iter(loader))
|
|
|
|
loader = self._get_data_loader(
|
|
self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "Error in worker_init_fn"):
|
|
list(iter(loader))
|
|
|
|
def test_typing(self):
|
|
from typing import List
|
|
|
|
# Make sure there is no TypeError
|
|
|
|
class SomeDatasetClass(Dataset[List[torch.Tensor]]):
|
|
pass
|
|
|
|
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
|
|
pass
|
|
|
|
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
|
|
@unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
|
|
def test_fd_limit_exceeded(self):
|
|
# See NOTE [ DataLoader on Linux and open files limit ]
|
|
import subprocess
|
|
|
|
subprocess.check_output(
|
|
[
|
|
sys.executable,
|
|
"-c",
|
|
"""\
|
|
import torch
|
|
import resource
|
|
from torch.utils.data import DataLoader, IterableDataset
|
|
|
|
class RandomDataset(IterableDataset):
|
|
def __init__(self, len, size):
|
|
super(RandomDataset).__init__()
|
|
self.len = len
|
|
self.size = size
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.len <= 0:
|
|
raise StopIteration
|
|
self.len -= 1
|
|
return torch.randn(self.size)
|
|
|
|
try:
|
|
keep_fds_alive = []
|
|
resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
|
|
for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
|
|
num_workers=1):
|
|
random_t.max(dim=0)
|
|
keep_fds_alive.append(random_t)
|
|
except RuntimeError as e:
|
|
assert "ulimit -n" in str(e)
|
|
assert "set_sharing_strategy" in str(e)
|
|
""",
|
|
]
|
|
)
|
|
|
|
def test_invalid_assign_after_init(self):
|
|
dl = self._get_data_loader(self.dataset)
|
|
for attr in ("batch_size", "sampler", "batch_sampler", "drop_last", "dataset"):
|
|
|
|
def fn():
|
|
setattr(dl, attr, {})
|
|
|
|
self.assertRaises(ValueError, fn)
|
|
|
|
def test_sequential_nonbatch(self):
|
|
self._test_sequential(self._get_data_loader(self.dataset, batch_size=None))
|
|
|
|
def test_sequential_batch(self):
|
|
self._test_sequential(self._get_data_loader(self.dataset))
|
|
self._test_sequential(self._get_data_loader(self.dataset, batch_size=2))
|
|
|
|
def test_bulk_loading_nobatch(self):
|
|
n = 35
|
|
bs = 4
|
|
ds = BulkLoadingDataset(n)
|
|
sampler = BulkLoadingSampler(ds, batch_size=4)
|
|
|
|
for num_workers in [0, 4]:
|
|
dl = self._get_data_loader(
|
|
ds,
|
|
num_workers=num_workers,
|
|
batch_size=None,
|
|
sampler=sampler,
|
|
pin_memory=TEST_CUDA,
|
|
)
|
|
self.assertFalse(dl._auto_collation)
|
|
samples = list(dl)
|
|
self.assertEqual(samples[0].is_pinned(), TEST_CUDA)
|
|
self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n)))
|
|
|
|
def test_growing_dataset(self):
|
|
dataset = [torch.ones(4) for _ in range(4)]
|
|
dataloader_seq = self._get_data_loader(dataset, shuffle=False)
|
|
dataloader_shuffle = self._get_data_loader(dataset, shuffle=True)
|
|
dataset.append(torch.ones(4))
|
|
self.assertEqual(len(dataloader_seq), 5)
|
|
self.assertEqual(len(dataloader_shuffle), 5)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_sequential_pin_memory(self):
|
|
loader = self._get_data_loader(self.dataset, batch_size=2, pin_memory=True)
|
|
for input, target in loader:
|
|
self.assertTrue(input.is_pinned())
|
|
self.assertTrue(target.is_pinned())
|
|
|
|
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
|
|
def test_multiple_dataloaders(self):
|
|
for multiprocessing_context in supported_multiprocessing_contexts:
|
|
loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1))
|
|
loader2_it = iter(
|
|
self._get_data_loader(
|
|
self.dataset,
|
|
num_workers=2,
|
|
multiprocessing_context=multiprocessing_context,
|
|
)
|
|
)
|
|
next(loader1_it)
|
|
next(loader1_it)
|
|
next(loader2_it)
|
|
next(loader2_it)
|
|
next(loader1_it)
|
|
next(loader2_it)
|
|
del loader1_it
|
|
del loader2_it
|
|
|
|
def test_segfault(self):
|
|
p = ErrorTrackingProcess(target=_test_segfault)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertNotEqual(p.exitcode, 0)
|
|
if IS_WINDOWS:
|
|
self.assertIsInstance(p.exception, OSError)
|
|
self.assertRegex(str(p.exception), r"access violation reading ")
|
|
else:
|
|
self.assertIsInstance(p.exception, RuntimeError)
|
|
self.assertRegex(
|
|
str(p.exception),
|
|
r"DataLoader worker \(pid \d+\) is killed by signal: ",
|
|
)
|
|
finally:
|
|
p.terminate()
|
|
|
|
# Tests if the child process forked by the DataLoader segfaults due to having more than 3 threads
|
|
# in the parent process after at least one set_num_threads invocation in the parent process.
|
|
# After forking, set_num_threads(1) in the child process entails handling some inherited data-structures
|
|
# of the Caffe2 thread-pool of the parent process, culminating in a segfault.
|
|
# Reference: https://github.com/pytorch/pytorch/issues/54752
|
|
@unittest.skipIf(IS_WINDOWS, "Needs fork")
|
|
def test_no_segfault(self):
|
|
p = ErrorTrackingProcess(target=_test_no_segfault)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
if p.exception:
|
|
self.assertIsInstance(p.exception, RuntimeError)
|
|
self.assertRegex(
|
|
str(p.exception),
|
|
r"DataLoader worker \(pid \d+\) is killed by signal: ",
|
|
)
|
|
self.fail("Segfault occurred in worker process after fork")
|
|
finally:
|
|
p.terminate()
|
|
|
|
def test_timeout(self):
|
|
if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
|
|
# This test runs in a subprocess, which can only initialize CUDA with spawn.
|
|
# _test_timeout_pin_memory with pin_memory=True initializes CUDA when the iterator is
|
|
# constructed.
|
|
targets = (_test_timeout, _test_timeout_pin_memory)
|
|
else:
|
|
targets = (_test_timeout,)
|
|
for target in targets:
|
|
p = ErrorTrackingProcess(target=target, args=(self.persistent_workers,))
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertNotEqual(p.exitcode, 0)
|
|
self.assertIsInstance(p.exception, RuntimeError)
|
|
self.assertRegex(
|
|
str(p.exception), r"DataLoader timed out after \d+ seconds"
|
|
)
|
|
finally:
|
|
p.terminate()
|
|
|
|
def test_large_sampler_indices(self):
|
|
# Test that the data loader cleanly exit when the process errors
|
|
# 1. having an reference to the iterator
|
|
# 2. using a sampler that yields big elements s.t. _index_queues putters block
|
|
#
|
|
# More context: https://github.com/pytorch/pytorch/issues/48666
|
|
|
|
p = ErrorTrackingProcess(
|
|
target=_test_large_sampler_indices, args=(self.persistent_workers,)
|
|
)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertNotEqual(p.exitcode, 0)
|
|
self.assertIsInstance(p.exception, RuntimeError)
|
|
self.assertRegex(str(p.exception), r"My Error")
|
|
finally:
|
|
p.terminate()
|
|
|
|
def test_invalid_ctor_args_combinations(self):
|
|
# general
|
|
with self.assertRaisesRegex(
|
|
ValueError, "num_workers option should be non-negative"
|
|
):
|
|
self._get_data_loader(self.dataset, num_workers=-1)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "timeout option should be non-negative"
|
|
):
|
|
self._get_data_loader(self.dataset, timeout=-1)
|
|
|
|
# disable auto-batching
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"batch_size=None option disables auto-batching and is mutually exclusive",
|
|
):
|
|
self._get_data_loader(self.dataset, batch_size=None, drop_last=True)
|
|
|
|
valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1]
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"multi-process loading \(num_workers > 0\), but got"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, num_workers=0, multiprocessing_context=valid_ctx
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "should specify a valid start method in"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, num_workers=1, multiprocessing_context="bad"
|
|
)
|
|
with self.assertRaisesRegex(
|
|
TypeError, "multiprocessing_context option should be a valid context "
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, num_workers=1, multiprocessing_context=object()
|
|
)
|
|
|
|
# map-style
|
|
sampler = torch.utils.data.SequentialSampler(self.dataset)
|
|
batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "sampler option is mutually exclusive with shuffle"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, batch_size=11, sampler=sampler, shuffle=True
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "sampler option is mutually exclusive with shuffle"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "sampler option is mutually exclusive with shuffle"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "batch_sampler option is mutually exclusive with"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, batch_size=11, batch_sampler=batch_sampler
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "batch_sampler option is mutually exclusive with"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, shuffle=True, batch_sampler=batch_sampler
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "batch_sampler option is mutually exclusive with"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, drop_last=True, batch_sampler=batch_sampler
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "batch_sampler option is mutually exclusive with"
|
|
):
|
|
self._get_data_loader(
|
|
self.dataset, drop_last=3, batch_sampler=batch_sampler
|
|
)
|
|
|
|
# iterable-style
|
|
dataset = CountingIterableDataset(20)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
|
|
):
|
|
self._get_data_loader(dataset, shuffle=True)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
|
|
):
|
|
self._get_data_loader(dataset, shuffle=3)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
|
|
):
|
|
self._get_data_loader(
|
|
dataset, sampler=torch.utils.data.SequentialSampler(dataset)
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
|
|
):
|
|
self._get_data_loader(dataset, sampler=3)
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"DataLoader with IterableDataset: expected unspecified batch_sampler",
|
|
):
|
|
self._get_data_loader(
|
|
dataset,
|
|
batch_sampler=torch.utils.data.BatchSampler(
|
|
torch.utils.data.SequentialSampler(dataset), 3, False
|
|
),
|
|
)
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"DataLoader with IterableDataset: expected unspecified batch_sampler",
|
|
):
|
|
self._get_data_loader(dataset, batch_sampler=3)
|
|
|
|
def test_builtin_collection_conversion(self):
|
|
for coll_ty in (list, tuple):
|
|
for num_workers in (0, 1):
|
|
# map-style dataset
|
|
dataset = CountingDataset(20)
|
|
# no auto-batching
|
|
fetched = coll_ty(
|
|
self._get_data_loader(
|
|
dataset, batch_size=None, num_workers=num_workers
|
|
)
|
|
)
|
|
self.assertEqual(fetched, coll_ty(range(20)))
|
|
# auto-batching
|
|
fetched = coll_ty(
|
|
self._get_data_loader(
|
|
dataset, batch_size=2, num_workers=num_workers
|
|
)
|
|
)
|
|
self.assertEqual(
|
|
fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
|
|
)
|
|
|
|
# iterable-style dataset
|
|
dataset = CountingIterableDataset(20)
|
|
# no auto-batching
|
|
fetched = coll_ty(
|
|
self._get_data_loader(
|
|
dataset, batch_size=None, num_workers=num_workers
|
|
)
|
|
)
|
|
self.assertEqual(fetched, coll_ty(range(20)))
|
|
# auto-batching
|
|
# this IterableDataset isn't configured for each worker, so for
|
|
# the equality test below to be valid, we cannot have more than 1 workers.
|
|
assert num_workers in [0, 1], "invalid test"
|
|
fetched = coll_ty(
|
|
self._get_data_loader(
|
|
dataset, batch_size=2, num_workers=num_workers
|
|
)
|
|
)
|
|
self.assertEqual(
|
|
fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
|
|
)
|
|
|
|
def test_iterable_style_dataset(self):
|
|
# [no auto-batching] single process loading
|
|
dataset = CountingIterableDataset(20)
|
|
dataloader = self._get_data_loader(dataset, batch_size=None)
|
|
fetched = list(dataloader)
|
|
self.assertEqual(len(fetched), 20)
|
|
for i, d in enumerate(fetched):
|
|
# non-batched should not convert ints into tensors
|
|
self.assertIsInstance(d, int)
|
|
self.assertEqual(d, i)
|
|
# DataLoader should match len of the iterable-style dataset (if implemented)
|
|
self.assertEqual(len(dataloader), len(dataset))
|
|
|
|
# [no auto-batching] multiprocessing loading
|
|
num_workers = 3
|
|
sizes_for_all_workers = [0, 4, 20]
|
|
expected = sorted(
|
|
functools.reduce(
|
|
operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
|
|
)
|
|
)
|
|
assert len(sizes_for_all_workers) == num_workers, "invalid test case"
|
|
for prefetch_factor in [2, 3, 4]:
|
|
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
|
|
dataloader = self._get_data_loader(
|
|
dataset,
|
|
num_workers=num_workers,
|
|
batch_size=None,
|
|
worker_init_fn=set_faulthander_if_available,
|
|
prefetch_factor=prefetch_factor,
|
|
)
|
|
dataloader_iter = iter(dataloader)
|
|
fetched = sorted(dataloader_iter)
|
|
for a, b in zip(fetched, expected):
|
|
# non-batched should not convert ints into tensors
|
|
self.assertIsInstance(a, int)
|
|
self.assertEqual(a, b)
|
|
# DataLoader should match len of the iterable-style dataset (if implemented)
|
|
self.assertEqual(len(dataloader), len(dataset))
|
|
# When loading more than len(dataset) data, after accessing len(dataloader),
|
|
# we should get a warning. See NOTE [ IterableDataset and __len__ ].
|
|
dataset = CountingIterableDataset(20)
|
|
dataloader = self._get_data_loader(
|
|
dataset,
|
|
num_workers=num_workers,
|
|
worker_init_fn=set_faulthander_if_available,
|
|
prefetch_factor=prefetch_factor,
|
|
)
|
|
it = iter(dataloader)
|
|
for _ in range(40):
|
|
self.assertNotWarn(
|
|
lambda: next(it), "Should not warn before accessing len(dataloader)"
|
|
)
|
|
self.assertEqual(len(dataloader), len(dataset))
|
|
self.assertEqual(len(dataloader), 20)
|
|
it = iter(dataloader)
|
|
for _ in range(20):
|
|
self.assertNotWarn(
|
|
lambda: next(it), "Should not warn before exceeding length"
|
|
)
|
|
for _ in range(3):
|
|
with self.assertWarnsRegex(
|
|
UserWarning,
|
|
r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this",
|
|
msg="Should always warn after exceeding length",
|
|
):
|
|
next(it)
|
|
# [no auto-batching] test that workers exit gracefully
|
|
workers = dataloader_iter._workers
|
|
del dataloader_iter
|
|
del dataloader
|
|
try:
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive())
|
|
self.assertEqual(w.exitcode, 0)
|
|
finally:
|
|
for w in workers:
|
|
w.terminate()
|
|
|
|
# [auto-batching] single process loading
|
|
dataset = CountingIterableDataset(20)
|
|
fetched = list(self._get_data_loader(dataset, batch_size=7))
|
|
self.assertEqual(len(fetched), 3)
|
|
self.assertEqual(fetched[0].tolist(), list(range(7)))
|
|
self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
|
|
self.assertEqual(fetched[2].tolist(), list(range(14, 20)))
|
|
|
|
# [auto-batching] multiprocessing loading
|
|
num_workers = 3
|
|
sizes_for_all_workers = [0, 4, 20]
|
|
expected = sorted(
|
|
functools.reduce(
|
|
operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
|
|
)
|
|
)
|
|
assert len(sizes_for_all_workers) == num_workers, "invalid test case"
|
|
for prefetch_factor in [2, 3, 4]:
|
|
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
|
|
# worker 0 should return 0 batches
|
|
# worker 1 should return 1 batches
|
|
# worker 2 should return 3 batches
|
|
dataloader = self._get_data_loader(
|
|
dataset,
|
|
num_workers=num_workers,
|
|
batch_size=7,
|
|
prefetch_factor=prefetch_factor,
|
|
)
|
|
dataloader_iter = iter(dataloader)
|
|
fetched = list(dataloader_iter)
|
|
self.assertEqual(len(fetched), 4)
|
|
fetched = {tuple(t.tolist()) for t in fetched}
|
|
self.assertEqual(
|
|
fetched,
|
|
{
|
|
tuple(range(4)),
|
|
tuple(range(7)),
|
|
tuple(range(7, 14)),
|
|
tuple(range(14, 20)),
|
|
},
|
|
)
|
|
|
|
# [auto-batching] test that workers exit gracefully
|
|
workers = dataloader_iter._workers
|
|
del dataloader_iter
|
|
del dataloader
|
|
try:
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive())
|
|
self.assertEqual(w.exitcode, 0)
|
|
finally:
|
|
for w in workers:
|
|
w.terminate()
|
|
# [auto-batching & drop_last] single process loading
|
|
dataset = CountingIterableDataset(20)
|
|
fetched = list(self._get_data_loader(dataset, batch_size=7, drop_last=True))
|
|
self.assertEqual(len(fetched), 2)
|
|
self.assertEqual(fetched[0].tolist(), list(range(7)))
|
|
self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
|
|
|
|
# [auto-batching & drop_last] multiprocessing loading
|
|
num_workers = 3
|
|
sizes_for_all_workers = [0, 4, 20]
|
|
expected = sorted(
|
|
functools.reduce(
|
|
operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
|
|
)
|
|
)
|
|
assert len(sizes_for_all_workers) == num_workers, "invalid test case"
|
|
for prefetch_factor in [2, 3, 4]:
|
|
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
|
|
# worker 0 should return 0 batches
|
|
# worker 1 should return 1 batches
|
|
# worker 2 should return 3 batches
|
|
dataloader = self._get_data_loader(
|
|
dataset,
|
|
num_workers=num_workers,
|
|
batch_size=7,
|
|
drop_last=True,
|
|
worker_init_fn=set_faulthander_if_available,
|
|
prefetch_factor=prefetch_factor,
|
|
)
|
|
dataloader_iter = iter(dataloader)
|
|
fetched = list(dataloader_iter)
|
|
self.assertEqual(len(fetched), 2)
|
|
fetched = {tuple(t.tolist()) for t in fetched}
|
|
self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
|
|
|
|
# [auto-batching & drop_last] test that workers exit gracefully
|
|
workers = dataloader_iter._workers
|
|
del dataloader_iter
|
|
del dataloader
|
|
try:
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive())
|
|
self.assertEqual(w.exitcode, 0)
|
|
finally:
|
|
for w in workers:
|
|
w.terminate()
|
|
|
|
def test_chain_iterable_style_dataset(self):
|
|
# chaining (concatenation)
|
|
dataset1 = CountingIterableDataset(20)
|
|
dataset2 = CountingIterableDataset(15)
|
|
expected = list(range(20)) + list(range(15))
|
|
for num_workers in [0, 1]:
|
|
for chained_dataset in [
|
|
dataset1 + dataset2,
|
|
ChainDataset([dataset1, dataset2]),
|
|
]:
|
|
fetched = list(
|
|
self._get_data_loader(chained_dataset, num_workers=num_workers)
|
|
)
|
|
self.assertEqual(len(fetched), len(expected))
|
|
for e, d in zip(expected, fetched):
|
|
self.assertIsInstance(d, torch.Tensor)
|
|
self.assertEqual(e, d)
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "ChainDataset only supports IterableDataset"
|
|
):
|
|
list(iter(dataset1 + self.dataset))
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "ChainDataset only supports IterableDataset"
|
|
):
|
|
list(iter(ChainDataset([dataset1, self.dataset])))
|
|
|
|
@unittest.skipIf(IS_MACOS, "Not working on macos")
|
|
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
|
|
@skipIfRocm # https://github.com/pytorch/pytorch/issues/90940
|
|
def test_multiprocessing_contexts(self):
|
|
reference = [
|
|
torch.arange(3),
|
|
torch.arange(3, 6),
|
|
torch.arange(6, 9),
|
|
torch.arange(9, 11),
|
|
]
|
|
counting_ds_n = 11
|
|
dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA))
|
|
for ctx in supported_multiprocessing_contexts:
|
|
# windows and jetson devices don't support sharing cuda tensor; ROCm does not yet fully support IPC
|
|
if (
|
|
ctx in ["spawn", "forkserver"]
|
|
and TEST_CUDA
|
|
and not IS_WINDOWS
|
|
and not IS_JETSON
|
|
):
|
|
ds_cls = CUDACountingDataset
|
|
else:
|
|
ds_cls = CountingDataset
|
|
self.assertEqual(
|
|
reference,
|
|
list(
|
|
self._get_data_loader(
|
|
ds_cls(counting_ds_n),
|
|
multiprocessing_context=ctx,
|
|
**dl_common_args,
|
|
)
|
|
),
|
|
)
|
|
if ctx is not None:
|
|
# test ctx object
|
|
ctx = mp.get_context(ctx)
|
|
self.assertEqual(
|
|
reference,
|
|
list(
|
|
self._get_data_loader(
|
|
ds_cls(counting_ds_n),
|
|
multiprocessing_context=ctx,
|
|
**dl_common_args,
|
|
)
|
|
),
|
|
)
|
|
|
|
def _test_multiprocessing_iterdatapipe(self, with_dill):
|
|
# Testing to make sure that function from global scope (e.g. imported from library) can be serialized
|
|
# and used with multiprocess DataLoader
|
|
|
|
reference = [
|
|
torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
|
|
torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
|
|
]
|
|
datapipe: IterDataPipe = IterableWrapper([[1, 2, 3, 4], [1, 2, 3, 4, 5, 6]])
|
|
datapipe = datapipe.map(row_processor)
|
|
datapipe = (
|
|
datapipe.filter(lambda row: len(row) == 4)
|
|
if with_dill
|
|
else datapipe.filter(filter_len)
|
|
)
|
|
|
|
dl_common_args = dict(
|
|
num_workers=2, batch_size=2, shuffle=True, pin_memory=(not TEST_CUDA)
|
|
)
|
|
for ctx in supported_multiprocessing_contexts:
|
|
self.assertEqual(
|
|
reference,
|
|
[
|
|
t.type(torch.int64)
|
|
for t in self._get_data_loader(
|
|
datapipe, multiprocessing_context=ctx, **dl_common_args
|
|
)
|
|
],
|
|
)
|
|
if ctx is not None:
|
|
# test ctx object
|
|
ctx = mp.get_context(ctx)
|
|
self.assertEqual(
|
|
reference,
|
|
[
|
|
t.type(torch.int64)
|
|
for t in self._get_data_loader(
|
|
datapipe, multiprocessing_context=ctx, **dl_common_args
|
|
)
|
|
],
|
|
)
|
|
|
|
@skipIfNoNumpy
|
|
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
|
|
def test_multiprocessing_iterdatapipe(self):
|
|
self._test_multiprocessing_iterdatapipe(with_dill=False)
|
|
|
|
@unittest.expectedFailure
|
|
@skipIfNoNumpy
|
|
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
|
|
@skipIfNoDill
|
|
def test_multiprocessing_iterdatapipe_with_dill(self):
|
|
self._test_multiprocessing_iterdatapipe(with_dill=True)
|
|
|
|
def test_worker_seed(self):
|
|
num_workers = 6
|
|
batch_size = 1
|
|
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
|
|
dataloader = self._get_data_loader(
|
|
dataset, batch_size=batch_size, num_workers=num_workers
|
|
)
|
|
seeds = set()
|
|
seeds.update(batch[0] for batch in dataloader)
|
|
self.assertEqual(len(seeds), num_workers)
|
|
|
|
def test_worker_seed_reproducibility(self):
|
|
def get_dataloader():
|
|
return DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
generator=torch.Generator().manual_seed(42),
|
|
)
|
|
|
|
num_workers = 6
|
|
batch_size = 1
|
|
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
|
|
self.assertEqual(
|
|
{int(batch) for batch in get_dataloader()},
|
|
{int(batch) for batch in get_dataloader()},
|
|
)
|
|
|
|
def test_multi_epochs_reproducibility(self):
|
|
num_workers = 2
|
|
batch_size = 10
|
|
num_epochs = 3
|
|
|
|
dataset = TestMultiEpochDataset(batch_size * num_workers)
|
|
dataloader = self._get_data_loader(
|
|
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
|
)
|
|
|
|
for ind in range(num_epochs):
|
|
for batch_idx, sample in enumerate(dataloader):
|
|
self.assertEqual(
|
|
sample.tolist(), [batch_idx % num_workers] * batch_size
|
|
)
|
|
|
|
def test_worker_init_fn(self):
|
|
dataset = SeedDataset(4)
|
|
dataloader = self._get_data_loader(
|
|
dataset, batch_size=2, num_workers=2, worker_init_fn=init_fn
|
|
)
|
|
for batch in dataloader:
|
|
self.assertEqual(12345, batch[0])
|
|
self.assertEqual(12345, batch[1])
|
|
|
|
def test_get_worker_info(self):
|
|
p = ErrorTrackingProcess(target=_test_get_worker_info)
|
|
p.start()
|
|
p.join(JOIN_TIMEOUT)
|
|
try:
|
|
self.assertFalse(p.is_alive())
|
|
self.assertEqual(p.exitcode, 0)
|
|
finally:
|
|
p.terminate()
|
|
|
|
def test_shuffle(self):
|
|
self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True))
|
|
|
|
def test_shuffle_batch_none(self):
|
|
self._test_shuffle(DataLoader(self.dataset, batch_size=None, shuffle=True))
|
|
|
|
def test_shuffle_batch(self):
|
|
self._test_shuffle(
|
|
self._get_data_loader(self.dataset, batch_size=2, shuffle=True)
|
|
)
|
|
|
|
def test_shuffle_reproducibility(self):
|
|
for fn in (
|
|
lambda: DataLoader(
|
|
self.dataset,
|
|
shuffle=True,
|
|
num_workers=0,
|
|
generator=torch.Generator().manual_seed(42),
|
|
),
|
|
lambda: DataLoader(
|
|
self.dataset,
|
|
shuffle=True,
|
|
num_workers=2,
|
|
generator=torch.Generator().manual_seed(42),
|
|
),
|
|
):
|
|
self.assertEqual(list(fn()), list(fn()))
|
|
|
|
def test_sequential_workers(self):
|
|
self._test_sequential(self._get_data_loader(self.dataset, num_workers=4))
|
|
|
|
def test_seqential_batch_workers(self):
|
|
self._test_sequential(
|
|
self._get_data_loader(self.dataset, batch_size=2, num_workers=4)
|
|
)
|
|
|
|
def test_seqential_batch_workers_prefetch(self):
|
|
self._test_sequential(
|
|
DataLoader(self.dataset, batch_size=2, num_workers=4, prefetch_factor=3)
|
|
)
|
|
|
|
def test_shuffle_workers(self):
|
|
self._test_shuffle(
|
|
self._get_data_loader(self.dataset, shuffle=True, num_workers=4)
|
|
)
|
|
|
|
def test_shuffle_batch_workers(self):
|
|
self._test_shuffle(
|
|
self._get_data_loader(
|
|
self.dataset, batch_size=2, shuffle=True, num_workers=4
|
|
)
|
|
)
|
|
|
|
def test_shuffle_batch_workers_prefetch(self):
|
|
self._test_shuffle(
|
|
DataLoader(
|
|
self.dataset,
|
|
batch_size=2,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
prefetch_factor=3,
|
|
)
|
|
)
|
|
|
|
def test_random_sampler(self):
|
|
from collections import Counter
|
|
|
|
from torch.utils.data import RandomSampler
|
|
|
|
def sample_stat(sampler, num_samples):
|
|
counts = Counter(sampler)
|
|
count_repeated = sum(val > 1 for val in counts.values())
|
|
return (
|
|
count_repeated,
|
|
min(counts.keys()),
|
|
max(counts.keys()),
|
|
sum(counts.values()),
|
|
)
|
|
|
|
# test sample with replacement
|
|
n = len(self.dataset) + 1 # ensure at least one sample is drawn more than once
|
|
sampler_with_replacement = RandomSampler(
|
|
self.dataset, replacement=True, num_samples=n
|
|
)
|
|
count_repeated, minval, maxval, count_total = sample_stat(
|
|
sampler_with_replacement, n
|
|
)
|
|
self.assertTrue(count_repeated > 0)
|
|
self.assertTrue(minval >= 0)
|
|
self.assertTrue(maxval < len(self.dataset))
|
|
self.assertTrue(count_total == n)
|
|
|
|
# test sample without replacement and without specified num_samples
|
|
sampler_without_replacement = RandomSampler(self.dataset)
|
|
count_repeated, minval, maxval, count_total = sample_stat(
|
|
sampler_without_replacement, len(self.dataset)
|
|
)
|
|
self.assertTrue(count_repeated == 0)
|
|
self.assertTrue(minval == 0)
|
|
self.assertTrue(maxval == len(self.dataset) - 1)
|
|
self.assertTrue(count_total == len(self.dataset))
|
|
|
|
# test sample without replacement and with specified num_samples
|
|
n = len(self.dataset) * 2
|
|
sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
|
|
count_repeated, minval, maxval, count_total = sample_stat(
|
|
sampler_without_replacement, len(self.dataset)
|
|
)
|
|
self.assertTrue(count_repeated == len(self.dataset))
|
|
self.assertTrue(minval == 0)
|
|
self.assertTrue(maxval == len(self.dataset) - 1)
|
|
self.assertTrue(count_total == n)
|
|
|
|
n = len(self.dataset) - 1
|
|
sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
|
|
count_repeated, minval, maxval, count_total = sample_stat(
|
|
sampler_without_replacement, len(self.dataset)
|
|
)
|
|
self.assertTrue(count_repeated == 0)
|
|
self.assertTrue(minval >= 0)
|
|
self.assertTrue(maxval < len(self.dataset))
|
|
self.assertTrue(count_total == n)
|
|
|
|
n = len(self.dataset) + 1
|
|
sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
|
|
count_repeated, minval, maxval, count_total = sample_stat(
|
|
sampler_without_replacement, len(self.dataset)
|
|
)
|
|
self.assertTrue(count_repeated == 1)
|
|
self.assertTrue(minval == 0)
|
|
self.assertTrue(maxval == len(self.dataset) - 1)
|
|
self.assertTrue(count_total == n)
|
|
|
|
# raise error when replacement is non-boolean
|
|
with self.assertRaisesRegex(
|
|
TypeError, "replacement should be a boolean value, but got replacement=0"
|
|
):
|
|
RandomSampler(self.dataset, replacement=0)
|
|
|
|
def test_random_sampler_len_with_replacement(self):
|
|
from torch.utils.data import RandomSampler
|
|
|
|
# add 5 extra samples
|
|
num_samples = len(self.dataset) + 5
|
|
sampler = RandomSampler(self.dataset, replacement=True, num_samples=num_samples)
|
|
# test len method
|
|
self.assertEqual(num_samples, len(sampler))
|
|
|
|
# test with iteration
|
|
count_num_samples = sum(1 for _ in sampler)
|
|
self.assertEqual(num_samples, count_num_samples)
|
|
|
|
# test with dataloader, batch_size = 1
|
|
batch_size = 1
|
|
count_num_samples_in_data_loader = len(
|
|
self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
|
|
)
|
|
self.assertEqual(num_samples, count_num_samples_in_data_loader)
|
|
|
|
# test with dataloader, batch_size = 6
|
|
batch_size = 6
|
|
count_num_samples_in_data_loader = len(
|
|
self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
|
|
)
|
|
self.assertEqual(
|
|
int(math.ceil(float(num_samples) / batch_size)),
|
|
count_num_samples_in_data_loader,
|
|
)
|
|
|
|
def test_random_sampler_len_without_replacement(self):
|
|
from torch.utils.data import RandomSampler
|
|
|
|
# add 5 extra samples
|
|
num_samples = len(self.dataset) + 5
|
|
sampler = RandomSampler(
|
|
self.dataset, replacement=False, num_samples=num_samples
|
|
)
|
|
# test len method
|
|
self.assertEqual(num_samples, len(sampler))
|
|
|
|
# test with iteration
|
|
count_num_samples = sum(1 for _ in sampler)
|
|
self.assertEqual(num_samples, count_num_samples)
|
|
|
|
# test with dataloader, batch_size = 1
|
|
batch_size = 1
|
|
count_num_samples_in_data_loader = len(
|
|
self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
|
|
)
|
|
self.assertEqual(num_samples, count_num_samples_in_data_loader)
|
|
|
|
# test with dataloader, batch_size = 6
|
|
batch_size = 6
|
|
count_num_samples_in_data_loader = len(
|
|
self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
|
|
)
|
|
self.assertEqual(
|
|
num_samples // batch_size + (num_samples % batch_size > 0),
|
|
count_num_samples_in_data_loader,
|
|
)
|
|
|
|
def test_distributed_sampler_invalid_rank(self):
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
dataset = torch.IntTensor(range(10))
|
|
with self.assertRaisesRegex(ValueError, "Invalid rank"):
|
|
sampler = DistributedSampler(dataset, 3, 3)
|
|
|
|
with self.assertRaisesRegex(ValueError, "Invalid rank"):
|
|
sampler = DistributedSampler(dataset, 3, -1)
|
|
|
|
def test_duplicating_data_with_drop_last(self):
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
num_processes = 4
|
|
num_batches = 9
|
|
data_set = torch.IntTensor(range(num_batches))
|
|
scanned_data = torch.IntTensor([])
|
|
for i in range(num_processes):
|
|
s = DistributedSampler(data_set, num_processes, i)
|
|
d_loader = self._get_data_loader(
|
|
data_set,
|
|
batch_size=int(num_batches / num_processes),
|
|
drop_last=True,
|
|
sampler=s,
|
|
)
|
|
for data in d_loader:
|
|
scanned_data = torch.cat((scanned_data, data), 0)
|
|
|
|
self.assertEqual(scanned_data.size(), scanned_data.unique().size())
|
|
|
|
def test_sampler_reproducibility(self):
|
|
from torch.utils.data import (
|
|
RandomSampler,
|
|
SubsetRandomSampler,
|
|
WeightedRandomSampler,
|
|
)
|
|
|
|
weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6]
|
|
for fn in (
|
|
lambda: RandomSampler(
|
|
self.dataset,
|
|
num_samples=5,
|
|
replacement=True,
|
|
generator=torch.Generator().manual_seed(42),
|
|
),
|
|
lambda: RandomSampler(
|
|
self.dataset,
|
|
replacement=False,
|
|
generator=torch.Generator().manual_seed(42),
|
|
),
|
|
lambda: WeightedRandomSampler(
|
|
weights,
|
|
num_samples=5,
|
|
replacement=True,
|
|
generator=torch.Generator().manual_seed(42),
|
|
),
|
|
lambda: WeightedRandomSampler(
|
|
weights,
|
|
num_samples=5,
|
|
replacement=False,
|
|
generator=torch.Generator().manual_seed(42),
|
|
),
|
|
lambda: SubsetRandomSampler(
|
|
range(10), generator=torch.Generator().manual_seed(42)
|
|
),
|
|
):
|
|
self.assertEqual(list(fn()), list(fn()))
|
|
|
|
for sampler in (
|
|
RandomSampler(self.dataset, num_samples=5, replacement=True),
|
|
RandomSampler(self.dataset, replacement=False),
|
|
WeightedRandomSampler(weights, num_samples=5, replacement=True),
|
|
WeightedRandomSampler(weights, num_samples=5, replacement=False),
|
|
SubsetRandomSampler(range(10)),
|
|
):
|
|
torch.manual_seed(0)
|
|
l1 = list(sampler) + list(sampler)
|
|
|
|
torch.manual_seed(0)
|
|
l2 = list(sampler) + list(sampler)
|
|
self.assertEqual(l1, l2)
|
|
|
|
its = (iter(sampler), iter(sampler))
|
|
ls = ([], [])
|
|
for idx in range(len(sampler)):
|
|
for i in range(2):
|
|
if idx == 0:
|
|
torch.manual_seed(0)
|
|
ls[i].append(next(its[i]))
|
|
self.assertEqual(ls[0], ls[1])
|
|
|
|
def _test_sampler(self, **kwargs):
|
|
indices = range(2, 12) # using a regular iterable
|
|
dl = self._get_data_loader(
|
|
self.dataset, sampler=indices, batch_size=2, **kwargs
|
|
)
|
|
self.assertEqual(len(dl), 5)
|
|
for i, (input, _target) in enumerate(dl):
|
|
self.assertEqual(len(input), 2)
|
|
self.assertEqual(input, self.data[i * 2 + 2 : i * 2 + 4])
|
|
|
|
def test_sampler(self):
|
|
self._test_sampler()
|
|
self._test_sampler(num_workers=4)
|
|
if not NO_MULTIPROCESSING_SPAWN:
|
|
self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
|
|
|
|
def _test_batch_sampler(self, **kwargs):
|
|
# [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
|
|
batches = [] # using a regular iterable
|
|
for i in range(0, 20, 5):
|
|
batches.append(tuple(range(i, i + 2)))
|
|
batches.append(tuple(range(i + 2, i + 5)))
|
|
|
|
dl = self._get_data_loader(self.dataset, batch_sampler=batches, **kwargs)
|
|
self.assertEqual(len(dl), 8)
|
|
for i, (input, _target) in enumerate(dl):
|
|
if i % 2 == 0:
|
|
offset = i * 5 // 2
|
|
self.assertEqual(len(input), 2)
|
|
self.assertEqual(input, self.data[offset : offset + 2])
|
|
else:
|
|
offset = i * 5 // 2
|
|
self.assertEqual(len(input), 3)
|
|
self.assertEqual(input, self.data[offset : offset + 3])
|
|
|
|
def test_batch_sampler(self):
|
|
self._test_batch_sampler()
|
|
self._test_batch_sampler(num_workers=4)
|
|
if not NO_MULTIPROCESSING_SPAWN:
|
|
self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_shuffle_pin_memory(self):
|
|
loader = self._get_data_loader(
|
|
self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
|
|
)
|
|
for input, target in loader:
|
|
self.assertTrue(input.is_pinned())
|
|
self.assertTrue(target.is_pinned())
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_numpy(self):
|
|
import numpy as np
|
|
|
|
class TestDataset(torch.utils.data.Dataset):
|
|
def __getitem__(self, i):
|
|
return np.ones((2, 3, 4)) * i
|
|
|
|
def __len__(self):
|
|
return 1000
|
|
|
|
loader = self._get_data_loader(TestDataset(), batch_size=12)
|
|
batch = next(iter(loader))
|
|
self.assertIsInstance(batch, torch.DoubleTensor)
|
|
self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_numpy_gen_state(self):
|
|
from torch.utils.data._utils.worker import _generate_state
|
|
|
|
# Using NumPy generated states as the reference to test `_generate_state`
|
|
# having the same result.
|
|
# Test case: ((worker_id, base_seed), expected_state)
|
|
test_cases = [
|
|
(
|
|
(4, 13434589827475259383),
|
|
(2884386318, 1088094898, 3523808998, 3860348662),
|
|
),
|
|
((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)),
|
|
(
|
|
(10, 978296274032934101),
|
|
(1759791917, 3550927336, 1225977135, 1036538043),
|
|
),
|
|
(
|
|
(12, 11868770762134256968),
|
|
(3974661794, 3331131333, 3630387033, 2885815368),
|
|
),
|
|
(
|
|
(9, 15378787925219019706),
|
|
(3815056996, 3162224466, 2735102421, 3190253477),
|
|
),
|
|
((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)),
|
|
(
|
|
(15, 14617792358407278405),
|
|
(3402479508, 1588702753, 1169536393, 3675067356),
|
|
),
|
|
(
|
|
(9, 17363320784006640087),
|
|
(957989458, 2518334477, 1421725660, 3086155459),
|
|
),
|
|
(
|
|
(12, 480002904169484764),
|
|
(2732851467, 1762620729, 4055801988, 1277640511),
|
|
),
|
|
(
|
|
(15, 16803975943592702950),
|
|
(3479415043, 4022359553, 295994005, 3358606349),
|
|
),
|
|
(
|
|
(9, 11704776406047813044),
|
|
(1968928009, 710113752, 2442656196, 1587420279),
|
|
),
|
|
(
|
|
(10, 16357891985431864516),
|
|
(1271733898, 4197047399, 3727213786, 2338547348),
|
|
),
|
|
(
|
|
(2, 17423369006318065007),
|
|
(544294336, 1911284083, 3299147734, 3231058347),
|
|
),
|
|
((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)),
|
|
((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)),
|
|
(
|
|
(6, 6269787272229682235),
|
|
(2548857855, 1216457374, 1012973562, 2999759647),
|
|
),
|
|
]
|
|
|
|
for (worker_id, base_seed), exp in test_cases:
|
|
self.assertEqual(exp, _generate_state(base_seed, worker_id))
|
|
|
|
def test_error(self):
|
|
self._test_error(
|
|
self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True)
|
|
)
|
|
|
|
def test_error_workers(self):
|
|
self._test_error(
|
|
self._get_data_loader(
|
|
ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4
|
|
)
|
|
)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
|
|
def test_partial_workers(self):
|
|
r"""Check that workers exit even if the iterator is not exhausted."""
|
|
if TEST_CUDA:
|
|
pin_memory_configs = (True, False)
|
|
else:
|
|
pin_memory_configs = (False,)
|
|
|
|
for pin_memory in pin_memory_configs:
|
|
loader = iter(
|
|
self._get_data_loader(
|
|
self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory
|
|
)
|
|
)
|
|
workers = loader._workers
|
|
if pin_memory:
|
|
pin_memory_thread = loader._pin_memory_thread
|
|
for i, _ in enumerate(loader):
|
|
if i == 10:
|
|
break
|
|
assert i == 10
|
|
del loader
|
|
for w in workers:
|
|
w.join(JOIN_TIMEOUT)
|
|
self.assertFalse(w.is_alive(), "subprocess not terminated")
|
|
if pin_memory:
|
|
pin_memory_thread.join(JOIN_TIMEOUT)
|
|
self.assertFalse(pin_memory_thread.is_alive())
|
|
|
|
# Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065
|
|
@skipIfRocm
|
|
@unittest.skipIf(not HAS_PSUTIL, "psutil not found")
|
|
@slowTest
|
|
def test_proper_exit(self):
|
|
(
|
|
r"""There might be ConnectionResetError or leaked semaphore warning """
|
|
r"""(due to dirty process exit), but they are all safe to ignore"""
|
|
)
|
|
|
|
# TODO: test the case where the pin_memory_thread triggers an
|
|
# error/fatal signal. I haven't found out how to properly do that.
|
|
|
|
for (
|
|
is_iterable_dataset,
|
|
use_workers,
|
|
pin_memory,
|
|
hold_iter_reference,
|
|
) in itertools.product([True, False], repeat=4):
|
|
# `hold_iter_reference` specifies whether we hold a reference to the
|
|
# iterator. This is interesting because Python3 error traces holds a
|
|
# reference to the frames, which hold references to all the local
|
|
# variables including the iterator, and then the iterator dtor may
|
|
# not be called before process end. It is important to see that the
|
|
# processes still exit in both cases.
|
|
|
|
if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS):
|
|
# This test runs in a subprocess, which can only initialize CUDA with spawn.
|
|
# DataLoader with pin_memory=True initializes CUDA when its iterator is constructed.
|
|
# For windows, pin_memory sometimes causes CUDA oom.
|
|
continue
|
|
|
|
# `exit_method` controls the way the loader process ends.
|
|
# - `*_kill` means that `*` is killed by OS.
|
|
# - `*_error` means that `*` raises an error.
|
|
# - `None` means that no error happens.
|
|
# In all cases, all processes should end properly.
|
|
if use_workers:
|
|
# TODO: Fix test for 'loader_kill' that would cause running out of shared memory.
|
|
# Killing loader process would prevent DataLoader iterator clean up all queues
|
|
# and worker processes
|
|
exit_methods = [None, "loader_error", "worker_error", "worker_kill"]
|
|
persistent_workers = self.persistent_workers
|
|
else:
|
|
exit_methods = [None, "loader_error", "loader_kill"]
|
|
persistent_workers = False
|
|
|
|
for exit_method in exit_methods:
|
|
if exit_method == "worker_kill":
|
|
# FIXME: This sometimes hangs. See #16608.
|
|
continue
|
|
|
|
desc = []
|
|
desc.append(f"is_iterable_dataset={is_iterable_dataset}")
|
|
desc.append(f"use_workers={use_workers}")
|
|
desc.append(f"pin_memory={pin_memory}")
|
|
desc.append(f"hold_iter_reference={hold_iter_reference}")
|
|
desc.append(f"exit_method={exit_method}")
|
|
desc = "test_proper_exit with " + ", ".join(desc)
|
|
|
|
# Event that the loader process uses to signal testing process
|
|
# that various things are setup, including that the worker pids
|
|
# are specified in `worker_pids` array.
|
|
loader_setup_event = mp.Event()
|
|
|
|
# Event that this process has finished setting up, and the
|
|
# loader process can now proceed to trigger error events or
|
|
# finish normally.
|
|
tester_setup_event = mp.Event()
|
|
|
|
loader_p = ErrorTrackingProcess(
|
|
target=_test_proper_exit,
|
|
args=(
|
|
is_iterable_dataset,
|
|
use_workers,
|
|
pin_memory,
|
|
exit_method,
|
|
hold_iter_reference,
|
|
loader_setup_event,
|
|
tester_setup_event,
|
|
persistent_workers,
|
|
),
|
|
disable_stderr=False,
|
|
)
|
|
loader_p.start()
|
|
loader_psutil_p = psutil.Process(loader_p.pid)
|
|
|
|
# Wait for loader process to set everything up, e.g., starting
|
|
# workers.
|
|
loader_setup_event.wait(timeout=JOIN_TIMEOUT)
|
|
if not loader_setup_event.is_set():
|
|
fail_msg = (
|
|
desc + ": loader process failed to setup within given time"
|
|
)
|
|
if loader_p.exception is not None:
|
|
fail_msg += f", and had exception {loader_p.exception}"
|
|
elif not loader_p.is_alive():
|
|
fail_msg += f", and exited with code {loader_p.exitcode} but had no exception"
|
|
else:
|
|
fail_msg += ", and is still alive."
|
|
if loader_p.is_alive():
|
|
# this may kill the process, needs to run after the above lines
|
|
loader_p.print_traces_of_all_threads()
|
|
self.fail(fail_msg)
|
|
|
|
# We are certain that the workers have started now.
|
|
worker_psutil_ps = loader_psutil_p.children()
|
|
|
|
def fail(reason):
|
|
report_psutil_attrs = [
|
|
"pid",
|
|
"name",
|
|
"cpu_times",
|
|
"io_counters",
|
|
"memory_full_info",
|
|
"num_ctx_switches",
|
|
"open_files",
|
|
"threads",
|
|
"status",
|
|
"nice",
|
|
"ionice",
|
|
]
|
|
if reason is None:
|
|
err_msg = desc
|
|
else:
|
|
err_msg = f"{desc}: {reason}"
|
|
err_msg += "\nLoader info:\n\t"
|
|
if loader_psutil_p.is_running():
|
|
err_msg += str(
|
|
loader_psutil_p.as_dict(attrs=report_psutil_attrs)
|
|
)
|
|
# this may kill the process, needs to run after the above line
|
|
loader_p.print_traces_of_all_threads()
|
|
else:
|
|
err_msg += f"exited with code {loader_p.exitcode}"
|
|
if use_workers:
|
|
err_msg += "\nWorker(s) info:"
|
|
for idx, worker_psutil_p in enumerate(worker_psutil_ps):
|
|
err_msg += f"\n\tWorker {idx}:\n\t\t"
|
|
if worker_psutil_p.is_running():
|
|
err_msg += str(
|
|
worker_psutil_p.as_dict(attrs=report_psutil_attrs)
|
|
)
|
|
# this may kill the process, needs to run after the above line
|
|
print_traces_of_all_threads(worker_psutil_p.pid)
|
|
else:
|
|
err_msg += "exited with unknown code"
|
|
self.fail(err_msg)
|
|
|
|
tester_setup_event.set()
|
|
|
|
try:
|
|
loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
|
|
if loader_p.is_alive():
|
|
fail_reason = "loader process did not terminate"
|
|
if loader_p.exception is not None:
|
|
fail(
|
|
fail_reason
|
|
+ f", and had exception {loader_p.exception}"
|
|
)
|
|
else:
|
|
fail(fail_reason + ", and had no exception")
|
|
_, alive = psutil.wait_procs(
|
|
worker_psutil_ps,
|
|
timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT),
|
|
)
|
|
if len(alive) > 0:
|
|
fail(
|
|
"worker process (pid(s) {}) did not terminate".format(
|
|
", ".join(str(p.pid) for p in alive)
|
|
)
|
|
)
|
|
if exit_method is None:
|
|
if loader_p.exitcode != 0:
|
|
fail(
|
|
f"loader process had nonzero exitcode {loader_p.exitcode}"
|
|
)
|
|
else:
|
|
if loader_p.exitcode == 0:
|
|
fail("loader process had zero exitcode")
|
|
if exit_method == "loader_error":
|
|
if not isinstance(
|
|
loader_p.exception, RuntimeError
|
|
) or "Loader error" not in str(loader_p.exception):
|
|
fail(
|
|
f"loader process did not raise expected exception, but had {loader_p.exception}"
|
|
)
|
|
elif exit_method == "worker_kill":
|
|
if isinstance(loader_p.exception, RuntimeError):
|
|
if "DataLoader worker (pid" not in str(
|
|
loader_p.exception
|
|
):
|
|
fail(
|
|
f"loader process did not raise expected exception, but had {loader_p.exception}"
|
|
)
|
|
elif isinstance(loader_p.exception, ConnectionRefusedError):
|
|
# Sometimes, when the worker is being killed and is freeing its
|
|
# resources, the unpickling in loader process will be met an
|
|
# a `ConnectionRefusedError` as it can not open a socket to receive
|
|
# resource. In such cases, the worker may not have fully exited,
|
|
# and the loader can't know this via `is_alive` check or `SIGCHLD`
|
|
# handler. So we permit this as an allowed error as well.
|
|
# After all, we are happy as long as it terminates.
|
|
pass
|
|
else:
|
|
fail(
|
|
f"loader process did not raise expected exception, but had {loader_p.exception}"
|
|
)
|
|
elif exit_method == "worker_error":
|
|
if not isinstance(
|
|
loader_p.exception, RuntimeError
|
|
) or "Worker error" not in str(loader_p.exception):
|
|
fail(
|
|
f"loader process did not raise expected exception, but had {loader_p.exception}"
|
|
)
|
|
finally:
|
|
loader_p.terminate()
|
|
|
|
def test_len(self):
|
|
def check_len(dl, expected):
|
|
self.assertEqual(len(dl), expected)
|
|
n = 0
|
|
for _ in dl:
|
|
n += 1
|
|
self.assertEqual(n, expected)
|
|
|
|
check_len(self.dataset, 100)
|
|
check_len(self._get_data_loader(self.dataset, batch_size=2), 50)
|
|
check_len(self._get_data_loader(self.dataset, batch_size=3), 34)
|
|
|
|
def test_iterabledataset_len(self):
|
|
class IterableDataset(torch.utils.data.IterableDataset):
|
|
def __len__(self):
|
|
return 10
|
|
|
|
def __iter__(self):
|
|
return iter(range(10))
|
|
|
|
iterable_loader = DataLoader(IterableDataset(), batch_size=1)
|
|
self.assertEqual(len(iterable_loader), 10)
|
|
iterable_loader = DataLoader(IterableDataset(), batch_size=1, drop_last=True)
|
|
self.assertEqual(len(iterable_loader), 10)
|
|
|
|
iterable_loader = DataLoader(IterableDataset(), batch_size=2)
|
|
self.assertEqual(len(iterable_loader), 5)
|
|
iterable_loader = DataLoader(IterableDataset(), batch_size=2, drop_last=True)
|
|
self.assertEqual(len(iterable_loader), 5)
|
|
|
|
iterable_loader = DataLoader(IterableDataset(), batch_size=3)
|
|
self.assertEqual(len(iterable_loader), 4)
|
|
iterable_loader = DataLoader(IterableDataset(), batch_size=3, drop_last=True)
|
|
self.assertEqual(len(iterable_loader), 3)
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_numpy_scalars(self):
|
|
import numpy as np
|
|
|
|
class ScalarDataset(torch.utils.data.Dataset):
|
|
def __init__(self, dtype):
|
|
self.dtype = dtype
|
|
|
|
def __getitem__(self, i):
|
|
return self.dtype()
|
|
|
|
def __len__(self):
|
|
return 4
|
|
|
|
dtypes = {
|
|
np.float64: torch.DoubleTensor,
|
|
np.float32: torch.FloatTensor,
|
|
np.float16: torch.HalfTensor,
|
|
np.int64: torch.LongTensor,
|
|
np.int32: torch.IntTensor,
|
|
np.int16: torch.ShortTensor,
|
|
np.int8: torch.CharTensor,
|
|
np.uint8: torch.ByteTensor,
|
|
}
|
|
for dt, tt in dtypes.items():
|
|
dset = ScalarDataset(dt)
|
|
loader = self._get_data_loader(dset, batch_size=2)
|
|
batch = next(iter(loader))
|
|
self.assertIsInstance(batch, tt)
|
|
|
|
def test_default_convert_mapping_keep_type(self):
|
|
data = CustomDict({"a": 1, "b": 2})
|
|
converted = _utils.collate.default_convert(data)
|
|
|
|
self.assertEqual(converted, data)
|
|
|
|
def test_default_convert_sequence_keep_type(self):
|
|
data = CustomList([1, 2, 3])
|
|
converted = _utils.collate.default_convert(data)
|
|
|
|
self.assertEqual(converted, data)
|
|
|
|
def test_default_convert_sequence_dont_keep_type(self):
|
|
data = range(2)
|
|
converted = _utils.collate.default_convert(data)
|
|
|
|
self.assertEqual(converted, [0, 1])
|
|
|
|
def test_default_collate_dtype(self):
|
|
arr = [1, 2, -1]
|
|
collated = _utils.collate.default_collate(arr)
|
|
self.assertEqual(collated, torch.tensor(arr))
|
|
self.assertEqual(collated.dtype, torch.int64)
|
|
|
|
arr = [1.1, 2.3, -0.9]
|
|
collated = _utils.collate.default_collate(arr)
|
|
self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64))
|
|
|
|
arr = [True, False]
|
|
collated = _utils.collate.default_collate(arr)
|
|
self.assertEqual(collated, torch.tensor(arr))
|
|
self.assertEqual(collated.dtype, torch.bool)
|
|
|
|
# Should be a no-op
|
|
arr = ["a", "b", "c"]
|
|
self.assertEqual(arr, _utils.collate.default_collate(arr))
|
|
|
|
def test_default_collate_mapping_keep_type(self):
|
|
batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})]
|
|
collated = _utils.collate.default_collate(batch)
|
|
|
|
expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])})
|
|
self.assertEqual(collated, expected)
|
|
|
|
def test_default_collate_sequence_keep_type(self):
|
|
batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])]
|
|
collated = _utils.collate.default_collate(batch)
|
|
|
|
expected = CustomList(
|
|
[
|
|
torch.tensor([1, 4]),
|
|
torch.tensor([2, 5]),
|
|
torch.tensor([3, 6]),
|
|
]
|
|
)
|
|
self.assertEqual(collated, expected)
|
|
|
|
def test_default_collate_sequence_dont_keep_type(self):
|
|
batch = [range(2), range(2)]
|
|
collated = _utils.collate.default_collate(batch)
|
|
|
|
self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])])
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_default_collate_bad_numpy_types(self):
|
|
import numpy as np
|
|
|
|
# Should be a no-op
|
|
arr = np.array(["a", "b", "c"])
|
|
self.assertEqual(arr, _utils.collate.default_collate(arr))
|
|
|
|
arr = np.array([[["a", "b", "c"]]])
|
|
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
|
|
|
arr = np.array([object(), object(), object()])
|
|
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
|
|
|
arr = np.array([[[object(), object(), object()]]])
|
|
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_default_collate_numpy_memmap(self):
|
|
import numpy as np
|
|
|
|
with tempfile.TemporaryFile() as f:
|
|
arr = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
|
|
arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape)
|
|
arr_memmap[:] = arr[:]
|
|
arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape)
|
|
tensor = _utils.collate.default_collate(list(arr_new))
|
|
|
|
self.assertTrue(
|
|
(tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item()
|
|
)
|
|
|
|
def test_default_collate_bad_sequence_type(self):
|
|
batch = [["X"], ["X", "X"]]
|
|
self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch))
|
|
self.assertRaises(
|
|
RuntimeError, lambda: _utils.collate.default_collate(batch[::-1])
|
|
)
|
|
|
|
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
|
def test_default_collate_shared_tensor(self):
|
|
import numpy as np
|
|
|
|
t_in = torch.zeros(1)
|
|
n_in = np.zeros(1)
|
|
|
|
self.assertEqual(t_in.is_shared(), False)
|
|
|
|
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
|
|
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
|
|
|
|
# FIXME: fix the following hack that makes `default_collate` believe
|
|
# that it is in a worker process (since it tests
|
|
# `get_worker_info() != None`), even though it is not.
|
|
old = _utils.worker._worker_info
|
|
try:
|
|
_utils.worker._worker_info = "x"
|
|
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
|
|
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
|
|
finally:
|
|
_utils.worker._worker_info = old
|
|
|
|
def test_excessive_thread_creation_warning(self):
|
|
with self.assertWarnsRegex(
|
|
UserWarning,
|
|
r"excessive worker creation might get DataLoader running slow or even freeze",
|
|
):
|
|
dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
|
|
|
|
|
|
class TestDataLoaderDeviceType(TestCase):
|
|
@parametrize(
|
|
"context",
|
|
[ctx for ctx in supported_multiprocessing_contexts if ctx is not None],
|
|
)
|
|
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
|
|
def test_nested_tensor_multiprocessing(self, device, context):
|
|
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
|
|
if "cuda" in device and context == "fork":
|
|
# TODO: Skip this better in a better way when the test framework allows
|
|
return
|
|
|
|
dataset = [
|
|
torch.nested.nested_tensor([torch.randn(5)], device=device)
|
|
for _ in range(10)
|
|
]
|
|
|
|
pin_memory_settings = [False]
|
|
if device == "cpu" and torch.cuda.is_available():
|
|
pin_memory_settings.append(True)
|
|
|
|
for pin_memory in pin_memory_settings:
|
|
loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=1,
|
|
num_workers=4,
|
|
collate_fn=_clone_collate,
|
|
pin_memory=pin_memory,
|
|
multiprocessing_context=context,
|
|
)
|
|
|
|
for i, batch in enumerate(loader):
|
|
self.assertEqual(batch[0], dataset[i])
|
|
|
|
# Error case: default collate_fn doesn't currently support batches of nested tensors.
|
|
# Following the current semantics, we'd need to stack them, which isn't possible atm.
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "not currently supported by the default collate_fn"
|
|
):
|
|
loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=1,
|
|
num_workers=4,
|
|
multiprocessing_context=context,
|
|
)
|
|
|
|
next(iter(loader))
|
|
|
|
|
|
class IntegrationTestDataLoaderDataPipe(TestCase):
|
|
r"""
|
|
Verify the behavior of a certain ``DataPipes`` with ``DataLoader``
|
|
"""
|
|
|
|
def test_shuffler_iterdatapipe(self):
|
|
r"""
|
|
Verify ``IterDataPipe.shuffle`` is controlled by ``DataLoader``
|
|
to generate different seeds deterministically per epoch.
|
|
"""
|
|
exp = list(range(100))
|
|
|
|
def _create_dp(buffer_size):
|
|
input_ds = dp.iter.IterableWrapper(exp)
|
|
return input_ds.shuffle(buffer_size=buffer_size).sharding_filter()
|
|
|
|
for bs in (5, 20, 33):
|
|
# Test Deterministic
|
|
for num_workers, pw in itertools.product((0, 1, 2), (True, False)):
|
|
if num_workers == 0 and pw:
|
|
continue
|
|
|
|
shuffle_dp = _create_dp(bs)
|
|
|
|
mp_ctx = "spawn" if num_workers > 0 else None
|
|
dl = DataLoader(
|
|
shuffle_dp,
|
|
num_workers=num_workers,
|
|
shuffle=True,
|
|
multiprocessing_context=mp_ctx,
|
|
persistent_workers=pw,
|
|
)
|
|
|
|
# No seed
|
|
dl_res_ns = list(dl)
|
|
self.assertEqual(sorted(dl_res_ns), exp)
|
|
|
|
# Same seeds
|
|
dl_res = []
|
|
for epoch in range(2):
|
|
torch.manual_seed(123)
|
|
dl_res.append(list(dl))
|
|
self.assertEqual(dl_res[0], dl_res[1])
|
|
self.assertEqual(sorted(dl_res[0]), exp)
|
|
|
|
# Different seeds
|
|
torch.manual_seed(321)
|
|
dl_res.append(list(dl))
|
|
|
|
self.assertEqual(len(dl_res[0]), len(dl_res[2]))
|
|
self.assertNotEqual(dl_res[0], dl_res[2])
|
|
self.assertEqual(sorted(dl_res[0]), sorted(dl_res[2]))
|
|
|
|
if dl._iterator is not None:
|
|
dl._iterator._shutdown_workers()
|
|
dl._iterator = None
|
|
del dl
|
|
|
|
|
|
class StringDataset(Dataset):
|
|
def __init__(self) -> None:
|
|
self.s = "12345"
|
|
|
|
def __len__(self):
|
|
return len(self.s)
|
|
|
|
def __getitem__(self, ndx):
|
|
return (self.s[ndx], ndx)
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
class TestStringDataLoader(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.dataset = StringDataset()
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_shuffle_pin_memory(self):
|
|
loader = DataLoader(
|
|
self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
|
|
)
|
|
for s, n in loader:
|
|
self.assertIsInstance(s[0], str)
|
|
self.assertTrue(n.is_pinned())
|
|
|
|
|
|
class DictDataset(Dataset):
|
|
def __len__(self):
|
|
return 4
|
|
|
|
def __getitem__(self, ndx):
|
|
return {
|
|
"a_tensor": torch.empty(4, 2).fill_(ndx),
|
|
"another_dict": {
|
|
"a_number": ndx,
|
|
},
|
|
}
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
class TestDictDataLoader(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.dataset = DictDataset()
|
|
|
|
def test_sequential_batch(self):
|
|
for persistent_workers in (False, True):
|
|
if persistent_workers:
|
|
loader = DataLoader(
|
|
self.dataset,
|
|
batch_size=2,
|
|
shuffle=False,
|
|
persistent_workers=persistent_workers,
|
|
num_workers=1,
|
|
)
|
|
else:
|
|
loader = DataLoader(
|
|
self.dataset,
|
|
batch_size=2,
|
|
shuffle=False,
|
|
persistent_workers=persistent_workers,
|
|
)
|
|
batch_size = loader.batch_size
|
|
for i, sample in enumerate(loader):
|
|
idx = i * batch_size
|
|
self.assertEqual(set(sample.keys()), {"a_tensor", "another_dict"})
|
|
self.assertEqual(set(sample["another_dict"].keys()), {"a_number"})
|
|
|
|
t = sample["a_tensor"]
|
|
self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
|
|
self.assertTrue((t[0] == idx).all())
|
|
self.assertTrue((t[1] == idx + 1).all())
|
|
|
|
n = sample["another_dict"]["a_number"]
|
|
self.assertEqual(n.size(), torch.Size([batch_size]))
|
|
self.assertEqual(n[0], idx)
|
|
self.assertEqual(n[1], idx + 1)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_pin_memory(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
|
|
for sample in loader:
|
|
self.assertTrue(sample["a_tensor"].is_pinned())
|
|
self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_pin_memory_device(self):
|
|
loader = DataLoader(
|
|
self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda"
|
|
)
|
|
for sample in loader:
|
|
self.assertTrue(sample["a_tensor"].is_pinned(device="cuda"))
|
|
self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda"))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_pin_memory_with_only_device(self):
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda")
|
|
for sample in loader:
|
|
self.assertFalse(sample["a_tensor"].is_pinned(device="cuda"))
|
|
self.assertFalse(
|
|
sample["another_dict"]["a_number"].is_pinned(device="cuda")
|
|
)
|
|
|
|
|
|
class DummyDataset(torch.utils.data.Dataset):
|
|
def __init__(self) -> None:
|
|
self.data = list(range(10))
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
if torch.is_tensor(idx):
|
|
idx = idx.tolist()
|
|
# The persistent workers always maintain the original
|
|
# dataset through the dataloader lifetime
|
|
# so the attributes will remain the same as the
|
|
# first time the workers where spawned (dataloader iteration)
|
|
assert self.start == 0
|
|
return self.data[idx]
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
@unittest.skipIf(
|
|
TEST_WITH_ASAN,
|
|
"DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
|
|
)
|
|
class TestDataLoaderPersistentWorkers(TestDataLoader):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.persistent_workers = True
|
|
|
|
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
|
|
@unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
|
|
def test_fd_limit_exceeded(self):
|
|
# See NOTE [ DataLoader on Linux and open files limit ]
|
|
import subprocess
|
|
|
|
subprocess.check_output(
|
|
[
|
|
sys.executable,
|
|
"-c",
|
|
"""\
|
|
import torch
|
|
import resource
|
|
from torch.utils.data import DataLoader, IterableDataset
|
|
|
|
class RandomDataset(IterableDataset):
|
|
def __init__(self, len, size):
|
|
super(RandomDataset).__init__()
|
|
self.len = len
|
|
self.size = size
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.len <= 0:
|
|
raise StopIteration
|
|
self.len -= 1
|
|
return torch.randn(self.size)
|
|
|
|
try:
|
|
keep_fds_alive = []
|
|
resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
|
|
for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
|
|
num_workers=1, persistent_workers=True):
|
|
random_t.max(dim=0)
|
|
keep_fds_alive.append(random_t)
|
|
except RuntimeError as e:
|
|
assert "ulimit -n" in str(e)
|
|
assert "set_sharing_strategy" in str(e)
|
|
""",
|
|
]
|
|
)
|
|
|
|
def test_dataset_not_reset(self):
|
|
dataset = DummyDataset()
|
|
pin_memory_configs = [False]
|
|
if TEST_CUDA:
|
|
pin_memory_configs.append(True)
|
|
for pin_memory in pin_memory_configs:
|
|
dataloader = self._get_data_loader(
|
|
dataset, num_workers=2, pin_memory=pin_memory
|
|
)
|
|
dataset.start = 0
|
|
for i in range(10):
|
|
for x in dataloader:
|
|
pass
|
|
# Changing the start value here doesn't have any effect in the dataset
|
|
# cached by the workers. since they are not recreated between epochs
|
|
# and can cache values safely
|
|
dataset.start = i
|
|
|
|
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
|
|
@unittest.skipIf(IS_WINDOWS, "Needs fork")
|
|
def test_early_exit(self):
|
|
import subprocess
|
|
|
|
proc = subprocess.check_output(
|
|
[
|
|
sys.executable,
|
|
"-c",
|
|
"""\
|
|
import torch
|
|
from torch.utils.data import DataLoader, IterableDataset
|
|
|
|
class RandomDataset(IterableDataset):
|
|
def __init__(self, len, size):
|
|
super(RandomDataset).__init__()
|
|
self.len = len
|
|
self.size = size
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.len <= 0:
|
|
raise StopIteration
|
|
self.len -= 1
|
|
return torch.randn(self.size)
|
|
|
|
if __name__ == '__main__':
|
|
dl = DataLoader(
|
|
RandomDataset(64, (28, 28)),
|
|
batch_size=16,
|
|
num_workers=2,
|
|
pin_memory=True,
|
|
persistent_workers=True,
|
|
multiprocessing_context="fork",
|
|
)
|
|
|
|
for _ in dl:
|
|
break
|
|
""",
|
|
]
|
|
)
|
|
|
|
|
|
class NamedTupleDataset(Dataset):
|
|
from collections import namedtuple
|
|
|
|
Batch = namedtuple("Batch", ["data", "label", "random_tensor"])
|
|
Data = namedtuple("Data", ["positive", "negative"])
|
|
|
|
def __len__(self):
|
|
return 4
|
|
|
|
def __getitem__(self, ndx):
|
|
return self.Batch(
|
|
data=self.Data(positive=ndx, negative=-ndx),
|
|
label=str(ndx),
|
|
random_tensor=torch.randn(3),
|
|
)
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
class TestNamedTupleDataLoader(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.dataset = NamedTupleDataset()
|
|
|
|
def test_dataloader_with_namedtuple(self):
|
|
# auto-collation
|
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA)
|
|
for batch in loader:
|
|
self.assertIsInstance(batch, NamedTupleDataset.Batch)
|
|
self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
|
|
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
|
|
self.assertIsInstance(batch.data.positive, torch.Tensor)
|
|
self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA)
|
|
# no auto-collation
|
|
loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA)
|
|
for batch in loader:
|
|
self.assertIsInstance(batch, NamedTupleDataset.Batch)
|
|
self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
|
|
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
|
|
self.assertNotIsInstance(batch.data.positive, torch.Tensor)
|
|
|
|
|
|
class SimpleCustomBatch:
|
|
def __init__(self, data):
|
|
transposed_data = list(zip(*data))
|
|
self.inp = torch.stack(transposed_data[0], 0)
|
|
self.tgt = torch.stack(transposed_data[1], 0)
|
|
|
|
def pin_memory(self):
|
|
self.inp = self.inp.pin_memory()
|
|
self.tgt = self.tgt.pin_memory()
|
|
return self
|
|
|
|
def is_pinned(self):
|
|
return self.inp.is_pinned() and self.tgt.is_pinned()
|
|
|
|
|
|
# Workaround for https://github.com/pytorch/pytorch/issues/50661
|
|
# Classes from `__main__` can not be correctly unpickled from spawned module
|
|
# See https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming
|
|
self_module = __import__(os.path.splitext(os.path.basename(__file__))[0])
|
|
|
|
|
|
def collate_wrapper(batch):
|
|
return self_module.SimpleCustomBatch(batch)
|
|
|
|
|
|
def collate_into_packed_sequence(batch):
|
|
data = torch.stack([sample[0] for sample in batch], 1)
|
|
t, b = data.size()
|
|
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
|
|
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
|
|
|
|
|
|
def collate_into_packed_sequence_batch_first(batch):
|
|
data = torch.stack([sample[0] for sample in batch], 0)
|
|
b, t = data.size()
|
|
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
|
|
return torch.nn.utils.rnn.pack_padded_sequence(
|
|
data, lengths, batch_first=True, enforce_sorted=False
|
|
)
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
class TestCustomPinFn(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
|
|
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
|
|
self.dataset = TensorDataset(inps, tgts)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_custom_batch_pin(self):
|
|
test_cases = [
|
|
(collate_wrapper, self_module.SimpleCustomBatch),
|
|
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
|
|
(
|
|
collate_into_packed_sequence_batch_first,
|
|
torch.nn.utils.rnn.PackedSequence,
|
|
),
|
|
]
|
|
for collate_fn, elem_cls in test_cases:
|
|
loader = DataLoader(
|
|
self.dataset, batch_size=2, collate_fn=collate_fn, pin_memory=True
|
|
)
|
|
for sample in loader:
|
|
self.assertIsInstance(sample, elem_cls)
|
|
self.assertTrue(sample.is_pinned())
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_custom_batch_pin_worker(self):
|
|
test_cases = [
|
|
(collate_wrapper, self_module.SimpleCustomBatch),
|
|
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
|
|
(
|
|
collate_into_packed_sequence_batch_first,
|
|
torch.nn.utils.rnn.PackedSequence,
|
|
),
|
|
]
|
|
for collate_fn, elem_cls in test_cases:
|
|
loader = DataLoader(
|
|
self.dataset,
|
|
batch_size=2,
|
|
collate_fn=collate_fn,
|
|
pin_memory=True,
|
|
num_workers=1,
|
|
)
|
|
for sample in loader:
|
|
self.assertIsInstance(sample, elem_cls)
|
|
self.assertTrue(sample.is_pinned())
|
|
|
|
|
|
class TestWorkerQueueDataset(Dataset):
|
|
def __init__(self, data):
|
|
self.data = data
|
|
self.worker_id = None
|
|
|
|
def worker_init_fn(self, worker_id):
|
|
self.worker_id = worker_id
|
|
|
|
def __getitem__(self, item):
|
|
return self.worker_id, self.data[item]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TSAN,
|
|
"Fails with TSAN with the following error: starting new threads after multi-threaded "
|
|
"fork is not supported. Dying (set die_after_fork=0 to override)",
|
|
)
|
|
@unittest.skipIf(
|
|
TEST_WITH_ASAN,
|
|
"Flaky with ASAN, see https://github.com/pytorch/pytorch/issues/65727",
|
|
)
|
|
class TestIndividualWorkerQueue(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.dataset = TestWorkerQueueDataset(list(range(128)))
|
|
|
|
def _run_ind_worker_queue_test(self, batch_size, num_workers):
|
|
loader = DataLoader(
|
|
self.dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
timeout=5,
|
|
worker_init_fn=self.dataset.worker_init_fn,
|
|
)
|
|
current_worker_idx = 0
|
|
for i, (worker_ids, sample) in enumerate(loader):
|
|
self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
|
|
self.assertEqual(
|
|
sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size))
|
|
)
|
|
current_worker_idx += 1
|
|
if current_worker_idx == num_workers:
|
|
current_worker_idx = 0
|
|
|
|
def test_ind_worker_queue(self):
|
|
max_num_workers = None
|
|
if hasattr(os, "sched_getaffinity"):
|
|
try:
|
|
max_num_workers = len(os.sched_getaffinity(0))
|
|
except Exception:
|
|
pass
|
|
if max_num_workers is None:
|
|
cpu_count = os.cpu_count()
|
|
if cpu_count is not None:
|
|
# Use half number of CPUs
|
|
max_num_workers = cpu_count // 2
|
|
|
|
if max_num_workers is None:
|
|
max_num_workers = 1
|
|
|
|
for batch_size in (8, 16, 32, 64):
|
|
for num_workers in range(0, min(6, max_num_workers)):
|
|
self._run_ind_worker_queue_test(
|
|
batch_size=batch_size, num_workers=num_workers + 1
|
|
)
|
|
|
|
|
|
class SetAffinityDataset(IterableDataset):
|
|
def __iter__(self):
|
|
torch.randperm(1)
|
|
after = os.sched_getaffinity(0)
|
|
return iter(after)
|
|
|
|
|
|
@unittest.skipIf(
|
|
not hasattr(os, "sched_setaffinity"), "os.sched_setaffinity is not available"
|
|
)
|
|
class TestSetAffinity(TestCase):
|
|
def test_set_affinity_in_worker_init(self):
|
|
# Query the current affinity mask to avoid setting a disallowed one
|
|
old_affinity = os.sched_getaffinity(0)
|
|
if not old_affinity:
|
|
self.skipTest("No affinity information")
|
|
# Choose any
|
|
expected_affinity = list(old_affinity)[-1]
|
|
|
|
def worker_set_affinity(_):
|
|
os.sched_setaffinity(0, [expected_affinity])
|
|
|
|
dataset = SetAffinityDataset()
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset, num_workers=2, worker_init_fn=worker_set_affinity
|
|
)
|
|
for sample in dataloader:
|
|
self.assertEqual(sample, [expected_affinity])
|
|
|
|
|
|
class ConvDataset(Dataset):
|
|
def __init__(self) -> None:
|
|
self.x = torch.ones(1, 1, 24000)
|
|
# Call convolution on parent process
|
|
self[0]
|
|
|
|
def __len__(self):
|
|
return 1
|
|
|
|
def __getitem__(self, index):
|
|
return torch.nn.functional.conv1d(self.x, torch.ones(1, 1, 2))
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Needs fork")
|
|
@unittest.skipIf(
|
|
TEST_WITH_ASAN,
|
|
"This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492",
|
|
)
|
|
class TestConvAfterFork(TestCase):
|
|
# Tests crash reported in https://github.com/pytorch/pytorch/issues/53565
|
|
def test_conv_after_fork(self):
|
|
loader = DataLoader(ConvDataset(), num_workers=1)
|
|
for x in loader:
|
|
self.assertEqual(x.shape, (1, 1, 1, 23999))
|
|
|
|
|
|
instantiate_device_type_tests(TestDataLoaderDeviceType, globals())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|