mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Deprecate DataLoader pin_memory_device param (#146821)"
This reverts commit ab655816b8f76f511fb2262d45276d8d1b13d59c. Reverted https://github.com/pytorch/pytorch/pull/146821 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/146821#issuecomment-3052093902))
This commit is contained in:
@ -21,7 +21,16 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
|
||||
torch.set_num_threads(1)
|
||||
|
||||
torch.multiprocessing._set_thread_name("pt_data_pin")
|
||||
torch.accelerator.set_device_index(device_id)
|
||||
|
||||
if device == "cuda":
|
||||
torch.cuda.set_device(device_id)
|
||||
elif device == "xpu":
|
||||
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
|
||||
elif device == torch._C._get_privateuse1_backend_name():
|
||||
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
|
||||
custom_device_mod.set_device(device_id)
|
||||
elif device is None:
|
||||
torch.accelerator.set_device_index(device_id)
|
||||
|
||||
def do_one_step():
|
||||
try:
|
||||
|
@ -190,8 +190,9 @@ class DataLoader(Generic[_T_co]):
|
||||
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
|
||||
the worker processes after a dataset has been consumed once. This allows to
|
||||
maintain the workers `Dataset` instances alive. (default: ``False``)
|
||||
pin_memory_device (str, optional): Deprecated, the current :ref:`accelerator<accelerators>`
|
||||
will be used as the device if ``pin_memory=True``.
|
||||
pin_memory_device (str, optional): the device to :attr:`pin_memory` on if ``pin_memory`` is
|
||||
``True``. If not given, the current :ref:`accelerator<accelerators>` will be the
|
||||
default. This argument is discouraged and subject to deprecated.
|
||||
in_order (bool, optional): If ``False``, the data loader will not enforce that batches
|
||||
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
|
||||
|
||||
@ -286,39 +287,8 @@ class DataLoader(Generic[_T_co]):
|
||||
self.dataset = dataset
|
||||
self.num_workers = num_workers
|
||||
self.prefetch_factor = prefetch_factor
|
||||
|
||||
if pin_memory and pin_memory_device:
|
||||
warnings.warn(
|
||||
"pin_memory_device is deprecated, the current accelerator will be used as the device,"
|
||||
f"ignore pin_memory_device='{pin_memory_device}'."
|
||||
)
|
||||
if pin_memory and not torch.accelerator.is_available():
|
||||
warn_msg = (
|
||||
"'pin_memory' argument is set as true but no accelerator is found, "
|
||||
"then device pinned memory won't be used."
|
||||
)
|
||||
warnings.warn(warn_msg)
|
||||
|
||||
self.pin_memory = pin_memory and torch.accelerator.is_available()
|
||||
self.pin_memory_device = (
|
||||
acc.type
|
||||
if self.pin_memory
|
||||
and (acc := torch.accelerator.current_accelerator()) is not None
|
||||
else ""
|
||||
)
|
||||
|
||||
# Currently, pin_memory would raise error on the MPS backend (see
|
||||
# https://github.com/pytorch/pytorch/issues/86060), so forcibly
|
||||
# disable pin_memory on MPS. Remove this restriction once pinned
|
||||
# memory allocation for MPS is fixed.
|
||||
if self.pin_memory_device == "mps":
|
||||
self.pin_memory = False
|
||||
warn_msg = (
|
||||
"'pin_memory' argument is set as true but not supported on MPS now, "
|
||||
"then device pinned memory won't be used."
|
||||
)
|
||||
warnings.warn(warn_msg)
|
||||
|
||||
self.pin_memory = pin_memory
|
||||
self.pin_memory_device = pin_memory_device
|
||||
self.timeout = timeout
|
||||
self.worker_init_fn = worker_init_fn
|
||||
self.multiprocessing_context = multiprocessing_context
|
||||
@ -684,10 +654,45 @@ class _BaseDataLoaderIter:
|
||||
ws, rank = _get_distributed_settings()
|
||||
self._world_size = ws
|
||||
self._rank = rank
|
||||
self._pin_memory = loader.pin_memory
|
||||
self._pin_memory_device = (
|
||||
None if len(loader.pin_memory_device) == 0 else loader.pin_memory_device
|
||||
)
|
||||
# If pin_memory_device not set, default behaviour is current accelerator.
|
||||
# If pin_memory_device is set but pin_memory is not set, the default
|
||||
# behaviour false.
|
||||
if len(loader.pin_memory_device) == 0:
|
||||
if loader.pin_memory and not torch.accelerator.is_available():
|
||||
warn_msg = (
|
||||
"'pin_memory' argument is set as true but no accelerator is found, "
|
||||
"then device pinned memory won't be used."
|
||||
)
|
||||
warnings.warn(warn_msg)
|
||||
|
||||
self._pin_memory = loader.pin_memory and torch.accelerator.is_available()
|
||||
self._pin_memory_device = None
|
||||
# Currently, pin_memory would raise error on the MPS backend (see
|
||||
# https://github.com/pytorch/pytorch/issues/86060), so forcibly
|
||||
# disable pin_memory on MPS. Remove this restriction once pinned
|
||||
# memory allocation for MPS is fixed.
|
||||
if (
|
||||
self._pin_memory
|
||||
and (acc := torch.accelerator.current_accelerator()) is not None
|
||||
and acc.type == "mps"
|
||||
):
|
||||
self._pin_memory = False
|
||||
warn_msg = (
|
||||
"'pin_memory' argument is set as true but not supported on MPS now, "
|
||||
"then device pinned memory won't be used."
|
||||
)
|
||||
warnings.warn(warn_msg)
|
||||
else:
|
||||
if not loader.pin_memory:
|
||||
warn_msg = (
|
||||
"'pin_memory_device' is set but 'pin_memory' argument is not set, "
|
||||
"then device pinned memory won't be used."
|
||||
"please set 'pin_memory' to true, if you need to use the device pin memory"
|
||||
)
|
||||
warnings.warn(warn_msg)
|
||||
|
||||
self._pin_memory = loader.pin_memory
|
||||
self._pin_memory_device = loader.pin_memory_device
|
||||
self._timeout = loader.timeout
|
||||
self._collate_fn = loader.collate_fn
|
||||
self._sampler_iter = iter(self._index_sampler)
|
||||
@ -1173,7 +1178,18 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
|
||||
# Queue is not type-annotated
|
||||
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
||||
current_device = torch.accelerator.current_device_index()
|
||||
current_device = -1
|
||||
if self._pin_memory_device == "cuda":
|
||||
current_device = torch.cuda.current_device()
|
||||
elif self._pin_memory_device == "xpu":
|
||||
current_device = torch.xpu.current_device()
|
||||
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
|
||||
custom_device_mod = getattr(
|
||||
torch, torch._C._get_privateuse1_backend_name()
|
||||
)
|
||||
current_device = custom_device_mod.current_device()
|
||||
elif self._pin_memory_device is None:
|
||||
current_device = torch.accelerator.current_device_index()
|
||||
pin_memory_thread = threading.Thread(
|
||||
target=_utils.pin_memory._pin_memory_loop,
|
||||
args=(
|
||||
|
Reference in New Issue
Block a user