Issue 68576 prefetch factor (#88972)

Fixes #68576
This PR allows set the `prefetch_factor=None` making it really optional according to the documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88972
Approved by: https://github.com/kit1980
This commit is contained in:
Dmitry Tomshin
2022-11-18 00:10:48 +00:00
committed by PyTorch MergeBot
parent 2b3ac879a7
commit 57e05e822d

View File

@ -217,7 +217,7 @@ class DataLoader(Generic[T_co]):
timeout: float
sampler: Union[Sampler, Iterable]
pin_memory_device: str
prefetch_factor: int
prefetch_factor: Optional[int]
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False
@ -228,7 +228,7 @@ class DataLoader(Generic[T_co]):
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
*, prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = ""):
torch._C._log_api_usage_once("python.data_loader")
@ -240,10 +240,13 @@ class DataLoader(Generic[T_co]):
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if num_workers == 0 and prefetch_factor != 2:
if num_workers == 0 and prefetch_factor is not None:
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
'let num_workers > 0 to enable multiprocessing.')
assert prefetch_factor > 0
'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.')
elif num_workers > 0 and prefetch_factor is None:
prefetch_factor = 2
elif prefetch_factor is not None and prefetch_factor < 0:
raise ValueError('prefetch_factor option should be non-negative')
if persistent_workers and num_workers == 0:
raise ValueError('persistent_workers option needs num_workers > 0')
@ -581,7 +584,6 @@ class _BaseDataLoaderIter(object):
ws, rank = _get_distributed_settings()
self._world_size = ws
self._rank = rank
self._prefetch_factor = loader.prefetch_factor
# for other backends, pin_memory_device need to set. if not set
# default behaviour is CUDA device. if pin_memory_device is selected
# and pin_memory is not set, the default behaviour false.
@ -991,6 +993,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
self._prefetch_factor = loader.prefetch_factor
assert self._num_workers > 0
assert self._prefetch_factor > 0