Files
pytorch/docs/source/notes/randomness.rst
Sam Estep 5bcbbf5373 Lint trailing newlines (#54737)
Summary:
*Context:* https://github.com/pytorch/pytorch/issues/53406 added a lint for trailing whitespace at the ends of lines. However, in order to pass FB-internal lints, that PR also had to normalize the trailing newlines in four of the files it touched. This PR adds an OSS lint to normalize trailing newlines.

The changes to the following files (made in 54847d0adb9be71be4979cead3d9d4c02160e4cd) are the only manually-written parts of this PR:

- `.github/workflows/lint.yml`
- `mypy-strict.ini`
- `tools/README.md`
- `tools/test/test_trailing_newlines.py`
- `tools/trailing_newlines.py`

I would have liked to make this just a shell one-liner like the other three similar lints, but nothing I could find quite fit the bill. Specifically, all the answers I tried from the following Stack Overflow questions were far too slow (at least a minute and a half to run on this entire repository):

- [How to detect file ends in newline?](https://stackoverflow.com/q/38746)
- [How do I find files that do not end with a newline/linefeed?](https://stackoverflow.com/q/4631068)
- [How to list all files in the Git index without newline at end of file](https://stackoverflow.com/q/27624800)
- [Linux - check if there is an empty line at the end of a file [duplicate]](https://stackoverflow.com/q/34943632)
- [git ensure newline at end of each file](https://stackoverflow.com/q/57770972)

To avoid giving false positives during the few days after this PR is merged, we should probably only merge it after https://github.com/pytorch/pytorch/issues/54967.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/54737

Test Plan:
Running the shell script from the "Ensure correct trailing newlines" step in the `quick-checks` job of `.github/workflows/lint.yml` should print no output and exit in a fraction of a second with a status of 0. That was not the case prior to this PR, as shown by this failing GHA workflow run on an earlier draft of this PR:

- https://github.com/pytorch/pytorch/runs/2197446987?check_suite_focus=true

In contrast, this run (after correcting the trailing newlines in this PR) succeeded:

- https://github.com/pytorch/pytorch/pull/54737/checks?check_run_id=2197553241

To unit-test `tools/trailing_newlines.py` itself (this is run as part of our "Test tools" GitHub Actions workflow):
```
python tools/test/test_trailing_newlines.py
```

Reviewed By: malfet

Differential Revision: D27409736

Pulled By: samestep

fbshipit-source-id: 46f565227046b39f68349bbd5633105b2d2e9b19
2021-03-30 13:09:52 -07:00

152 lines
6.3 KiB
ReStructuredText

.. _reproducibility:
Reproducibility
===============
Completely reproducible results are not guaranteed across PyTorch releases,
individual commits, or different platforms. Furthermore, results may not be
reproducible between CPU and GPU executions, even when using identical seeds.
However, there are some steps you can take to limit the number of sources of
nondeterministic behavior for a specific platform, device, and PyTorch release.
First, you can control sources of randomness that can cause multiple executions
of your application to behave differently. Second, you can configure PyTorch
to avoid using nondeterministic algorithms for some operations, so that multiple
calls to those operations, given the same inputs, will produce the same result.
.. warning::
Deterministic operations are often slower than nondeterministic operations, so
single-run performance may decrease for your model. However, determinism may
save time in development by facilitating experimentation, debugging, and
regression testing.
Controlling sources of randomness
.................................
PyTorch random number generator
-------------------------------
You can use :meth:`torch.manual_seed()` to seed the RNG for all devices (both
CPU and CUDA)::
import torch
torch.manual_seed(0)
Python
------
For custom operators, you might need to set python seed as well::
import random
random.seed(0)
Random number generators in other libraries
-------------------------------------------
If you or any of the libraries you are using rely on NumPy, you can seed the global
NumPy RNG with::
import numpy as np
np.random.seed(0)
However, some applications and libraries may use NumPy Random Generator objects,
not the global RNG
(`<https://numpy.org/doc/stable/reference/random/generator.html>`_), and those will
need to be seeded consistently as well.
If you are using any other libraries that use random number generators, refer to
the documentation for those libraries to see how to set consistent seeds for them.
CUDA convolution benchmarking
-----------------------------
The cuDNN library, used by CUDA convolution operations, can be a source of nondeterminism
across multiple executions of an application. When a cuDNN convolution is called with a
new set of size parameters, an optional feature can run multiple convolution algorithms,
benchmarking them to find the fastest one. Then, the fastest algorithm will be used
consistently during the rest of the process for the corresponding set of size parameters.
Due to benchmarking noise and different hardware, the benchmark may select different
algorithms on subsequent runs, even on the same machine.
Disabling the benchmarking feature with :code:`torch.backends.cudnn.benchmark = False`
causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced
performance.
However, if you do not need reproducibility across multiple executions of your application,
then performance might improve if the benchmarking feature is enabled with
:code:`torch.backends.cudnn.benchmark = True`.
Note that this setting is different from the :code:`torch.backends.cudnn.deterministic`
setting discussed below.
Avoiding nondeterministic algorithms
....................................
:meth:`torch.use_deterministic_algorithms` lets you configure PyTorch to use
deterministic algorithms instead of nondeterministic ones where available, and
to throw an error if an operation is known to be nondeterministic (and without
a deterministic alternative).
Please check the documentation for :meth:`torch.use_deterministic_algorithms()`
for a full list of affected operations. If an operation does not act correctly
according to the documentation, or if you need a deterministic implementation
of an operation that does not have one, please submit an issue:
`<https://github.com/pytorch/pytorch/issues?q=label:%22topic:%20determinism%22>`_
For example, running the nondeterministic CUDA implementation of :meth:`torch.Tensor.index_add_`
will throw an error::
>>> import torch
>>> torch.use_deterministic_algorithms(True)
>>> torch.randn(2, 2).cuda().index_add_(0, torch.tensor([0, 1]), torch.randn(2, 2))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set
'torch.use_deterministic_algorithms(True)'. ...
When :meth:`torch.bmm` is called with sparse-dense CUDA tensors it typically uses a
nondeterministic algorithm, but when the deterministic flag is turned on, its alternate
deterministic implementation will be used::
>>> import torch
>>> torch.use_deterministic_algorithms(True)
>>> torch.bmm(torch.randn(2, 2, 2).to_sparse().cuda(), torch.randn(2, 2, 2).cuda())
tensor([[[ 1.1900, -2.3409],
[ 0.4796, 0.8003]],
[[ 0.1509, 1.8027],
[ 0.0333, -1.1444]]], device='cuda:0')
Furthermore, if you are using CUDA tensors, and your CUDA version is 10.2 or greater, you
should set the environment variable `CUBLAS_WORKSPACE_CONFIG` according to CUDA documentation:
`<https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility>`_
CUDA convolution determinism
----------------------------
While disabling CUDA convolution benchmarking (discussed above) ensures that
CUDA selects the same algorithm each time an application is run, that algorithm
itself may be nondeterministic, unless either
:code:`torch.use_deterministic_algorithms(True)` or
:code:`torch.backends.cudnn.deterministic = True` is set. The latter setting
controls only this behavior, unlike :meth:`torch.use_deterministic_algorithms`
which will make other PyTorch operations behave deterministically, too.
CUDA RNN and LSTM
-----------------
In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior.
See :meth:`torch.nn.RNN` and :meth:`torch.nn.LSTM` for details and workarounds.
DataLoader
..........
DataLoader will reseed workers following :ref:`data-loading-randomness` algorithm.
Use :meth:`worker_init_fn` to preserve reproducibility::
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
worker_init_fn=seed_worker
)