mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
move GroupByFilename Dataset to DataPipe (#51709)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51709 Move GroupByFilename Dataset to DataPipe Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D26263585 Pulled By: glaringlee fbshipit-source-id: 00e3e13b47b89117f1ccfc4cd6239940a40d071e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
482b94ae51
commit
015cabf82a
@ -19,6 +19,7 @@ from torch.utils.data.datapipes.utils.decoder import (
|
||||
basichandlers as decoder_basichandlers,
|
||||
imagehandler as decoder_imagehandler)
|
||||
|
||||
|
||||
def create_temp_dir_and_files():
|
||||
# The temp dir and files within it will be released and deleted in tearDown().
|
||||
# Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
|
||||
@ -178,6 +179,38 @@ class TestIterableDataPipeBasic(TestCase):
|
||||
self.assertTrue(rec[1] == open(rec[0], 'rb').read().decode('utf-8'))
|
||||
|
||||
|
||||
def test_groupbykey_iterable_datapipe(self):
|
||||
temp_dir = self.temp_dir.name
|
||||
temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
|
||||
file_list = [
|
||||
"a.png", "b.png", "c.json", "a.json", "c.png", "b.json", "d.png",
|
||||
"d.json", "e.png", "f.json", "g.png", "f.png", "g.json", "e.json",
|
||||
"h.txt", "h.json"]
|
||||
with tarfile.open(temp_tarfile_pathname, "w:gz") as tar:
|
||||
for file_name in file_list:
|
||||
file_pathname = os.path.join(temp_dir, file_name)
|
||||
with open(file_pathname, 'w') as f:
|
||||
f.write('12345abcde')
|
||||
tar.add(file_pathname)
|
||||
|
||||
datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar')
|
||||
datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
|
||||
datapipe3 = dp.iter.ReadFilesFromTar(datapipe2)
|
||||
datapipe4 = dp.iter.GroupByKey(datapipe3, group_size=2)
|
||||
|
||||
expected_result = [("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), (
|
||||
"f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.json", "h.txt")]
|
||||
|
||||
count = 0
|
||||
for rec, expected in zip(datapipe4, expected_result):
|
||||
count = count + 1
|
||||
self.assertEqual(os.path.basename(rec[0][0]), expected[0])
|
||||
self.assertEqual(os.path.basename(rec[1][0]), expected[1])
|
||||
self.assertEqual(rec[0][1].read(), b'12345abcde')
|
||||
self.assertEqual(rec[1][1].read(), b'12345abcde')
|
||||
self.assertEqual(count, 8)
|
||||
|
||||
|
||||
class IDP_NoLen(IterDataPipe):
|
||||
def __init__(self, input_dp):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user