mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Unify the output pathname of archive reader and extractor (#65424)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65424 This PR is re-implementation for https://github.com/facebookexternal/torchdata/pull/93 Same PR has landed into torchdata https://github.com/facebookexternal/torchdata/pull/157 Test Plan: Imported from OSS Reviewed By: soulitzer Differential Revision: D31090447 Pulled By: ejguan fbshipit-source-id: 45af1ad9b24310bebfd6e010f41cff398946ba65
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e331beef20
commit
96383ca704
@ -39,6 +39,7 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
|
||||
for data in self.datapipe:
|
||||
validate_pathname_binary_tuple(data)
|
||||
pathname, data_stream = data
|
||||
folder_name = os.path.dirname(pathname)
|
||||
try:
|
||||
# typing.cast is used here to silence mypy's type checker
|
||||
tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=self.mode)
|
||||
@ -49,7 +50,7 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
|
||||
if extracted_fobj is None:
|
||||
warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
|
||||
raise tarfile.ExtractError
|
||||
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
|
||||
inner_pathname = os.path.normpath(os.path.join(folder_name, tarinfo.name))
|
||||
yield (inner_pathname, extracted_fobj) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
|
@ -36,6 +36,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
|
||||
for data in self.datapipe:
|
||||
validate_pathname_binary_tuple(data)
|
||||
pathname, data_stream = data
|
||||
folder_name = os.path.dirname(pathname)
|
||||
try:
|
||||
# typing.cast is used here to silence mypy's type checker
|
||||
zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
|
||||
@ -47,7 +48,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
|
||||
elif zipinfo.filename.endswith('/'):
|
||||
continue
|
||||
extracted_fobj = zips.open(zipinfo)
|
||||
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
|
||||
inner_pathname = os.path.normpath(os.path.join(folder_name, zipinfo.filename))
|
||||
yield (inner_pathname, extracted_fobj) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
|
Reference in New Issue
Block a user