mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Replace references to _DataLoaderIter with _BaseDataLoaderIter (#27105)
Summary: Back in April, malmaud added type annotations for `dataloader.py`. However, at about the same time, SsnL in https://github.com/pytorch/pytorch/issues/19228 replaced `_DataLoaderIter` with `_BaseDataLoaderIter` and two subclasses, `_SingleProcessDataLoaderIter`, and `_MultiProcessingDataLoaderIter`. However - probably because these changes happened in parallel at roughly the same time, the type stubs and several other references in the codebase were never updated to match this refactoring. I've gone ahead and done the updates to reflect the refactoring in https://github.com/pytorch/pytorch/issues/19228, which fixes the specific type stub/impelementation mismatch pointed out in https://github.com/pytorch/pytorch/issues/26673, although not the broader problem that pytorch doesn't have a test to make sure that the `.pyi` type stub files match the real API defined in `.py` files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/27105 Differential Revision: D17813641 Pulled By: ezyang fbshipit-source-id: ed7ac025c8d6ad3f298dd073347ec83bb4b6600c
This commit is contained in:
		
				
					committed by
					
						
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							d57124823b
						
					
				
				
					commit
					f522bde121
				
			@ -155,7 +155,7 @@ static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *args) {
 | 
			
		||||
  }
 | 
			
		||||
  int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
 | 
			
		||||
  if (worker_pids.find(key) != worker_pids.end()) {
 | 
			
		||||
    throw ValueError("_set_worker_pids should be called only once for each _DataLoaderIter.");
 | 
			
		||||
    throw ValueError("_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
 | 
			
		||||
  }
 | 
			
		||||
  PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
 | 
			
		||||
  if (!PyTuple_Check(child_pids)) {
 | 
			
		||||
@ -182,7 +182,7 @@ static PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_i
 | 
			
		||||
  int64_t key = THPUtils_unpackLong(loader_id);
 | 
			
		||||
  auto it = worker_pids.find(key);
 | 
			
		||||
  if (it == worker_pids.end()) {
 | 
			
		||||
    throw ValueError("Cannot find worker information for _DataLoaderIter with id %ld.", key);
 | 
			
		||||
    throw ValueError("Cannot find worker information for _BaseDataLoaderIter with id %ld.", key);
 | 
			
		||||
  }
 | 
			
		||||
  worker_pids.erase(it);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
 | 
			
		||||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
 | 
			
		||||
collate samples fetched from dataset into Tensor(s).
 | 
			
		||||
 | 
			
		||||
These **needs** to be in global scope since Py2 doesn't support serializing
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
r""""Contains definitions of the methods used by the _DataLoaderIter to fetch
 | 
			
		||||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
 | 
			
		||||
data from an iterable-style or map-style dataset. This logic is shared in both
 | 
			
		||||
single- and multi-processing data loading.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
r""""Contains definitions of the methods used by the _DataLoaderIter to put
 | 
			
		||||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put
 | 
			
		||||
fetched tensors into pinned memory.
 | 
			
		||||
 | 
			
		||||
These **needs** to be in global scope since Py2 doesn't support serializing
 | 
			
		||||
 | 
			
		||||
@ -9,8 +9,8 @@ libraries users call in the workers. In this file and `DataLoader.cpp`, we make
 | 
			
		||||
our best effort to provide some error message to users when such unfortunate
 | 
			
		||||
events happen.
 | 
			
		||||
 | 
			
		||||
When a _DataLoaderIter starts worker processes, their pids are registered in a
 | 
			
		||||
defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ]
 | 
			
		||||
When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
 | 
			
		||||
defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
 | 
			
		||||
via `_set_worker_pids`.
 | 
			
		||||
 | 
			
		||||
When an error happens in a worker process, the main process received a SIGCHLD,
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
r""""Contains definitions of the methods used by the _DataLoaderIter workers.
 | 
			
		||||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
 | 
			
		||||
 | 
			
		||||
These **needs** to be in global scope since Py2 doesn't support serializing
 | 
			
		||||
static methods.
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
 | 
			
		||||
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
 | 
			
		||||
 | 
			
		||||
To support these two classes, in `./_utils` we define many utility methods and
 | 
			
		||||
functions to be run in multiprocessing. E.g., the data loading worker loop is
 | 
			
		||||
 | 
			
		||||
@ -28,12 +28,14 @@ class DataLoader(Generic[T_co]):
 | 
			
		||||
                 worker_init_fn: _worker_init_fn_t=...) -> None: ...
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int: ...
 | 
			
		||||
    # We quote '_DataLoaderIter' since it isn't defined yet and the definition can't be moved up since
 | 
			
		||||
    # '_DataLoaderIter' references 'DataLoader'. Pending updates of PEP 484 will fix this.
 | 
			
		||||
    def __iter__(self) -> '_DataLoaderIter':...
 | 
			
		||||
    # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
 | 
			
		||||
    # since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic
 | 
			
		||||
    # analyzer is used that obviates the need for this but we leave the quoting in to support older
 | 
			
		||||
    # versions of mypy
 | 
			
		||||
    def __iter__(self) -> '_BaseDataLoaderIter':...
 | 
			
		||||
 | 
			
		||||
class _DataLoaderIter:
 | 
			
		||||
class _BaseDataLoaderIter:
 | 
			
		||||
    def __init__(self, loader: DataLoader) -> None:...
 | 
			
		||||
    def __len__(self) -> int: ...
 | 
			
		||||
    def __iter__(self) -> _DataLoaderIter: ...
 | 
			
		||||
    def __iter__(self) -> _BaseDataLoaderIter: ...
 | 
			
		||||
    def __next__(self) -> Any: ...
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user