Unify the output pathname of archive reader and extractor (#65424)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65424

This PR is re-implementation for https://github.com/facebookexternal/torchdata/pull/93
Same PR has landed into torchdata https://github.com/facebookexternal/torchdata/pull/157

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D31090447

Pulled By: ejguan

fbshipit-source-id: 45af1ad9b24310bebfd6e010f41cff398946ba65
This commit is contained in:
Erjia Guan
2021-09-22 06:32:58 -07:00
committed by Facebook GitHub Bot
parent e331beef20
commit 96383ca704
2 changed files with 4 additions and 2 deletions

View File

@ -39,6 +39,7 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
for data in self.datapipe:
validate_pathname_binary_tuple(data)
pathname, data_stream = data
folder_name = os.path.dirname(pathname)
try:
# typing.cast is used here to silence mypy's type checker
tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=self.mode)
@ -49,7 +50,7 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
if extracted_fobj is None:
warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
raise tarfile.ExtractError
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
inner_pathname = os.path.normpath(os.path.join(folder_name, tarinfo.name))
yield (inner_pathname, extracted_fobj) # type: ignore[misc]
except Exception as e:
warnings.warn(

View File

@ -36,6 +36,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
for data in self.datapipe:
validate_pathname_binary_tuple(data)
pathname, data_stream = data
folder_name = os.path.dirname(pathname)
try:
# typing.cast is used here to silence mypy's type checker
zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
@ -47,7 +48,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
elif zipinfo.filename.endswith('/'):
continue
extracted_fobj = zips.open(zipinfo)
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
inner_pathname = os.path.normpath(os.path.join(folder_name, zipinfo.filename))
yield (inner_pathname, extracted_fobj) # type: ignore[misc]
except Exception as e:
warnings.warn(