mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51488 Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D26209888 Pulled By: ejguan fbshipit-source-id: cb8bc852b1e4d72be81e0297308a43954cd95332
		
			
				
	
	
		
			51 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			51 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import tempfile
 | 
						|
import warnings
 | 
						|
 | 
						|
from torch.testing._internal.common_utils import (TestCase, run_tests)
 | 
						|
from torch.utils.data.datasets import ListDirFilesIterableDataset, LoadFilesFromDiskIterableDataset
 | 
						|
 | 
						|
 | 
						|
def create_temp_dir_and_files():
 | 
						|
    # The temp dir and files within it will be released and deleted in tearDown().
 | 
						|
    # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
 | 
						|
    temp_dir = tempfile.TemporaryDirectory()  # noqa: P201
 | 
						|
    temp_dir_path = temp_dir.name
 | 
						|
    temp_file1 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False)  # noqa: P201
 | 
						|
    temp_file2 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False)  # noqa: P201
 | 
						|
    temp_file3 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False)  # noqa: P201
 | 
						|
 | 
						|
    return (temp_dir, temp_file1.name, temp_file2.name, temp_file3.name)
 | 
						|
 | 
						|
 | 
						|
class TestIterableDatasetBasic(TestCase):
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        ret = create_temp_dir_and_files()
 | 
						|
        self.temp_dir = ret[0]
 | 
						|
        self.temp_files = ret[1:]
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        try:
 | 
						|
            self.temp_dir.cleanup()
 | 
						|
        except Exception as e:
 | 
						|
            warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e)))
 | 
						|
 | 
						|
    def test_listdirfiles_iterable_dataset(self):
 | 
						|
        temp_dir = self.temp_dir.name
 | 
						|
        dataset = ListDirFilesIterableDataset(temp_dir, '')
 | 
						|
        for pathname in dataset:
 | 
						|
            self.assertTrue(pathname in self.temp_files)
 | 
						|
 | 
						|
    def test_loadfilesfromdisk_iterable_dataset(self):
 | 
						|
        temp_dir = self.temp_dir.name
 | 
						|
        dataset1 = ListDirFilesIterableDataset(temp_dir, '')
 | 
						|
        dataset2 = LoadFilesFromDiskIterableDataset(dataset1)
 | 
						|
 | 
						|
        for rec in dataset2:
 | 
						|
            self.assertTrue(rec[0] in self.temp_files)
 | 
						|
            self.assertTrue(rec[1].read() == open(rec[0], 'rb').read())
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    run_tests()
 |