Introducing DataChunk for DataPipes batching (#62768)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62768

This is part of TorchArrow DF support preparation, separating it to multiple PRs to simplify review process.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30149090

Pulled By: VitalyFedyunin

fbshipit-source-id: a36b5ff56e2ac6b06060014d4cd41b487754acb8
This commit is contained in:
Vitaly Fedyunin
2021-08-06 08:34:58 -07:00
committed by Facebook GitHub Bot
parent 5e5de75f4d
commit d3bdf345cb
8 changed files with 97 additions and 39 deletions

0
test/delete.py Normal file
View File

View File

@ -42,6 +42,7 @@ import torch.utils.data.sharding
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.utils.data import (
DataLoader,
DataChunk,
IterDataPipe,
MapDataPipe,
RandomSampler,
@ -108,6 +109,14 @@ def create_temp_dir_and_files():
(temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)]
class TestDataChunk(TestCase):
def test_as_string(self):
elements = list(range(10))
chunk : DataChunk[int] = DataChunk(elements)
self.assertEquals(str(chunk), str(elements))
class TestIterableDataPipeBasic(TestCase):
def setUp(self):
@ -141,7 +150,6 @@ class TestIterableDataPipeBasic(TestCase):
self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files))
self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))
def test_loadfilesfromdisk_iterable_datapipe(self):
# test import datapipe class directly
from torch.utils.data.datapipes.iter import (
@ -216,7 +224,6 @@ class TestIterableDataPipeBasic(TestCase):
self.assertEqual(data_ref[1].read(), f.read())
data_ref[1].close()
def test_routeddecoder_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
@ -697,7 +704,6 @@ class TestFunctionalIterDataPipe(TestCase):
for i in unbatch_dp:
print(i)
def test_bucket_batch_datapipe(self):
input_dp = IDP(range(20))
with self.assertRaises(AssertionError):
@ -787,7 +793,8 @@ class TestFunctionalIterDataPipe(TestCase):
for data, exp in zip(filter_dp, expected_dp1):
self.assertEqual(data, exp)
filter_dp = input_ds.filter(nesting_level=-1, drop_empty_batches=False, filter_fn=_filter_fn, fn_kwargs={'val': 5})
filter_dp = input_ds.filter(nesting_level=-1, drop_empty_batches=False,
filter_fn=_filter_fn, fn_kwargs={'val': 5})
expected_dp2: List[List[int]] = [[], [5, 6, 7, 8, 9]]
self.assertEqual(len(list(filter_dp)), len(expected_dp2))
for data, exp in zip(filter_dp, expected_dp2):
@ -826,7 +833,6 @@ class TestFunctionalIterDataPipe(TestCase):
for data2, exp2 in zip(filter_dp, expected_dp6):
self.assertEqual(data2, exp2)
def test_sampler_datapipe(self):
input_dp = IDP(range(10))
# Default SequentialSampler
@ -1153,6 +1159,7 @@ class TestTyping(TestCase):
class DP3(IterDataPipe[Tuple[T_co, str]]):
r""" DataPipe without fixed type with __init__ function"""
def __init__(self, datasource):
self.datasource = datasource
@ -1168,6 +1175,7 @@ class TestTyping(TestCase):
class DP4(IterDataPipe[tuple]):
r""" DataPipe without __iter__ annotation"""
def __iter__(self):
raise NotImplementedError
@ -1177,6 +1185,7 @@ class TestTyping(TestCase):
class DP5(IterDataPipe):
r""" DataPipe without type annotation"""
def __iter__(self) -> Iterator[str]:
raise NotImplementedError
@ -1187,6 +1196,7 @@ class TestTyping(TestCase):
class DP6(IterDataPipe[int]):
r""" DataPipe with plain Iterator"""
def __iter__(self) -> Iterator:
raise NotImplementedError
@ -1206,7 +1216,6 @@ class TestTyping(TestCase):
self.assertTrue(issubclass(DP8, IterDataPipe))
self.assertTrue(DP8.type.param == Awaitable[str])
def test_construct_time(self):
class DP0(IterDataPipe[Tuple]):
@argument_validation
@ -1269,11 +1278,9 @@ class TestTyping(TestCase):
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
list(dp)
def test_reinforce(self):
T = TypeVar('T', int, str)
class DP(IterDataPipe[T]):
def __init__(self, ds):
self.ds = ds
@ -1306,6 +1313,7 @@ class TestTyping(TestCase):
with runtime_validation_disabled():
self.assertEqual(list(d for d in dp), ds)
class NumbersDataset(IterDataPipe):
def __init__(self, size=10):
self.size = size
@ -1321,7 +1329,7 @@ class TestGraph(TestCase):
numbers_dp = NumbersDataset(size=50)
mapped_dp = numbers_dp.map(lambda x: x * 10)
graph = torch.utils.data.graph.traverse(mapped_dp)
expected : Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}}
expected: Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}}
self.assertEqual(expected, graph)
# TODO(VitalyFedyunin): This test is incorrect because of 'buffer' nature
@ -1377,5 +1385,6 @@ class TestSharding(TestCase):
self.assertEqual(sorted(expected), sorted(items))
if __name__ == '__main__':
run_tests()

