Revert "Deprecate DataLoader pin_memory_device param (#146821)"

This reverts commit ab655816b8f76f511fb2262d45276d8d1b13d59c.

Reverted https://github.com/pytorch/pytorch/pull/146821 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/146821#issuecomment-3052093902))
This commit is contained in:
PyTorch MergeBot
2025-07-09 10:29:31 +00:00
parent 6f23f53599
commit b83d8827bc
2 changed files with 66 additions and 41 deletions

View File

@ -21,7 +21,16 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
torch.set_num_threads(1) torch.set_num_threads(1)
torch.multiprocessing._set_thread_name("pt_data_pin") torch.multiprocessing._set_thread_name("pt_data_pin")
torch.accelerator.set_device_index(device_id)
if device == "cuda":
torch.cuda.set_device(device_id)
elif device == "xpu":
torch.xpu.set_device(device_id) # type: ignore[attr-defined]
elif device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
custom_device_mod.set_device(device_id)
elif device is None:
torch.accelerator.set_device_index(device_id)
def do_one_step(): def do_one_step():
try: try:

View File

@ -190,8 +190,9 @@ class DataLoader(Generic[_T_co]):
persistent_workers (bool, optional): If ``True``, the data loader will not shut down persistent_workers (bool, optional): If ``True``, the data loader will not shut down
the worker processes after a dataset has been consumed once. This allows to the worker processes after a dataset has been consumed once. This allows to
maintain the workers `Dataset` instances alive. (default: ``False``) maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): Deprecated, the current :ref:`accelerator<accelerators>` pin_memory_device (str, optional): the device to :attr:`pin_memory` on if ``pin_memory`` is
will be used as the device if ``pin_memory=True``. ``True``. If not given, the current :ref:`accelerator<accelerators>` will be the
default. This argument is discouraged and subject to deprecated.
in_order (bool, optional): If ``False``, the data loader will not enforce that batches in_order (bool, optional): If ``False``, the data loader will not enforce that batches
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
@ -286,39 +287,8 @@ class DataLoader(Generic[_T_co]):
self.dataset = dataset self.dataset = dataset
self.num_workers = num_workers self.num_workers = num_workers
self.prefetch_factor = prefetch_factor self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
if pin_memory and pin_memory_device: self.pin_memory_device = pin_memory_device
warnings.warn(
"pin_memory_device is deprecated, the current accelerator will be used as the device,"
f"ignore pin_memory_device='{pin_memory_device}'."
)
if pin_memory and not torch.accelerator.is_available():
warn_msg = (
"'pin_memory' argument is set as true but no accelerator is found, "
"then device pinned memory won't be used."
)
warnings.warn(warn_msg)
self.pin_memory = pin_memory and torch.accelerator.is_available()
self.pin_memory_device = (
acc.type
if self.pin_memory
and (acc := torch.accelerator.current_accelerator()) is not None
else ""
)
# Currently, pin_memory would raise error on the MPS backend (see
# https://github.com/pytorch/pytorch/issues/86060), so forcibly
# disable pin_memory on MPS. Remove this restriction once pinned
# memory allocation for MPS is fixed.
if self.pin_memory_device == "mps":
self.pin_memory = False
warn_msg = (
"'pin_memory' argument is set as true but not supported on MPS now, "
"then device pinned memory won't be used."
)
warnings.warn(warn_msg)
self.timeout = timeout self.timeout = timeout
self.worker_init_fn = worker_init_fn self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context self.multiprocessing_context = multiprocessing_context
@ -684,10 +654,45 @@ class _BaseDataLoaderIter:
ws, rank = _get_distributed_settings() ws, rank = _get_distributed_settings()
self._world_size = ws self._world_size = ws
self._rank = rank self._rank = rank
self._pin_memory = loader.pin_memory # If pin_memory_device not set, default behaviour is current accelerator.
self._pin_memory_device = ( # If pin_memory_device is set but pin_memory is not set, the default
None if len(loader.pin_memory_device) == 0 else loader.pin_memory_device # behaviour false.
) if len(loader.pin_memory_device) == 0:
if loader.pin_memory and not torch.accelerator.is_available():
warn_msg = (
"'pin_memory' argument is set as true but no accelerator is found, "
"then device pinned memory won't be used."
)
warnings.warn(warn_msg)
self._pin_memory = loader.pin_memory and torch.accelerator.is_available()
self._pin_memory_device = None
# Currently, pin_memory would raise error on the MPS backend (see
# https://github.com/pytorch/pytorch/issues/86060), so forcibly
# disable pin_memory on MPS. Remove this restriction once pinned
# memory allocation for MPS is fixed.
if (
self._pin_memory
and (acc := torch.accelerator.current_accelerator()) is not None
and acc.type == "mps"
):
self._pin_memory = False
warn_msg = (
"'pin_memory' argument is set as true but not supported on MPS now, "
"then device pinned memory won't be used."
)
warnings.warn(warn_msg)
else:
if not loader.pin_memory:
warn_msg = (
"'pin_memory_device' is set but 'pin_memory' argument is not set, "
"then device pinned memory won't be used."
"please set 'pin_memory' to true, if you need to use the device pin memory"
)
warnings.warn(warn_msg)
self._pin_memory = loader.pin_memory
self._pin_memory_device = loader.pin_memory_device
self._timeout = loader.timeout self._timeout = loader.timeout
self._collate_fn = loader.collate_fn self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler) self._sampler_iter = iter(self._index_sampler)
@ -1173,7 +1178,18 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# Queue is not type-annotated # Queue is not type-annotated
self._data_queue = queue.Queue() # type: ignore[var-annotated] self._data_queue = queue.Queue() # type: ignore[var-annotated]
current_device = torch.accelerator.current_device_index() current_device = -1
if self._pin_memory_device == "cuda":
current_device = torch.cuda.current_device()
elif self._pin_memory_device == "xpu":
current_device = torch.xpu.current_device()
elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
custom_device_mod = getattr(
torch, torch._C._get_privateuse1_backend_name()
)
current_device = custom_device_mod.current_device()
elif self._pin_memory_device is None:
current_device = torch.accelerator.current_device_index()
pin_memory_thread = threading.Thread( pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop, target=_utils.pin_memory._pin_memory_loop,
args=( args=(