mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54544 ## Feature - Add `subinstance(data, type)` to check `data` is a subtype instance of the `type` - Add a decorator of `runtime_validation` to validate the returned data from `__iter__` is subtype instance of hint. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D27327234 Pulled By: ejguan fbshipit-source-id: fb6a332762b0fe75284bb2b52a13ed171b42558c
806 lines
31 KiB
Python
806 lines
31 KiB
Python
import itertools
|
|
import os
|
|
import pickle
|
|
import random
|
|
import tempfile
|
|
import warnings
|
|
import tarfile
|
|
import zipfile
|
|
import numpy as np
|
|
from PIL import Image
|
|
from unittest import skipIf
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.testing._internal.common_utils import (TestCase, run_tests)
|
|
from torch.utils.data import \
|
|
(IterDataPipe, RandomSampler, DataLoader,
|
|
construct_time_validation, runtime_validation)
|
|
|
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Set, Union
|
|
|
|
import torch.utils.data.datapipes as dp
|
|
from torch.utils.data.datapipes.utils.decoder import (
|
|
basichandlers as decoder_basichandlers,
|
|
imagehandler as decoder_imagehandler)
|
|
|
|
try:
|
|
import torchvision.transforms
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
|
|
T_co = TypeVar('T_co', covariant=True)
|
|
|
|
|
|
def create_temp_dir_and_files():
|
|
# The temp dir and files within it will be released and deleted in tearDown().
|
|
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
|
|
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
|
|
temp_dir_path = temp_dir.name
|
|
with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.txt') as f:
|
|
temp_file1_name = f.name
|
|
with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.byte') as f:
|
|
temp_file2_name = f.name
|
|
with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.empty') as f:
|
|
temp_file3_name = f.name
|
|
|
|
with open(temp_file1_name, 'w') as f1:
|
|
f1.write('0123456789abcdef')
|
|
with open(temp_file2_name, 'wb') as f2:
|
|
f2.write(b"0123456789abcdef")
|
|
|
|
temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path) # noqa: P201
|
|
temp_sub_dir_path = temp_sub_dir.name
|
|
with tempfile.NamedTemporaryFile(dir=temp_sub_dir_path, delete=False, suffix='.txt') as f:
|
|
temp_sub_file1_name = f.name
|
|
with tempfile.NamedTemporaryFile(dir=temp_sub_dir_path, delete=False, suffix='.byte') as f:
|
|
temp_sub_file2_name = f.name
|
|
|
|
with open(temp_sub_file1_name, 'w') as f1:
|
|
f1.write('0123456789abcdef')
|
|
with open(temp_sub_file2_name, 'wb') as f2:
|
|
f2.write(b"0123456789abcdef")
|
|
|
|
return [(temp_dir, temp_file1_name, temp_file2_name, temp_file3_name),
|
|
(temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)]
|
|
|
|
class TestIterableDataPipeBasic(TestCase):
|
|
|
|
def setUp(self):
|
|
ret = create_temp_dir_and_files()
|
|
self.temp_dir = ret[0][0]
|
|
self.temp_files = ret[0][1:]
|
|
self.temp_sub_dir = ret[1][0]
|
|
self.temp_sub_files = ret[1][1:]
|
|
|
|
def tearDown(self):
|
|
try:
|
|
self.temp_sub_dir.cleanup()
|
|
self.temp_dir.cleanup()
|
|
except Exception as e:
|
|
warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e)))
|
|
|
|
def test_listdirfiles_iterable_datapipe(self):
|
|
temp_dir = self.temp_dir.name
|
|
datapipe = dp.iter.ListDirFiles(temp_dir, '')
|
|
|
|
count = 0
|
|
for pathname in datapipe:
|
|
count = count + 1
|
|
self.assertTrue(pathname in self.temp_files)
|
|
self.assertEqual(count, len(self.temp_files))
|
|
|
|
count = 0
|
|
datapipe = dp.iter.ListDirFiles(temp_dir, '', recursive=True)
|
|
for pathname in datapipe:
|
|
count = count + 1
|
|
self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files))
|
|
self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))
|
|
|
|
|
|
def test_loadfilesfromdisk_iterable_datapipe(self):
|
|
# test import datapipe class directly
|
|
from torch.utils.data.datapipes.iter import ListDirFiles, LoadFilesFromDisk
|
|
|
|
temp_dir = self.temp_dir.name
|
|
datapipe1 = ListDirFiles(temp_dir, '')
|
|
datapipe2 = LoadFilesFromDisk(datapipe1)
|
|
|
|
count = 0
|
|
for rec in datapipe2:
|
|
count = count + 1
|
|
self.assertTrue(rec[0] in self.temp_files)
|
|
self.assertTrue(rec[1].read() == open(rec[0], 'rb').read())
|
|
self.assertEqual(count, len(self.temp_files))
|
|
|
|
|
|
def test_readfilesfromtar_iterable_datapipe(self):
|
|
temp_dir = self.temp_dir.name
|
|
temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
|
|
with tarfile.open(temp_tarfile_pathname, "w:gz") as tar:
|
|
tar.add(self.temp_files[0])
|
|
tar.add(self.temp_files[1])
|
|
tar.add(self.temp_files[2])
|
|
datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar')
|
|
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
|
|
datapipe3 = dp.iter.ReadFilesFromTar(datapipe2)
|
|
# read extracted files before reaching the end of the tarfile
|
|
count = 0
|
|
for rec, temp_file in zip(datapipe3, self.temp_files):
|
|
count = count + 1
|
|
self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
|
|
self.assertEqual(rec[1].read(), open(temp_file, 'rb').read())
|
|
self.assertEqual(count, len(self.temp_files))
|
|
# read extracted files after reaching the end of the tarfile
|
|
count = 0
|
|
data_refs = []
|
|
for rec in datapipe3:
|
|
count = count + 1
|
|
data_refs.append(rec)
|
|
self.assertEqual(count, len(self.temp_files))
|
|
for i in range(0, count):
|
|
self.assertEqual(os.path.basename(data_refs[i][0]), os.path.basename(self.temp_files[i]))
|
|
self.assertEqual(data_refs[i][1].read(), open(self.temp_files[i], 'rb').read())
|
|
|
|
|
|
def test_readfilesfromzip_iterable_datapipe(self):
|
|
temp_dir = self.temp_dir.name
|
|
temp_zipfile_pathname = os.path.join(temp_dir, "test_zip.zip")
|
|
with zipfile.ZipFile(temp_zipfile_pathname, 'w') as myzip:
|
|
myzip.write(self.temp_files[0])
|
|
myzip.write(self.temp_files[1])
|
|
myzip.write(self.temp_files[2])
|
|
datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.zip')
|
|
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
|
|
datapipe3 = dp.iter.ReadFilesFromZip(datapipe2)
|
|
# read extracted files before reaching the end of the zipfile
|
|
count = 0
|
|
for rec, temp_file in zip(datapipe3, self.temp_files):
|
|
count = count + 1
|
|
self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
|
|
self.assertEqual(rec[1].read(), open(temp_file, 'rb').read())
|
|
self.assertEqual(count, len(self.temp_files))
|
|
# read extracted files before reaching the end of the zipile
|
|
count = 0
|
|
data_refs = []
|
|
for rec in datapipe3:
|
|
count = count + 1
|
|
data_refs.append(rec)
|
|
self.assertEqual(count, len(self.temp_files))
|
|
for i in range(0, count):
|
|
self.assertEqual(os.path.basename(data_refs[i][0]), os.path.basename(self.temp_files[i]))
|
|
self.assertEqual(data_refs[i][1].read(), open(self.temp_files[i], 'rb').read())
|
|
|
|
|
|
def test_routeddecoder_iterable_datapipe(self):
|
|
temp_dir = self.temp_dir.name
|
|
temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
|
|
img = Image.new('RGB', (2, 2), color='red')
|
|
img.save(temp_pngfile_pathname)
|
|
datapipe1 = dp.iter.ListDirFiles(temp_dir, ['*.png', '*.txt'])
|
|
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
|
|
datapipe3 = dp.iter.RoutedDecoder(datapipe2, handlers=[decoder_imagehandler('rgb')])
|
|
datapipe3.add_handler(decoder_basichandlers)
|
|
|
|
for rec in datapipe3:
|
|
ext = os.path.splitext(rec[0])[1]
|
|
if ext == '.png':
|
|
expected = np.array([[[1., 0., 0.], [1., 0., 0.]], [[1., 0., 0.], [1., 0., 0.]]], dtype=np.single)
|
|
self.assertTrue(np.array_equal(rec[1], expected))
|
|
else:
|
|
self.assertTrue(rec[1] == open(rec[0], 'rb').read().decode('utf-8'))
|
|
|
|
|
|
def test_groupbykey_iterable_datapipe(self):
|
|
temp_dir = self.temp_dir.name
|
|
temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
|
|
file_list = [
|
|
"a.png", "b.png", "c.json", "a.json", "c.png", "b.json", "d.png",
|
|
"d.json", "e.png", "f.json", "g.png", "f.png", "g.json", "e.json",
|
|
"h.txt", "h.json"]
|
|
with tarfile.open(temp_tarfile_pathname, "w:gz") as tar:
|
|
for file_name in file_list:
|
|
file_pathname = os.path.join(temp_dir, file_name)
|
|
with open(file_pathname, 'w') as f:
|
|
f.write('12345abcde')
|
|
tar.add(file_pathname)
|
|
|
|
datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar')
|
|
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
|
|
datapipe3 = dp.iter.ReadFilesFromTar(datapipe2)
|
|
datapipe4 = dp.iter.GroupByKey(datapipe3, group_size=2)
|
|
|
|
expected_result = [("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), (
|
|
"f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.json", "h.txt")]
|
|
|
|
count = 0
|
|
for rec, expected in zip(datapipe4, expected_result):
|
|
count = count + 1
|
|
self.assertEqual(os.path.basename(rec[0][0]), expected[0])
|
|
self.assertEqual(os.path.basename(rec[1][0]), expected[1])
|
|
self.assertEqual(rec[0][1].read(), b'12345abcde')
|
|
self.assertEqual(rec[1][1].read(), b'12345abcde')
|
|
self.assertEqual(count, 8)
|
|
|
|
|
|
class IDP_NoLen(IterDataPipe):
|
|
def __init__(self, input_dp):
|
|
super().__init__()
|
|
self.input_dp = input_dp
|
|
|
|
def __iter__(self):
|
|
for i in self.input_dp:
|
|
yield i
|
|
|
|
|
|
class IDP(IterDataPipe):
|
|
def __init__(self, input_dp):
|
|
super().__init__()
|
|
self.input_dp = input_dp
|
|
self.length = len(input_dp)
|
|
|
|
def __iter__(self):
|
|
for i in self.input_dp:
|
|
yield i
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
def _fake_fn(data, *args, **kwargs):
|
|
return data
|
|
|
|
def _fake_filter_fn(data, *args, **kwargs):
|
|
return data >= 5
|
|
|
|
def _worker_init_fn(worker_id):
|
|
random.seed(123)
|
|
|
|
|
|
class TestFunctionalIterDataPipe(TestCase):
|
|
|
|
def test_picklable(self):
|
|
arr = range(10)
|
|
picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
|
|
(dp.iter.Map, IDP(arr), (), {}),
|
|
(dp.iter.Map, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}),
|
|
(dp.iter.Collate, IDP(arr), (), {}),
|
|
(dp.iter.Collate, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}),
|
|
(dp.iter.Filter, IDP(arr), (_fake_filter_fn, (0, ), {'test': True}), {}),
|
|
]
|
|
for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes:
|
|
p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore
|
|
|
|
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [
|
|
(dp.iter.Map, IDP(arr), (lambda x: x, ), {}),
|
|
(dp.iter.Collate, IDP(arr), (lambda x: x, ), {}),
|
|
(dp.iter.Filter, IDP(arr), (lambda x: x >= 5, ), {}),
|
|
]
|
|
for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes:
|
|
with warnings.catch_warnings(record=True) as wa:
|
|
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore
|
|
self.assertEqual(len(wa), 1)
|
|
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
|
|
with self.assertRaises(AttributeError):
|
|
p = pickle.dumps(datapipe) # type: ignore
|
|
|
|
def test_concat_datapipe(self):
|
|
input_dp1 = IDP(range(10))
|
|
input_dp2 = IDP(range(5))
|
|
|
|
with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
|
|
dp.iter.Concat()
|
|
|
|
with self.assertRaisesRegex(TypeError, r"Expected all inputs to be `IterDataPipe`"):
|
|
dp.iter.Concat(input_dp1, ()) # type: ignore
|
|
|
|
concat_dp = input_dp1.concat(input_dp2)
|
|
self.assertEqual(len(concat_dp), 15)
|
|
self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
|
|
|
|
# Test Reset
|
|
self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
|
|
|
|
input_dp_nl = IDP_NoLen(range(5))
|
|
|
|
concat_dp = input_dp1.concat(input_dp_nl)
|
|
with self.assertRaises(NotImplementedError):
|
|
len(concat_dp)
|
|
|
|
self.assertEqual(list(d for d in concat_dp), list(range(10)) + list(range(5)))
|
|
|
|
def test_map_datapipe(self):
|
|
input_dp = IDP(range(10))
|
|
|
|
def fn(item, dtype=torch.float, *, sum=False):
|
|
data = torch.tensor(item, dtype=dtype)
|
|
return data if not sum else data.sum()
|
|
|
|
map_dp = input_dp.map(fn)
|
|
self.assertEqual(len(input_dp), len(map_dp))
|
|
for x, y in zip(map_dp, input_dp):
|
|
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
|
|
|
|
map_dp = input_dp.map(fn=fn, fn_args=(torch.int, ), fn_kwargs={'sum': True})
|
|
self.assertEqual(len(input_dp), len(map_dp))
|
|
for x, y in zip(map_dp, input_dp):
|
|
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
|
|
|
|
from functools import partial
|
|
map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
|
|
self.assertEqual(len(input_dp), len(map_dp))
|
|
for x, y in zip(map_dp, input_dp):
|
|
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
|
|
|
|
input_dp_nl = IDP_NoLen(range(10))
|
|
map_dp_nl = input_dp_nl.map()
|
|
with self.assertRaises(NotImplementedError):
|
|
len(map_dp_nl)
|
|
for x, y in zip(map_dp_nl, input_dp_nl):
|
|
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
|
|
|
|
def test_collate_datapipe(self):
|
|
arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
|
input_dp = IDP(arrs)
|
|
|
|
def _collate_fn(batch):
|
|
return torch.tensor(sum(batch), dtype=torch.float)
|
|
|
|
collate_dp = input_dp.collate(collate_fn=_collate_fn)
|
|
self.assertEqual(len(input_dp), len(collate_dp))
|
|
for x, y in zip(collate_dp, input_dp):
|
|
self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float))
|
|
|
|
input_dp_nl = IDP_NoLen(arrs)
|
|
collate_dp_nl = input_dp_nl.collate()
|
|
with self.assertRaises(NotImplementedError):
|
|
len(collate_dp_nl)
|
|
for x, y in zip(collate_dp_nl, input_dp_nl):
|
|
self.assertEqual(x, torch.tensor(y))
|
|
|
|
def test_batch_datapipe(self):
|
|
arrs = list(range(10))
|
|
input_dp = IDP(arrs)
|
|
with self.assertRaises(AssertionError):
|
|
input_dp.batch(batch_size=0)
|
|
|
|
# Default not drop the last batch
|
|
bs = 3
|
|
batch_dp = input_dp.batch(batch_size=bs)
|
|
self.assertEqual(len(batch_dp), 4)
|
|
for i, batch in enumerate(batch_dp):
|
|
self.assertEqual(len(batch), 1 if i == 3 else bs)
|
|
self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)])
|
|
|
|
# Drop the last batch
|
|
bs = 4
|
|
batch_dp = input_dp.batch(batch_size=bs, drop_last=True)
|
|
self.assertEqual(len(batch_dp), 2)
|
|
for i, batch in enumerate(batch_dp):
|
|
self.assertEqual(len(batch), bs)
|
|
self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)])
|
|
|
|
input_dp_nl = IDP_NoLen(range(10))
|
|
batch_dp_nl = input_dp_nl.batch(batch_size=2)
|
|
with self.assertRaises(NotImplementedError):
|
|
len(batch_dp_nl)
|
|
|
|
def test_bucket_batch_datapipe(self):
|
|
input_dp = IDP(range(20))
|
|
with self.assertRaises(AssertionError):
|
|
input_dp.bucket_batch(batch_size=0)
|
|
|
|
input_dp_nl = IDP_NoLen(range(20))
|
|
bucket_dp_nl = input_dp_nl.bucket_batch(batch_size=7)
|
|
with self.assertRaises(NotImplementedError):
|
|
len(bucket_dp_nl)
|
|
|
|
# Test Bucket Batch without sort_key
|
|
def _helper(**kwargs):
|
|
arrs = list(range(100))
|
|
random.shuffle(arrs)
|
|
input_dp = IDP(arrs)
|
|
bucket_dp = input_dp.bucket_batch(**kwargs)
|
|
if kwargs["sort_key"] is None:
|
|
# BatchDataset as reference
|
|
ref_dp = input_dp.batch(batch_size=kwargs['batch_size'], drop_last=kwargs['drop_last'])
|
|
for batch, rbatch in zip(bucket_dp, ref_dp):
|
|
self.assertEqual(batch, rbatch)
|
|
else:
|
|
bucket_size = bucket_dp.bucket_size
|
|
bucket_num = (len(input_dp) - 1) // bucket_size + 1
|
|
it = iter(bucket_dp)
|
|
for i in range(bucket_num):
|
|
ref = sorted(arrs[i * bucket_size: (i + 1) * bucket_size])
|
|
bucket: List = []
|
|
while len(bucket) < len(ref):
|
|
try:
|
|
batch = next(it)
|
|
bucket += batch
|
|
# If drop last, stop in advance
|
|
except StopIteration:
|
|
break
|
|
if len(bucket) != len(ref):
|
|
ref = ref[:len(bucket)]
|
|
# Sorted bucket
|
|
self.assertEqual(bucket, ref)
|
|
|
|
_helper(batch_size=7, drop_last=False, sort_key=None)
|
|
_helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=None)
|
|
|
|
# Test Bucket Batch with sort_key
|
|
def _sort_fn(data):
|
|
return data
|
|
|
|
_helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn)
|
|
_helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn)
|
|
|
|
def test_filter_datapipe(self):
|
|
input_ds = IDP(range(10))
|
|
|
|
def _filter_fn(data, val, clip=False):
|
|
if clip:
|
|
return data >= val
|
|
return True
|
|
|
|
filter_dp = input_ds.filter(filter_fn=_filter_fn, fn_args=(5, ))
|
|
for data, exp in zip(filter_dp, range(10)):
|
|
self.assertEqual(data, exp)
|
|
|
|
filter_dp = input_ds.filter(filter_fn=_filter_fn, fn_kwargs={'val': 5, 'clip': True})
|
|
for data, exp in zip(filter_dp, range(5, 10)):
|
|
self.assertEqual(data, exp)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
len(filter_dp)
|
|
|
|
def _non_bool_fn(data):
|
|
return 1
|
|
|
|
filter_dp = input_ds.filter(filter_fn=_non_bool_fn)
|
|
with self.assertRaises(ValueError):
|
|
temp = list(d for d in filter_dp)
|
|
|
|
def test_sampler_datapipe(self):
|
|
input_dp = IDP(range(10))
|
|
# Default SequentialSampler
|
|
sampled_dp = dp.iter.Sampler(input_dp) # type: ignore
|
|
self.assertEqual(len(sampled_dp), 10)
|
|
for i, x in enumerate(sampled_dp):
|
|
self.assertEqual(x, i)
|
|
|
|
# RandomSampler
|
|
random_sampled_dp = dp.iter.Sampler(input_dp, sampler=RandomSampler, sampler_kwargs={'replacement': True}) # type: ignore
|
|
|
|
# Requires `__len__` to build SamplerDataPipe
|
|
input_dp_nolen = IDP_NoLen(range(10))
|
|
with self.assertRaises(AssertionError):
|
|
sampled_dp = dp.iter.Sampler(input_dp_nolen)
|
|
|
|
def test_shuffle_datapipe(self):
|
|
exp = list(range(20))
|
|
input_ds = IDP(exp)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
shuffle_dp = input_ds.shuffle(buffer_size=0)
|
|
|
|
for bs in (5, 20, 25):
|
|
shuffle_dp = input_ds.shuffle(buffer_size=bs)
|
|
self.assertEqual(len(shuffle_dp), len(input_ds))
|
|
|
|
random.seed(123)
|
|
res = list(d for d in shuffle_dp)
|
|
self.assertEqual(sorted(res), exp)
|
|
|
|
# Test Deterministic
|
|
for num_workers in (0, 1):
|
|
random.seed(123)
|
|
dl = DataLoader(shuffle_dp, num_workers=num_workers, worker_init_fn=_worker_init_fn)
|
|
dl_res = list(d for d in dl)
|
|
self.assertEqual(res, dl_res)
|
|
|
|
shuffle_dp_nl = IDP_NoLen(range(20)).shuffle(buffer_size=5)
|
|
with self.assertRaises(NotImplementedError):
|
|
len(shuffle_dp_nl)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_transforms_datapipe(self):
|
|
torch.set_default_dtype(torch.float)
|
|
# A sequence of numpy random numbers representing 3-channel images
|
|
w = h = 32
|
|
inputs = [np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) for i in range(10)]
|
|
tensor_inputs = [torch.tensor(x, dtype=torch.float).permute(2, 0, 1) / 255. for x in inputs]
|
|
|
|
input_dp = IDP(inputs)
|
|
# Raise TypeError for python function
|
|
with self.assertRaisesRegex(TypeError, r"`transforms` are required to be"):
|
|
input_dp.transforms(_fake_fn)
|
|
|
|
# transforms.Compose of several transforms
|
|
transforms = torchvision.transforms.Compose([
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Pad(1, fill=1, padding_mode='constant'),
|
|
])
|
|
tsfm_dp = input_dp.transforms(transforms)
|
|
self.assertEqual(len(tsfm_dp), len(input_dp))
|
|
for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs):
|
|
self.assertEqual(tsfm_data[:, 1:(h + 1), 1:(w + 1)], input_data)
|
|
|
|
# nn.Sequential of several transforms (required to be instances of nn.Module)
|
|
input_dp = IDP(tensor_inputs)
|
|
transforms = nn.Sequential(
|
|
torchvision.transforms.Pad(1, fill=1, padding_mode='constant'),
|
|
)
|
|
tsfm_dp = input_dp.transforms(transforms)
|
|
self.assertEqual(len(tsfm_dp), len(input_dp))
|
|
for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs):
|
|
self.assertEqual(tsfm_data[:, 1:(h + 1), 1:(w + 1)], input_data)
|
|
|
|
# Single transform
|
|
input_dp = IDP_NoLen(inputs) # type: ignore
|
|
transform = torchvision.transforms.ToTensor()
|
|
tsfm_dp = input_dp.transforms(transform)
|
|
with self.assertRaises(NotImplementedError):
|
|
len(tsfm_dp)
|
|
for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs):
|
|
self.assertEqual(tsfm_data, input_data)
|
|
|
|
def test_zip_datapipe(self):
|
|
with self.assertRaises(TypeError):
|
|
dp.iter.Zip(IDP(range(10)), list(range(10))) # type: ignore
|
|
|
|
zipped_dp = dp.iter.Zip(IDP(range(10)), IDP_NoLen(range(5))) # type: ignore
|
|
with self.assertRaises(NotImplementedError):
|
|
len(zipped_dp)
|
|
exp = list((i, i) for i in range(5))
|
|
self.assertEqual(list(d for d in zipped_dp), exp)
|
|
|
|
zipped_dp = dp.iter.Zip(IDP(range(10)), IDP(range(5)))
|
|
self.assertEqual(len(zipped_dp), 5)
|
|
self.assertEqual(list(zipped_dp), exp)
|
|
# Reset
|
|
self.assertEqual(list(zipped_dp), exp)
|
|
|
|
|
|
class TestTyping(TestCase):
|
|
def test_subtype(self):
|
|
from torch.utils.data._typing import issubtype
|
|
|
|
basic_type = (int, str, bool, float, complex,
|
|
list, tuple, dict, set, T_co)
|
|
for t in basic_type:
|
|
self.assertTrue(issubtype(t, t))
|
|
self.assertTrue(issubtype(t, Any))
|
|
if t == T_co:
|
|
self.assertTrue(issubtype(Any, t))
|
|
else:
|
|
self.assertFalse(issubtype(Any, t))
|
|
for t1, t2 in itertools.product(basic_type, basic_type):
|
|
if t1 == t2 or t2 == T_co:
|
|
self.assertTrue(issubtype(t1, t2))
|
|
else:
|
|
self.assertFalse(issubtype(t1, t2))
|
|
|
|
T = TypeVar('T', int, str)
|
|
S = TypeVar('S', bool, Union[str, int], Tuple[int, T]) # type: ignore
|
|
types = ((int, Optional[int]),
|
|
(List, Union[int, list]),
|
|
(Tuple[int, str], S),
|
|
(Tuple[int, str], tuple),
|
|
(T, S),
|
|
(S, T_co),
|
|
(T, Union[S, Set]))
|
|
for sub, par in types:
|
|
self.assertTrue(issubtype(sub, par))
|
|
self.assertFalse(issubtype(par, sub))
|
|
|
|
subscriptable_types = {
|
|
List: 1,
|
|
Tuple: 2, # use 2 parameters
|
|
Set: 1,
|
|
Dict: 2,
|
|
}
|
|
for subscript_type, n in subscriptable_types.items():
|
|
for ts in itertools.combinations(types, n):
|
|
subs, pars = zip(*ts)
|
|
sub = subscript_type[subs] # type: ignore
|
|
par = subscript_type[pars] # type: ignore
|
|
self.assertTrue(issubtype(sub, par))
|
|
self.assertFalse(issubtype(par, sub))
|
|
# Non-recursive check
|
|
self.assertTrue(issubtype(par, sub, recursive=False))
|
|
|
|
def test_issubinstance(self):
|
|
from torch.utils.data._typing import issubinstance
|
|
|
|
basic_data = (1, '1', True, 1., complex(1., 0.))
|
|
basic_type = (int, str, bool, float, complex)
|
|
S = TypeVar('S', bool, Union[str, int])
|
|
for d in basic_data:
|
|
self.assertTrue(issubinstance(d, Any))
|
|
self.assertTrue(issubinstance(d, T_co))
|
|
if type(d) in (bool, int, str):
|
|
self.assertTrue(issubinstance(d, S))
|
|
else:
|
|
self.assertFalse(issubinstance(d, S))
|
|
for t in basic_type:
|
|
if type(d) == t:
|
|
self.assertTrue(issubinstance(d, t))
|
|
else:
|
|
self.assertFalse(issubinstance(d, t))
|
|
# list/set
|
|
dt = (([1, '1', 2], List), (set({1, '1', 2}), Set))
|
|
for d, t in dt:
|
|
self.assertTrue(issubinstance(d, t))
|
|
self.assertTrue(issubinstance(d, t[T_co])) # type: ignore
|
|
self.assertFalse(issubinstance(d, t[int])) # type: ignore
|
|
|
|
# dict
|
|
d = dict({'1': 1, '2': 2.})
|
|
self.assertTrue(issubinstance(d, Dict))
|
|
self.assertTrue(issubinstance(d, Dict[str, T_co]))
|
|
self.assertFalse(issubinstance(d, Dict[str, int]))
|
|
|
|
# tuple
|
|
d = (1, '1', 2)
|
|
self.assertTrue(issubinstance(d, Tuple))
|
|
self.assertTrue(issubinstance(d, Tuple[int, str, T_co]))
|
|
self.assertFalse(issubinstance(d, Tuple[int, Any]))
|
|
self.assertFalse(issubinstance(d, Tuple[int, int, int]))
|
|
|
|
# Static checking annotation
|
|
def test_compile_time(self):
|
|
with self.assertRaisesRegex(TypeError, r"Expected 'Iterator' as the return"):
|
|
class InvalidDP1(IterDataPipe[int]):
|
|
def __iter__(self) -> str: # type: ignore
|
|
yield 0
|
|
|
|
with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"):
|
|
class InvalidDP2(IterDataPipe[Tuple]):
|
|
def __iter__(self) -> Iterator[int]: # type: ignore
|
|
yield 0
|
|
|
|
with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"):
|
|
class InvalidDP3(IterDataPipe[Tuple[int, str]]):
|
|
def __iter__(self) -> Iterator[tuple]: # type: ignore
|
|
yield (0, )
|
|
|
|
class DP1(IterDataPipe[Tuple[int, str]]):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __iter__(self) -> Iterator[Tuple[int, str]]:
|
|
for d in range(self.length):
|
|
yield d, str(d)
|
|
|
|
self.assertTrue(issubclass(DP1, IterDataPipe))
|
|
dp1 = DP1(10)
|
|
self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type))
|
|
dp2 = DP1(5)
|
|
self.assertEqual(dp1.type, dp2.type)
|
|
|
|
with self.assertRaisesRegex(TypeError, r"Can not subclass a DataPipe"):
|
|
class InvalidDP4(DP1[tuple]): # type: ignore
|
|
def __iter__(self) -> Iterator[tuple]: # type: ignore
|
|
yield (0, )
|
|
|
|
class DP2(IterDataPipe[T_co]):
|
|
def __iter__(self) -> Iterator[T_co]:
|
|
for d in range(10):
|
|
yield d # type: ignore
|
|
|
|
self.assertTrue(issubclass(DP2, IterDataPipe))
|
|
dp1 = DP2() # type: ignore
|
|
self.assertTrue(DP2.type.issubtype(dp1.type) and dp1.type.issubtype(DP2.type))
|
|
dp2 = DP2() # type: ignore
|
|
self.assertEqual(dp1.type, dp2.type)
|
|
|
|
class DP3(IterDataPipe[Tuple[T_co, str]]):
|
|
r""" DataPipe without fixed type with __init__ function"""
|
|
def __init__(self, datasource):
|
|
self.datasource = datasource
|
|
|
|
def __iter__(self) -> Iterator[Tuple[T_co, str]]:
|
|
for d in self.datasource:
|
|
yield d, str(d)
|
|
|
|
self.assertTrue(issubclass(DP3, IterDataPipe))
|
|
dp1 = DP3(range(10)) # type: ignore
|
|
self.assertTrue(DP3.type.issubtype(dp1.type) and dp1.type.issubtype(DP3.type))
|
|
dp2 = DP3(5) # type: ignore
|
|
self.assertEqual(dp1.type, dp2.type)
|
|
|
|
class DP4(IterDataPipe[tuple]):
|
|
r""" DataPipe without __iter__ annotation"""
|
|
def __iter__(self):
|
|
raise NotImplementedError
|
|
|
|
self.assertTrue(issubclass(DP4, IterDataPipe))
|
|
dp = DP4()
|
|
self.assertTrue(dp.type.param == tuple)
|
|
|
|
class DP5(IterDataPipe):
|
|
r""" DataPipe without type annotation"""
|
|
def __iter__(self) -> Iterator[str]:
|
|
raise NotImplementedError
|
|
|
|
self.assertTrue(issubclass(DP5, IterDataPipe))
|
|
dp = DP5() # type: ignore
|
|
self.assertTrue(dp.type.param == Any)
|
|
|
|
class DP6(IterDataPipe[int]):
|
|
r""" DataPipe with plain Iterator"""
|
|
def __iter__(self) -> Iterator:
|
|
raise NotImplementedError
|
|
|
|
self.assertTrue(issubclass(DP6, IterDataPipe))
|
|
dp = DP6() # type: ignore
|
|
self.assertTrue(dp.type.param == int)
|
|
|
|
def test_construct_time(self):
|
|
class DP0(IterDataPipe[Tuple]):
|
|
@construct_time_validation
|
|
def __init__(self, dp: IterDataPipe):
|
|
self.dp = dp
|
|
|
|
def __iter__(self) -> Iterator[Tuple]:
|
|
for d in self.dp:
|
|
yield d, str(d)
|
|
|
|
class DP1(IterDataPipe[int]):
|
|
@construct_time_validation
|
|
def __init__(self, dp: IterDataPipe[Tuple[int, str]]):
|
|
self.dp = dp
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
for a, b in self.dp:
|
|
yield a
|
|
|
|
# Non-DataPipe input with DataPipe hint
|
|
datasource = [(1, '1'), (2, '2'), (3, '3')]
|
|
with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"):
|
|
dp = DP0(datasource)
|
|
|
|
dp = DP0(IDP(range(10)))
|
|
with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"):
|
|
dp = DP1(dp)
|
|
|
|
with self.assertRaisesRegex(TypeError, r"Can not decorate"):
|
|
class InvalidDP1(IterDataPipe[int]):
|
|
@construct_time_validation
|
|
def __iter__(self):
|
|
yield 0
|
|
|
|
def test_runtime(self):
|
|
class DP(IterDataPipe[Tuple[int, T_co]]):
|
|
def __init__(self, datasource):
|
|
self.ds = datasource
|
|
|
|
@runtime_validation
|
|
def __iter__(self) -> Iterator[Tuple[int, T_co]]:
|
|
for d in self.ds:
|
|
yield d
|
|
|
|
dss = ([(1, '1'), (2, '2')],
|
|
[(1, 1), (2, '2')])
|
|
for ds in dss:
|
|
dp = DP(ds) # type: ignore
|
|
self.assertEqual(list(d for d in dp), ds)
|
|
# Reset __iter__
|
|
self.assertEqual(list(d for d in dp), ds)
|
|
|
|
dss = ([(1, 1), ('2', 2)], # type: ignore
|
|
[[1, '1'], [2, '2']], # type: ignore
|
|
[1, '1', 2, '2'])
|
|
for ds in dss:
|
|
dp = DP(ds)
|
|
with self.assertRaisesRegex(RuntimeError, r"Expected an instance of subtype"):
|
|
list(d for d in dp)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|