View File

@ -13,6 +13,7 @@ from torch.utils.data.dataset import (
ConcatDataset,
Dataset,
Dataset as MapDataPipe,
DataChunk,
IterableDataset,
IterableDataset as IterDataPipe,
Subset,

View File

@ -1,6 +1,6 @@
import warnings
import torch.nn as nn
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
try:
@ -68,14 +68,17 @@ class MapIterDataPipe(IterDataPipe[T_co]):
if nesting_level == 0:
return self.fn(data, *self.args, **self.kwargs)
elif nesting_level > 0:
if not isinstance(data, list):
if isinstance(data, DataChunk):
return type(data)([self._apply(i, nesting_level - 1) for i in data.raw_iterator()])
elif isinstance(data, list):
return [self._apply(i, nesting_level - 1) for i in data]
else:
raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
result = [self._apply(i, nesting_level - 1) for i in data]
return result
else:
if isinstance(data, list):
result = [self._apply(i, nesting_level) for i in data]
return result
if isinstance(data, DataChunk):
return type(data)([self._apply(i, nesting_level) for i in data.raw_iterator()])
elif isinstance(data, list):
return [self._apply(i, nesting_level) for i in data]
else:
return self.fn(data, *self.args, **self.kwargs)
@ -162,6 +165,7 @@ class TransformsIterDataPipe(MapIterDataPipe):
datapipe: Iterable DataPipe being transformed
transforms: A transform or a sequence of transforms from torchvision or torchaudio.
"""
def __init__(self,
datapipe: IterDataPipe,
transforms: Callable,

View File

@ -31,8 +31,7 @@ class SamplerIterDataPipe(IterDataPipe[T_co]):
self.sampler_args = () if sampler_args is None else sampler_args
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
# https://github.com/python/mypy/pull/9629 will solve
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args,
**self.sampler_kwargs) # type: ignore[misc]
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args, **self.sampler_kwargs) # type: ignore[misc]
def __iter__(self) -> Iterator[T_co]:
return iter(self.sampler)

View File

@ -4,7 +4,7 @@ import warnings
from collections import defaultdict
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
from typing import Any, Callable, Dict, Iterator, List, Optional, Sized, Tuple, TypeVar, DefaultDict
T_co = TypeVar('T_co', covariant=True)
@ -31,7 +31,7 @@ class ShardingFilterIterDataPipe(IterDataPipe):
@functional_datapipe('batch')
class BatchIterDataPipe(IterDataPipe[List[T_co]]):
class BatchIterDataPipe(IterDataPipe[DataChunk[T_co]]):
r""" :class:`BatchIterDataPipe`.
Iterable DataPipe to create mini-batches of data. An outer dimension will be added as
@ -65,17 +65,18 @@ class BatchIterDataPipe(IterDataPipe[List[T_co]]):
self.batch_size = batch_size
self.drop_last = drop_last
self.length = None
self.wrapper_class = DataChunk
def __iter__(self) -> Iterator[List[T_co]]:
def __iter__(self) -> Iterator[DataChunk[T_co]]:
batch: List[T_co] = []
for x in self.datapipe:
batch.append(x)
if len(batch) == self.batch_size:
yield batch
yield self.wrapper_class(batch)
batch = []
if len(batch) > 0:
if not self.drop_last:
yield batch
yield self.wrapper_class(batch)
batch = []
def __len__(self) -> int:
@ -115,7 +116,7 @@ class UnBatchIterDataPipe(IterDataPipe):
if unbatch_level < -1:
raise ValueError("unbatch_level must be -1 or >= 0")
if unbatch_level == -1:
if isinstance(element, list):
if isinstance(element, list) or isinstance(element, DataChunk):
for item in element:
for i in self._dive(item, unbatch_level=-1):
yield i
@ -124,11 +125,12 @@ class UnBatchIterDataPipe(IterDataPipe):
elif unbatch_level == 0:
yield element
else:
if not isinstance(element, list):
if isinstance(element, list) or isinstance(element, DataChunk):
for item in element:
for i in self._dive(item, unbatch_level=unbatch_level - 1):
yield i
else:
raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe")
for item in element:
for i in self._dive(item, unbatch_level=unbatch_level - 1):
yield i
@functional_datapipe('bucket_batch')
@ -175,11 +177,16 @@ class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]):
def __iter__(self) -> Iterator[List[T_co]]:
# Bucket without sorting remains same order, directly returns BatchDataset
if self.sort_key is None:
yield from BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last)
for element in BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last):
if isinstance(element, DataChunk):
yield list(element.raw_iterator())
else:
yield element
else:
bucket: List[T_co]
batch: List[T_co] = []
for bucket in self.bucket_ds:
for bucket_or_chunk in self.bucket_ds:
bucket = list(bucket_or_chunk)
# In-place sort within bucket
bucket.sort(key=self.sort_key)
for start in range(0, len(bucket), self.batch_size):
@ -255,6 +262,7 @@ class GroupByIterDataPipe(IterDataPipe):
assert guaranteed_group_size > 0 and group_size is not None and guaranteed_group_size <= group_size
self.guaranteed_group_size = guaranteed_group_size
self.drop_remaining = drop_remaining
self.wrapper_class = DataChunk
def _remove_biggest_key(self, buffer_elements, buffer_size):
biggest_key = None
@ -283,14 +291,14 @@ class GroupByIterDataPipe(IterDataPipe):
key = self.group_key_fn(x)
if self.group_size is not None and self.group_size == len(buffer_elements[key]):
yield buffer_elements[key]
yield self.wrapper_class(buffer_elements[key])
buffer_size -= len(buffer_elements[key])
del buffer_elements[key]
if buffer_size == self.buffer_size:
(result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
if result_to_yield is not None:
yield result_to_yield
yield self.wrapper_class(result_to_yield)
buffer_elements[key].append(x)
buffer_size += 1
@ -298,7 +306,7 @@ class GroupByIterDataPipe(IterDataPipe):
while buffer_size:
(result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
if result_to_yield is not None:
yield result_to_yield
yield self.wrapper_class(result_to_yield)
@functional_datapipe('group_by_key')

View File

@ -1,4 +1,4 @@
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
from .callable import MapIterDataPipe
@ -45,12 +45,20 @@ class FilterIterDataPipe(MapIterDataPipe):
if nesting_level == 0:
return self._returnIfTrue(data)
elif nesting_level > 0:
if not isinstance(data, list):
if isinstance(data, DataChunk):
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1)
for i in data.raw_iterator()])
return type(data)(list(result))
elif isinstance(data, list):
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1) for i in data])
return list(result)
else:
raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1) for i in data])
return list(result)
else: # Handling nesting_level == -1
if isinstance(data, list):
if isinstance(data, DataChunk):
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level) for i in data.raw_iterator()])
return type(data)(list(result))
elif isinstance(data, list):
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level) for i in data])
return list(result)
else:
@ -64,7 +72,10 @@ class FilterIterDataPipe(MapIterDataPipe):
return data
def _isNonEmpty(self, data):
return data is not None and not (data == [] and self.drop_empty_batches)
r = data is not None and \
not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
return r
def __len__(self):
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))

View File

@ -25,6 +25,32 @@ T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
class DataChunk(List[T]):
def __init__(self, items):
self.items = items
def __getitem__(self, key):
return self.items[key]
def __len__(self):
return len(self.items)
def as_str(self, indent=''):
res = indent + "[" + ", ".join([str(i) for i in iter(self)]) + "]"
return res
def __str__(self):
return self.as_str()
def __iter__(self) -> Iterator[T]:
for i in self.items:
yield i
def raw_iterator(self):
for i in self.items:
yield i
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.