mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introducing DataChunk for DataPipes batching (#62768)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62768 This is part of TorchArrow DF support preparation, separating it to multiple PRs to simplify review process. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30149090 Pulled By: VitalyFedyunin fbshipit-source-id: a36b5ff56e2ac6b06060014d4cd41b487754acb8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5e5de75f4d
commit
d3bdf345cb
0
test/delete.py
Normal file
0
test/delete.py
Normal file
@ -42,6 +42,7 @@ import torch.utils.data.sharding
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
DataChunk,
|
||||
IterDataPipe,
|
||||
MapDataPipe,
|
||||
RandomSampler,
|
||||
@ -108,6 +109,14 @@ def create_temp_dir_and_files():
|
||||
(temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)]
|
||||
|
||||
|
||||
class TestDataChunk(TestCase):
|
||||
|
||||
def test_as_string(self):
|
||||
elements = list(range(10))
|
||||
chunk : DataChunk[int] = DataChunk(elements)
|
||||
self.assertEquals(str(chunk), str(elements))
|
||||
|
||||
|
||||
class TestIterableDataPipeBasic(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -141,7 +150,6 @@ class TestIterableDataPipeBasic(TestCase):
|
||||
self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files))
|
||||
self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))
|
||||
|
||||
|
||||
def test_loadfilesfromdisk_iterable_datapipe(self):
|
||||
# test import datapipe class directly
|
||||
from torch.utils.data.datapipes.iter import (
|
||||
@ -216,7 +224,6 @@ class TestIterableDataPipeBasic(TestCase):
|
||||
self.assertEqual(data_ref[1].read(), f.read())
|
||||
data_ref[1].close()
|
||||
|
||||
|
||||
def test_routeddecoder_iterable_datapipe(self):
|
||||
temp_dir = self.temp_dir.name
|
||||
temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
|
||||
@ -697,7 +704,6 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
for i in unbatch_dp:
|
||||
print(i)
|
||||
|
||||
|
||||
def test_bucket_batch_datapipe(self):
|
||||
input_dp = IDP(range(20))
|
||||
with self.assertRaises(AssertionError):
|
||||
@ -787,7 +793,8 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
for data, exp in zip(filter_dp, expected_dp1):
|
||||
self.assertEqual(data, exp)
|
||||
|
||||
filter_dp = input_ds.filter(nesting_level=-1, drop_empty_batches=False, filter_fn=_filter_fn, fn_kwargs={'val': 5})
|
||||
filter_dp = input_ds.filter(nesting_level=-1, drop_empty_batches=False,
|
||||
filter_fn=_filter_fn, fn_kwargs={'val': 5})
|
||||
expected_dp2: List[List[int]] = [[], [5, 6, 7, 8, 9]]
|
||||
self.assertEqual(len(list(filter_dp)), len(expected_dp2))
|
||||
for data, exp in zip(filter_dp, expected_dp2):
|
||||
@ -826,7 +833,6 @@ class TestFunctionalIterDataPipe(TestCase):
|
||||
for data2, exp2 in zip(filter_dp, expected_dp6):
|
||||
self.assertEqual(data2, exp2)
|
||||
|
||||
|
||||
def test_sampler_datapipe(self):
|
||||
input_dp = IDP(range(10))
|
||||
# Default SequentialSampler
|
||||
@ -1153,6 +1159,7 @@ class TestTyping(TestCase):
|
||||
|
||||
class DP3(IterDataPipe[Tuple[T_co, str]]):
|
||||
r""" DataPipe without fixed type with __init__ function"""
|
||||
|
||||
def __init__(self, datasource):
|
||||
self.datasource = datasource
|
||||
|
||||
@ -1168,6 +1175,7 @@ class TestTyping(TestCase):
|
||||
|
||||
class DP4(IterDataPipe[tuple]):
|
||||
r""" DataPipe without __iter__ annotation"""
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1177,6 +1185,7 @@ class TestTyping(TestCase):
|
||||
|
||||
class DP5(IterDataPipe):
|
||||
r""" DataPipe without type annotation"""
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1187,6 +1196,7 @@ class TestTyping(TestCase):
|
||||
|
||||
class DP6(IterDataPipe[int]):
|
||||
r""" DataPipe with plain Iterator"""
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1206,7 +1216,6 @@ class TestTyping(TestCase):
|
||||
self.assertTrue(issubclass(DP8, IterDataPipe))
|
||||
self.assertTrue(DP8.type.param == Awaitable[str])
|
||||
|
||||
|
||||
def test_construct_time(self):
|
||||
class DP0(IterDataPipe[Tuple]):
|
||||
@argument_validation
|
||||
@ -1269,11 +1278,9 @@ class TestTyping(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
|
||||
list(dp)
|
||||
|
||||
|
||||
def test_reinforce(self):
|
||||
T = TypeVar('T', int, str)
|
||||
|
||||
|
||||
class DP(IterDataPipe[T]):
|
||||
def __init__(self, ds):
|
||||
self.ds = ds
|
||||
@ -1306,6 +1313,7 @@ class TestTyping(TestCase):
|
||||
with runtime_validation_disabled():
|
||||
self.assertEqual(list(d for d in dp), ds)
|
||||
|
||||
|
||||
class NumbersDataset(IterDataPipe):
|
||||
def __init__(self, size=10):
|
||||
self.size = size
|
||||
@ -1321,7 +1329,7 @@ class TestGraph(TestCase):
|
||||
numbers_dp = NumbersDataset(size=50)
|
||||
mapped_dp = numbers_dp.map(lambda x: x * 10)
|
||||
graph = torch.utils.data.graph.traverse(mapped_dp)
|
||||
expected : Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}}
|
||||
expected: Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}}
|
||||
self.assertEqual(expected, graph)
|
||||
|
||||
# TODO(VitalyFedyunin): This test is incorrect because of 'buffer' nature
|
||||
@ -1377,5 +1385,6 @@ class TestSharding(TestCase):
|
||||
|
||||
self.assertEqual(sorted(expected), sorted(items))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -13,6 +13,7 @@ from torch.utils.data.dataset import (
|
||||
ConcatDataset,
|
||||
Dataset,
|
||||
Dataset as MapDataPipe,
|
||||
DataChunk,
|
||||
IterableDataset,
|
||||
IterableDataset as IterDataPipe,
|
||||
Subset,
|
||||
|
@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import IterDataPipe, _utils, functional_datapipe
|
||||
from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
|
||||
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
|
||||
|
||||
try:
|
||||
@ -68,14 +68,17 @@ class MapIterDataPipe(IterDataPipe[T_co]):
|
||||
if nesting_level == 0:
|
||||
return self.fn(data, *self.args, **self.kwargs)
|
||||
elif nesting_level > 0:
|
||||
if not isinstance(data, list):
|
||||
if isinstance(data, DataChunk):
|
||||
return type(data)([self._apply(i, nesting_level - 1) for i in data.raw_iterator()])
|
||||
elif isinstance(data, list):
|
||||
return [self._apply(i, nesting_level - 1) for i in data]
|
||||
else:
|
||||
raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
|
||||
result = [self._apply(i, nesting_level - 1) for i in data]
|
||||
return result
|
||||
else:
|
||||
if isinstance(data, list):
|
||||
result = [self._apply(i, nesting_level) for i in data]
|
||||
return result
|
||||
if isinstance(data, DataChunk):
|
||||
return type(data)([self._apply(i, nesting_level) for i in data.raw_iterator()])
|
||||
elif isinstance(data, list):
|
||||
return [self._apply(i, nesting_level) for i in data]
|
||||
else:
|
||||
return self.fn(data, *self.args, **self.kwargs)
|
||||
|
||||
@ -162,6 +165,7 @@ class TransformsIterDataPipe(MapIterDataPipe):
|
||||
datapipe: Iterable DataPipe being transformed
|
||||
transforms: A transform or a sequence of transforms from torchvision or torchaudio.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
datapipe: IterDataPipe,
|
||||
transforms: Callable,
|
||||
|
@ -31,8 +31,7 @@ class SamplerIterDataPipe(IterDataPipe[T_co]):
|
||||
self.sampler_args = () if sampler_args is None else sampler_args
|
||||
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
|
||||
# https://github.com/python/mypy/pull/9629 will solve
|
||||
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args,
|
||||
**self.sampler_kwargs) # type: ignore[misc]
|
||||
self.sampler = sampler(data_source=self.datapipe, *self.sampler_args, **self.sampler_kwargs) # type: ignore[misc]
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
return iter(self.sampler)
|
||||
|
@ -4,7 +4,7 @@ import warnings
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.utils.data import IterDataPipe, functional_datapipe
|
||||
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sized, Tuple, TypeVar, DefaultDict
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
@ -31,7 +31,7 @@ class ShardingFilterIterDataPipe(IterDataPipe):
|
||||
|
||||
|
||||
@functional_datapipe('batch')
|
||||
class BatchIterDataPipe(IterDataPipe[List[T_co]]):
|
||||
class BatchIterDataPipe(IterDataPipe[DataChunk[T_co]]):
|
||||
r""" :class:`BatchIterDataPipe`.
|
||||
|
||||
Iterable DataPipe to create mini-batches of data. An outer dimension will be added as
|
||||
@ -65,17 +65,18 @@ class BatchIterDataPipe(IterDataPipe[List[T_co]]):
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
self.length = None
|
||||
self.wrapper_class = DataChunk
|
||||
|
||||
def __iter__(self) -> Iterator[List[T_co]]:
|
||||
def __iter__(self) -> Iterator[DataChunk[T_co]]:
|
||||
batch: List[T_co] = []
|
||||
for x in self.datapipe:
|
||||
batch.append(x)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
yield self.wrapper_class(batch)
|
||||
batch = []
|
||||
if len(batch) > 0:
|
||||
if not self.drop_last:
|
||||
yield batch
|
||||
yield self.wrapper_class(batch)
|
||||
batch = []
|
||||
|
||||
def __len__(self) -> int:
|
||||
@ -115,7 +116,7 @@ class UnBatchIterDataPipe(IterDataPipe):
|
||||
if unbatch_level < -1:
|
||||
raise ValueError("unbatch_level must be -1 or >= 0")
|
||||
if unbatch_level == -1:
|
||||
if isinstance(element, list):
|
||||
if isinstance(element, list) or isinstance(element, DataChunk):
|
||||
for item in element:
|
||||
for i in self._dive(item, unbatch_level=-1):
|
||||
yield i
|
||||
@ -124,11 +125,12 @@ class UnBatchIterDataPipe(IterDataPipe):
|
||||
elif unbatch_level == 0:
|
||||
yield element
|
||||
else:
|
||||
if not isinstance(element, list):
|
||||
if isinstance(element, list) or isinstance(element, DataChunk):
|
||||
for item in element:
|
||||
for i in self._dive(item, unbatch_level=unbatch_level - 1):
|
||||
yield i
|
||||
else:
|
||||
raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe")
|
||||
for item in element:
|
||||
for i in self._dive(item, unbatch_level=unbatch_level - 1):
|
||||
yield i
|
||||
|
||||
|
||||
@functional_datapipe('bucket_batch')
|
||||
@ -175,11 +177,16 @@ class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]):
|
||||
def __iter__(self) -> Iterator[List[T_co]]:
|
||||
# Bucket without sorting remains same order, directly returns BatchDataset
|
||||
if self.sort_key is None:
|
||||
yield from BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last)
|
||||
for element in BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last):
|
||||
if isinstance(element, DataChunk):
|
||||
yield list(element.raw_iterator())
|
||||
else:
|
||||
yield element
|
||||
else:
|
||||
bucket: List[T_co]
|
||||
batch: List[T_co] = []
|
||||
for bucket in self.bucket_ds:
|
||||
for bucket_or_chunk in self.bucket_ds:
|
||||
bucket = list(bucket_or_chunk)
|
||||
# In-place sort within bucket
|
||||
bucket.sort(key=self.sort_key)
|
||||
for start in range(0, len(bucket), self.batch_size):
|
||||
@ -255,6 +262,7 @@ class GroupByIterDataPipe(IterDataPipe):
|
||||
assert guaranteed_group_size > 0 and group_size is not None and guaranteed_group_size <= group_size
|
||||
self.guaranteed_group_size = guaranteed_group_size
|
||||
self.drop_remaining = drop_remaining
|
||||
self.wrapper_class = DataChunk
|
||||
|
||||
def _remove_biggest_key(self, buffer_elements, buffer_size):
|
||||
biggest_key = None
|
||||
@ -283,14 +291,14 @@ class GroupByIterDataPipe(IterDataPipe):
|
||||
key = self.group_key_fn(x)
|
||||
|
||||
if self.group_size is not None and self.group_size == len(buffer_elements[key]):
|
||||
yield buffer_elements[key]
|
||||
yield self.wrapper_class(buffer_elements[key])
|
||||
buffer_size -= len(buffer_elements[key])
|
||||
del buffer_elements[key]
|
||||
|
||||
if buffer_size == self.buffer_size:
|
||||
(result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
|
||||
if result_to_yield is not None:
|
||||
yield result_to_yield
|
||||
yield self.wrapper_class(result_to_yield)
|
||||
|
||||
buffer_elements[key].append(x)
|
||||
buffer_size += 1
|
||||
@ -298,7 +306,7 @@ class GroupByIterDataPipe(IterDataPipe):
|
||||
while buffer_size:
|
||||
(result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
|
||||
if result_to_yield is not None:
|
||||
yield result_to_yield
|
||||
yield self.wrapper_class(result_to_yield)
|
||||
|
||||
|
||||
@functional_datapipe('group_by_key')
|
||||
|
@ -1,4 +1,4 @@
|
||||
from torch.utils.data import IterDataPipe, functional_datapipe
|
||||
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
|
||||
from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict
|
||||
|
||||
from .callable import MapIterDataPipe
|
||||
@ -45,12 +45,20 @@ class FilterIterDataPipe(MapIterDataPipe):
|
||||
if nesting_level == 0:
|
||||
return self._returnIfTrue(data)
|
||||
elif nesting_level > 0:
|
||||
if not isinstance(data, list):
|
||||
if isinstance(data, DataChunk):
|
||||
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1)
|
||||
for i in data.raw_iterator()])
|
||||
return type(data)(list(result))
|
||||
elif isinstance(data, list):
|
||||
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1) for i in data])
|
||||
return list(result)
|
||||
else:
|
||||
raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
|
||||
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level - 1) for i in data])
|
||||
return list(result)
|
||||
else: # Handling nesting_level == -1
|
||||
if isinstance(data, list):
|
||||
if isinstance(data, DataChunk):
|
||||
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level) for i in data.raw_iterator()])
|
||||
return type(data)(list(result))
|
||||
elif isinstance(data, list):
|
||||
result = filter(self._isNonEmpty, [self._applyFilter(i, nesting_level) for i in data])
|
||||
return list(result)
|
||||
else:
|
||||
@ -64,7 +72,10 @@ class FilterIterDataPipe(MapIterDataPipe):
|
||||
return data
|
||||
|
||||
def _isNonEmpty(self, data):
|
||||
return data is not None and not (data == [] and self.drop_empty_batches)
|
||||
r = data is not None and \
|
||||
not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
|
||||
return r
|
||||
|
||||
|
||||
def __len__(self):
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
|
@ -25,6 +25,32 @@ T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class DataChunk(List[T]):
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.items[key]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def as_str(self, indent=''):
|
||||
res = indent + "[" + ", ".join([str(i) for i in iter(self)]) + "]"
|
||||
return res
|
||||
|
||||
def __str__(self):
|
||||
return self.as_str()
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
for i in self.items:
|
||||
yield i
|
||||
|
||||
def raw_iterator(self):
|
||||
for i in self.items:
|
||||
yield i
|
||||
|
||||
|
||||
class Dataset(Generic[T_co]):
|
||||
r"""An abstract class representing a :class:`Dataset`.
|
||||
|
||||
|
Reference in New Issue
Block a user