Replace references to _DataLoaderIter with _BaseDataLoaderIter (#27105)

Summary:
Back in April, malmaud added type annotations for `dataloader.py`. However, at about the same time, SsnL in https://github.com/pytorch/pytorch/issues/19228 replaced `_DataLoaderIter` with `_BaseDataLoaderIter` and two subclasses, `_SingleProcessDataLoaderIter`, and `_MultiProcessingDataLoaderIter`. However - probably because these changes happened in parallel at roughly the same time, the type stubs and several other references in the codebase were never updated to match this refactoring.

I've gone ahead and done the updates to reflect the refactoring in https://github.com/pytorch/pytorch/issues/19228, which fixes the specific type stub/impelementation mismatch pointed out in https://github.com/pytorch/pytorch/issues/26673, although not the broader problem that pytorch doesn't have a test to make sure that the `.pyi` type stub files match the real API defined in `.py` files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27105

Differential Revision: D17813641

Pulled By: ezyang

fbshipit-source-id: ed7ac025c8d6ad3f298dd073347ec83bb4b6600c
This commit is contained in:
Nathan Goldbaum
2019-10-08 11:49:08 -07:00
committed by Facebook Github Bot
parent d57124823b
commit f522bde121
8 changed files with 16 additions and 14 deletions

View File

@ -155,7 +155,7 @@ static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *args) {
}
int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
if (worker_pids.find(key) != worker_pids.end()) {
throw ValueError("_set_worker_pids should be called only once for each _DataLoaderIter.");
throw ValueError("_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
}
PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
if (!PyTuple_Check(child_pids)) {
@ -182,7 +182,7 @@ static PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_i
int64_t key = THPUtils_unpackLong(loader_id);
auto it = worker_pids.find(key);
if (it == worker_pids.end()) {
throw ValueError("Cannot find worker information for _DataLoaderIter with id %ld.", key);
throw ValueError("Cannot find worker information for _BaseDataLoaderIter with id %ld.", key);
}
worker_pids.erase(it);

View File

@ -1,4 +1,4 @@
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing

View File

@ -1,4 +1,4 @@
r""""Contains definitions of the methods used by the _DataLoaderIter to fetch
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
data from an iterable-style or map-style dataset. This logic is shared in both
single- and multi-processing data loading.
"""

View File

@ -1,4 +1,4 @@
r""""Contains definitions of the methods used by the _DataLoaderIter to put
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put
fetched tensors into pinned memory.
These **needs** to be in global scope since Py2 doesn't support serializing

View File

@ -9,8 +9,8 @@ libraries users call in the workers. In this file and `DataLoader.cpp`, we make
our best effort to provide some error message to users when such unfortunate
events happen.
When a _DataLoaderIter starts worker processes, their pids are registered in a
defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ]
When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
via `_set_worker_pids`.
When an error happens in a worker process, the main process received a SIGCHLD,

View File

@ -1,4 +1,4 @@
r""""Contains definitions of the methods used by the _DataLoaderIter workers.
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.

View File

@ -1,4 +1,4 @@
r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
To support these two classes, in `./_utils` we define many utility methods and
functions to be run in multiprocessing. E.g., the data loading worker loop is

View File

@ -28,12 +28,14 @@ class DataLoader(Generic[T_co]):
worker_init_fn: _worker_init_fn_t=...) -> None: ...
def __len__(self) -> int: ...
# We quote '_DataLoaderIter' since it isn't defined yet and the definition can't be moved up since
# '_DataLoaderIter' references 'DataLoader'. Pending updates of PEP 484 will fix this.
def __iter__(self) -> '_DataLoaderIter':...
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic
# analyzer is used that obviates the need for this but we leave the quoting in to support older
# versions of mypy
def __iter__(self) -> '_BaseDataLoaderIter':...
class _DataLoaderIter:
class _BaseDataLoaderIter:
def __init__(self, loader: DataLoader) -> None:...
def __len__(self) -> int: ...
def __iter__(self) -> _DataLoaderIter: ...
def __iter__(self) -> _BaseDataLoaderIter: ...
def __next__(self) -> Any: ...