mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Deprecate DataLoader pin_memory_device param (#146821)"
This reverts commit ab655816b8f76f511fb2262d45276d8d1b13d59c. Reverted https://github.com/pytorch/pytorch/pull/146821 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/146821#issuecomment-3052093902))
This commit is contained in:
@ -21,7 +21,16 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
torch.multiprocessing._set_thread_name("pt_data_pin")
|
torch.multiprocessing._set_thread_name("pt_data_pin")
|
||||||
torch.accelerator.set_device_index(device_id)
|
|
||||||
|
if device == "cuda":
|
||||||
|
torch.cuda.set_device(device_id)
|
||||||
|
elif device == "xpu":
|
||||||
|
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
|
||||||
|
elif device == torch._C._get_privateuse1_backend_name():
|
||||||
|
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
||||||
|
custom_device_mod.set_device(device_id)
|
||||||
|
elif device is None:
|
||||||
|
torch.accelerator.set_device_index(device_id)
|
||||||
|
|
||||||
def do_one_step():
|
def do_one_step():
|
||||||
try:
|
try:
|
||||||
|
@ -190,8 +190,9 @@ class DataLoader(Generic[_T_co]):
|
|||||||
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
|
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
|
||||||
the worker processes after a dataset has been consumed once. This allows to
|
the worker processes after a dataset has been consumed once. This allows to
|
||||||
maintain the workers `Dataset` instances alive. (default: ``False``)
|
maintain the workers `Dataset` instances alive. (default: ``False``)
|
||||||
pin_memory_device (str, optional): Deprecated, the current :ref:`accelerator<accelerators>`
|
pin_memory_device (str, optional): the device to :attr:`pin_memory` on if ``pin_memory`` is
|
||||||
will be used as the device if ``pin_memory=True``.
|
``True``. If not given, the current :ref:`accelerator<accelerators>` will be the
|
||||||
|
default. This argument is discouraged and subject to deprecated.
|
||||||
in_order (bool, optional): If ``False``, the data loader will not enforce that batches
|
in_order (bool, optional): If ``False``, the data loader will not enforce that batches
|
||||||
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
|
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
|
||||||
|
|
||||||
@ -286,39 +287,8 @@ class DataLoader(Generic[_T_co]):
|
|||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.prefetch_factor = prefetch_factor
|
self.prefetch_factor = prefetch_factor
|
||||||
|
self.pin_memory = pin_memory
|
||||||
if pin_memory and pin_memory_device:
|
self.pin_memory_device = pin_memory_device
|
||||||
warnings.warn(
|
|
||||||
"pin_memory_device is deprecated, the current accelerator will be used as the device,"
|
|
||||||
f"ignore pin_memory_device='{pin_memory_device}'."
|
|
||||||
)
|
|
||||||
if pin_memory and not torch.accelerator.is_available():
|
|
||||||
warn_msg = (
|
|
||||||
"'pin_memory' argument is set as true but no accelerator is found, "
|
|
||||||
"then device pinned memory won't be used."
|
|
||||||
)
|
|
||||||
warnings.warn(warn_msg)
|
|
||||||
|
|
||||||
self.pin_memory = pin_memory and torch.accelerator.is_available()
|
|
||||||
self.pin_memory_device = (
|
|
||||||
acc.type
|
|
||||||
if self.pin_memory
|
|
||||||
and (acc := torch.accelerator.current_accelerator()) is not None
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Currently, pin_memory would raise error on the MPS backend (see
|
|
||||||
# https://github.com/pytorch/pytorch/issues/86060), so forcibly
|
|
||||||
# disable pin_memory on MPS. Remove this restriction once pinned
|
|
||||||
# memory allocation for MPS is fixed.
|
|
||||||
if self.pin_memory_device == "mps":
|
|
||||||
self.pin_memory = False
|
|
||||||
warn_msg = (
|
|
||||||
"'pin_memory' argument is set as true but not supported on MPS now, "
|
|
||||||
"then device pinned memory won't be used."
|
|
||||||
)
|
|
||||||
warnings.warn(warn_msg)
|
|
||||||
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.worker_init_fn = worker_init_fn
|
self.worker_init_fn = worker_init_fn
|
||||||
self.multiprocessing_context = multiprocessing_context
|
self.multiprocessing_context = multiprocessing_context
|
||||||
@ -684,10 +654,45 @@ class _BaseDataLoaderIter:
|
|||||||
ws, rank = _get_distributed_settings()
|
ws, rank = _get_distributed_settings()
|
||||||
self._world_size = ws
|
self._world_size = ws
|
||||||
self._rank = rank
|
self._rank = rank
|
||||||
self._pin_memory = loader.pin_memory
|
# If pin_memory_device not set, default behaviour is current accelerator.
|
||||||
self._pin_memory_device = (
|
# If pin_memory_device is set but pin_memory is not set, the default
|
||||||
None if len(loader.pin_memory_device) == 0 else loader.pin_memory_device
|
# behaviour false.
|
||||||
)
|
if len(loader.pin_memory_device) == 0:
|
||||||
|
if loader.pin_memory and not torch.accelerator.is_available():
|
||||||
|
warn_msg = (
|
||||||
|
"'pin_memory' argument is set as true but no accelerator is found, "
|
||||||
|
"then device pinned memory won't be used."
|
||||||
|
)
|
||||||
|
warnings.warn(warn_msg)
|
||||||
|
|
||||||
|
self._pin_memory = loader.pin_memory and torch.accelerator.is_available()
|
||||||
|
self._pin_memory_device = None
|
||||||
|
# Currently, pin_memory would raise error on the MPS backend (see
|
||||||
|
# https://github.com/pytorch/pytorch/issues/86060), so forcibly
|
||||||
|
# disable pin_memory on MPS. Remove this restriction once pinned
|
||||||
|
# memory allocation for MPS is fixed.
|
||||||
|
if (
|
||||||
|
self._pin_memory
|
||||||
|
and (acc := torch.accelerator.current_accelerator()) is not None
|
||||||
|
and acc.type == "mps"
|
||||||
|
):
|
||||||
|
self._pin_memory = False
|
||||||
|
warn_msg = (
|
||||||
|
"'pin_memory' argument is set as true but not supported on MPS now, "
|
||||||
|
"then device pinned memory won't be used."
|
||||||
|
)
|
||||||
|
warnings.warn(warn_msg)
|
||||||
|
else:
|
||||||
|
if not loader.pin_memory:
|
||||||
|
warn_msg = (
|
||||||
|
"'pin_memory_device' is set but 'pin_memory' argument is not set, "
|
||||||
|
"then device pinned memory won't be used."
|
||||||
|
"please set 'pin_memory' to true, if you need to use the device pin memory"
|
||||||
|
)
|
||||||
|
warnings.warn(warn_msg)
|
||||||
|
|
||||||
|
self._pin_memory = loader.pin_memory
|
||||||
|
self._pin_memory_device = loader.pin_memory_device
|
||||||
self._timeout = loader.timeout
|
self._timeout = loader.timeout
|
||||||
self._collate_fn = loader.collate_fn
|
self._collate_fn = loader.collate_fn
|
||||||
self._sampler_iter = iter(self._index_sampler)
|
self._sampler_iter = iter(self._index_sampler)
|
||||||
@ -1173,7 +1178,18 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||||||
|
|
||||||
# Queue is not type-annotated
|
# Queue is not type-annotated
|
||||||
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
||||||
current_device = torch.accelerator.current_device_index()
|
current_device = -1
|
||||||
|
if self._pin_memory_device == "cuda":
|
||||||
|
current_device = torch.cuda.current_device()
|
||||||
|
elif self._pin_memory_device == "xpu":
|
||||||
|
current_device = torch.xpu.current_device()
|
||||||
|
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
|
||||||
|
custom_device_mod = getattr(
|
||||||
|
torch, torch._C._get_privateuse1_backend_name()
|
||||||
|
)
|
||||||
|
current_device = custom_device_mod.current_device()
|
||||||
|
elif self._pin_memory_device is None:
|
||||||
|
current_device = torch.accelerator.current_device_index()
|
||||||
pin_memory_thread = threading.Thread(
|
pin_memory_thread = threading.Thread(
|
||||||
target=_utils.pin_memory._pin_memory_loop,
|
target=_utils.pin_memory._pin_memory_loop,
|
||||||
args=(
|
args=(
|
||||||
|
Reference in New Issue
Block a user