mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
Issue 68576 prefetch factor (#88972)
Fixes #68576 This PR allows set the `prefetch_factor=None` making it really optional according to the documentation Pull Request resolved: https://github.com/pytorch/pytorch/pull/88972 Approved by: https://github.com/kit1980
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b3ac879a7
commit
57e05e822d
@ -217,7 +217,7 @@ class DataLoader(Generic[T_co]):
|
||||
timeout: float
|
||||
sampler: Union[Sampler, Iterable]
|
||||
pin_memory_device: str
|
||||
prefetch_factor: int
|
||||
prefetch_factor: Optional[int]
|
||||
_iterator : Optional['_BaseDataLoaderIter']
|
||||
__initialized = False
|
||||
|
||||
@ -228,7 +228,7 @@ class DataLoader(Generic[T_co]):
|
||||
pin_memory: bool = False, drop_last: bool = False,
|
||||
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
|
||||
multiprocessing_context=None, generator=None,
|
||||
*, prefetch_factor: int = 2,
|
||||
*, prefetch_factor: Optional[int] = None,
|
||||
persistent_workers: bool = False,
|
||||
pin_memory_device: str = ""):
|
||||
torch._C._log_api_usage_once("python.data_loader")
|
||||
@ -240,10 +240,13 @@ class DataLoader(Generic[T_co]):
|
||||
if timeout < 0:
|
||||
raise ValueError('timeout option should be non-negative')
|
||||
|
||||
if num_workers == 0 and prefetch_factor != 2:
|
||||
if num_workers == 0 and prefetch_factor is not None:
|
||||
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
|
||||
'let num_workers > 0 to enable multiprocessing.')
|
||||
assert prefetch_factor > 0
|
||||
'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.')
|
||||
elif num_workers > 0 and prefetch_factor is None:
|
||||
prefetch_factor = 2
|
||||
elif prefetch_factor is not None and prefetch_factor < 0:
|
||||
raise ValueError('prefetch_factor option should be non-negative')
|
||||
|
||||
if persistent_workers and num_workers == 0:
|
||||
raise ValueError('persistent_workers option needs num_workers > 0')
|
||||
@ -581,7 +584,6 @@ class _BaseDataLoaderIter(object):
|
||||
ws, rank = _get_distributed_settings()
|
||||
self._world_size = ws
|
||||
self._rank = rank
|
||||
self._prefetch_factor = loader.prefetch_factor
|
||||
# for other backends, pin_memory_device need to set. if not set
|
||||
# default behaviour is CUDA device. if pin_memory_device is selected
|
||||
# and pin_memory is not set, the default behaviour false.
|
||||
@ -991,6 +993,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
def __init__(self, loader):
|
||||
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
|
||||
|
||||
self._prefetch_factor = loader.prefetch_factor
|
||||
|
||||
assert self._num_workers > 0
|
||||
assert self._prefetch_factor > 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user