mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127846 Approved by: https://github.com/ezyang ghstack dependencies: #127842, #127843, #127844, #127845
331 lines
13 KiB
Python
331 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
|
|
|
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
static methods.
|
|
"""
|
|
|
|
import torch
|
|
import random
|
|
import os
|
|
import queue
|
|
from dataclasses import dataclass
|
|
from torch._utils import ExceptionWrapper
|
|
from typing import Optional, Union, TYPE_CHECKING
|
|
from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS, HAS_NUMPY
|
|
if TYPE_CHECKING:
|
|
from torch.utils.data import Dataset
|
|
|
|
if IS_WINDOWS:
|
|
import ctypes
|
|
from ctypes.wintypes import DWORD, BOOL, HANDLE
|
|
|
|
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
|
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
|
# of the manager and ask if the process status has changed.
|
|
class ManagerWatchdog:
|
|
def __init__(self):
|
|
self.manager_pid = os.getppid()
|
|
|
|
# mypy cannot detect this code is windows only
|
|
self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) # type: ignore[attr-defined]
|
|
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
|
self.kernel32.OpenProcess.restype = HANDLE
|
|
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
|
self.kernel32.WaitForSingleObject.restype = DWORD
|
|
|
|
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
|
SYNCHRONIZE = 0x00100000
|
|
self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
|
|
|
|
if not self.manager_handle:
|
|
raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
|
|
|
|
self.manager_dead = False
|
|
|
|
def is_alive(self):
|
|
if not self.manager_dead:
|
|
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
|
self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
|
|
return not self.manager_dead
|
|
else:
|
|
class ManagerWatchdog: # type: ignore[no-redef]
|
|
def __init__(self):
|
|
self.manager_pid = os.getppid()
|
|
self.manager_dead = False
|
|
|
|
def is_alive(self):
|
|
if not self.manager_dead:
|
|
self.manager_dead = os.getppid() != self.manager_pid
|
|
return not self.manager_dead
|
|
|
|
_worker_info: Optional["WorkerInfo"] = None
|
|
|
|
|
|
class WorkerInfo:
|
|
id: int
|
|
num_workers: int
|
|
seed: int
|
|
dataset: 'Dataset'
|
|
__initialized = False
|
|
|
|
def __init__(self, **kwargs):
|
|
for k, v in kwargs.items():
|
|
setattr(self, k, v)
|
|
self.__keys = tuple(kwargs.keys())
|
|
self.__initialized = True
|
|
|
|
def __setattr__(self, key, val):
|
|
if self.__initialized:
|
|
raise RuntimeError(f"Cannot assign attributes to {self.__class__.__name__} objects")
|
|
return super().__setattr__(key, val)
|
|
|
|
def __repr__(self):
|
|
items = []
|
|
for k in self.__keys:
|
|
items.append(f'{k}={getattr(self, k)}')
|
|
return f"{self.__class__.__name__}({', '.join(items)})"
|
|
|
|
|
|
def get_worker_info() -> Optional[WorkerInfo]:
|
|
r"""Returns the information about the current
|
|
:class:`~torch.utils.data.DataLoader` iterator worker process.
|
|
|
|
When called in a worker, this returns an object guaranteed to have the
|
|
following attributes:
|
|
|
|
* :attr:`id`: the current worker id.
|
|
* :attr:`num_workers`: the total number of workers.
|
|
* :attr:`seed`: the random seed set for the current worker. This value is
|
|
determined by main process RNG and the worker id. See
|
|
:class:`~torch.utils.data.DataLoader`'s documentation for more details.
|
|
* :attr:`dataset`: the copy of the dataset object in **this** process. Note
|
|
that this will be a different object in a different process than the one
|
|
in the main process.
|
|
|
|
When called in the main process, this returns ``None``.
|
|
|
|
.. note::
|
|
When used in a :attr:`worker_init_fn` passed over to
|
|
:class:`~torch.utils.data.DataLoader`, this method can be useful to
|
|
set up each worker process differently, for instance, using ``worker_id``
|
|
to configure the ``dataset`` object to only read a specific fraction of a
|
|
sharded dataset, or use ``seed`` to seed other libraries used in dataset
|
|
code.
|
|
"""
|
|
return _worker_info
|
|
|
|
|
|
r"""Dummy class used to signal the end of an IterableDataset"""
|
|
@dataclass(frozen=True)
|
|
class _IterableDatasetStopIteration:
|
|
worker_id: int
|
|
|
|
r"""Dummy class used to resume the fetching when worker reuse is enabled"""
|
|
@dataclass(frozen=True)
|
|
class _ResumeIteration:
|
|
seed: Optional[int] = None
|
|
|
|
# The function `_generate_state` is adapted from `numpy.random.SeedSequence`
|
|
# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
|
|
# It's MIT licensed, here is the copyright:
|
|
|
|
# Copyright (c) 2015 Melissa E. O'Neill
|
|
# Copyright (c) 2019 NumPy Developers
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in
|
|
# all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
# This function generates an array of int32 as the seed for
|
|
# `numpy.random`, in order to prevent state collision due to same
|
|
# seed and algorithm for `numpy.random` and `random` modules.
|
|
# TODO: Implement `SeedSequence` like object for `torch.random`
|
|
def _generate_state(base_seed, worker_id):
|
|
INIT_A = 0x43b0d7e5
|
|
MULT_A = 0x931e8875
|
|
INIT_B = 0x8b51f9dd
|
|
MULT_B = 0x58f38ded
|
|
MIX_MULT_L = 0xca01f9dd
|
|
MIX_MULT_R = 0x4973f715
|
|
XSHIFT = 4 * 8 // 2
|
|
MASK32 = 0xFFFFFFFF
|
|
|
|
entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
|
|
pool = [0] * 4
|
|
|
|
hash_const_A = INIT_A
|
|
|
|
def hash(value):
|
|
nonlocal hash_const_A
|
|
value = (value ^ hash_const_A) & MASK32
|
|
hash_const_A = (hash_const_A * MULT_A) & MASK32
|
|
value = (value * hash_const_A) & MASK32
|
|
value = (value ^ (value >> XSHIFT)) & MASK32
|
|
return value
|
|
|
|
def mix(x, y):
|
|
result_x = (MIX_MULT_L * x) & MASK32
|
|
result_y = (MIX_MULT_R * y) & MASK32
|
|
result = (result_x - result_y) & MASK32
|
|
result = (result ^ (result >> XSHIFT)) & MASK32
|
|
return result
|
|
|
|
# Add in the entropy to the pool.
|
|
for i in range(len(pool)):
|
|
pool[i] = hash(entropy[i])
|
|
|
|
# Mix all bits together so late bits can affect earlier bits.
|
|
for i_src in range(len(pool)):
|
|
for i_dst in range(len(pool)):
|
|
if i_src != i_dst:
|
|
pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
|
|
|
|
hash_const_B = INIT_B
|
|
state = []
|
|
for i_dst in range(4):
|
|
data_val = pool[i_dst]
|
|
data_val = (data_val ^ hash_const_B) & MASK32
|
|
hash_const_B = (hash_const_B * MULT_B) & MASK32
|
|
data_val = (data_val * hash_const_B) & MASK32
|
|
data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
|
|
state.append(data_val)
|
|
return state
|
|
|
|
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
|
|
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
|
|
num_workers, persistent_workers, shared_seed):
|
|
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
|
# logic of this function.
|
|
|
|
try:
|
|
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
|
# module's handlers are executed after Python returns from C low-level
|
|
# handlers, likely when the same fatal signal had already happened
|
|
# again.
|
|
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
|
|
signal_handling._set_worker_signal_handlers()
|
|
|
|
torch.set_num_threads(1)
|
|
seed = base_seed + worker_id
|
|
random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if HAS_NUMPY:
|
|
np_seed = _generate_state(base_seed, worker_id)
|
|
import numpy as np
|
|
np.random.seed(np_seed)
|
|
|
|
from torch.utils.data import IterDataPipe
|
|
from torch.utils.data.graph_settings import apply_random_seed
|
|
|
|
shared_rng = torch.Generator()
|
|
if isinstance(dataset, IterDataPipe):
|
|
assert shared_seed is not None
|
|
shared_rng.manual_seed(shared_seed)
|
|
dataset = apply_random_seed(dataset, shared_rng)
|
|
|
|
global _worker_info
|
|
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
|
|
seed=seed, dataset=dataset)
|
|
|
|
from torch.utils.data import _DatasetKind
|
|
|
|
init_exception = None
|
|
|
|
try:
|
|
if init_fn is not None:
|
|
init_fn(worker_id)
|
|
|
|
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
|
|
except Exception:
|
|
init_exception = ExceptionWrapper(
|
|
where=f"in DataLoader worker process {worker_id}")
|
|
|
|
# When using Iterable mode, some worker can exit earlier than others due
|
|
# to the IterableDataset behaving differently for different workers.
|
|
# When such things happen, an `_IterableDatasetStopIteration` object is
|
|
# sent over to the main process with the ID of this worker, so that the
|
|
# main process won't send more tasks to this worker, and will send
|
|
# `None` to this worker to properly exit it.
|
|
#
|
|
# Note that we cannot set `done_event` from a worker as it is shared
|
|
# among all processes. Instead, we set the `iteration_end` flag to
|
|
# signify that the iterator is exhausted. When either `done_event` or
|
|
# `iteration_end` is set, we skip all processing step and just wait for
|
|
# `None`.
|
|
iteration_end = False
|
|
|
|
watchdog = ManagerWatchdog()
|
|
|
|
while watchdog.is_alive():
|
|
try:
|
|
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
|
except queue.Empty:
|
|
continue
|
|
if isinstance(r, _ResumeIteration):
|
|
# Acknowledge the main process
|
|
data_queue.put((r, None))
|
|
iteration_end = False
|
|
|
|
if isinstance(dataset, IterDataPipe):
|
|
assert r.seed is not None
|
|
shared_rng.manual_seed(r.seed)
|
|
dataset = apply_random_seed(dataset, shared_rng)
|
|
|
|
# Recreate the fetcher for worker-reuse policy
|
|
fetcher = _DatasetKind.create_fetcher(
|
|
dataset_kind, dataset, auto_collation, collate_fn, drop_last)
|
|
continue
|
|
elif r is None:
|
|
# Received the final signal
|
|
assert done_event.is_set() or iteration_end
|
|
break
|
|
elif done_event.is_set() or iteration_end:
|
|
# `done_event` is set. But I haven't received the final signal
|
|
# (None) yet. I will keep continuing until get it, and skip the
|
|
# processing steps.
|
|
continue
|
|
idx, index = r
|
|
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
|
|
if init_exception is not None:
|
|
data = init_exception
|
|
init_exception = None
|
|
else:
|
|
try:
|
|
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
|
|
except Exception as e:
|
|
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
|
|
data = _IterableDatasetStopIteration(worker_id)
|
|
# Set `iteration_end`
|
|
# (1) to save future `next(...)` calls, and
|
|
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
|
|
iteration_end = True
|
|
else:
|
|
# It is important that we don't store exc_info in a variable.
|
|
# `ExceptionWrapper` does the correct thing.
|
|
# See NOTE [ Python Traceback Reference Cycle Problem ]
|
|
data = ExceptionWrapper(
|
|
where=f"in DataLoader worker process {worker_id}")
|
|
data_queue.put((idx, data))
|
|
del data, idx, index, r # save memory
|
|
except KeyboardInterrupt:
|
|
# Main process will raise KeyboardInterrupt anyways.
|
|
pass
|
|
if done_event.is_set():
|
|
data_queue.cancel_join_thread()
|
|
data_queue.close()
|