Files
pytorch/docs/source/elastic/train_script.rst
Can Balioglu 65e6194aeb Introduce the torchrun entrypoint (#64049)
Summary:
This PR introduces a new `torchrun` entrypoint that simply "points" to `python -m torch.distributed.run`. It is shorter and less error-prone to type and gives a nicer syntax than a rather cryptic `python -m ...` command line. Along with the new entrypoint the documentation is also updated and places where `torch.distributed.run` are mentioned are replaced with `torchrun`.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse agolynski SciPioneer H-Huang mrzzd cbalioglu gcramer23

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64049

Reviewed By: cbalioglu

Differential Revision: D30584041

Pulled By: kiukchung

fbshipit-source-id: d99db3b5d12e7bf9676bab70e680d4b88031ae2d
2021-08-26 20:17:48 -07:00

51 lines
1.9 KiB
ReStructuredText

.. _elastic_train_script:
Train script
-------------
If your train script works with ``torch.distributed.launch`` it will continue
working with ``torchrun`` with these differences:
1. No need to manually pass ``RANK``, ``WORLD_SIZE``,
``MASTER_ADDR``, and ``MASTER_PORT``.
2. ``rdzv_backend`` and ``rdzv_endpoint`` can be provided. For most users
this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_). The default
``rdzv_backend`` creates a non-elastic rendezvous where ``rdzv_endpoint`` holds
the master address.
3. Make sure you have a ``load_checkpoint(path)`` and
``save_checkpoint(path)`` logic in your script. When any number of
workers fail we restart all the workers with the same program
arguments so you will lose progress up to the most recent checkpoint
(see `elastic launch <run.html>`_).
4. ``use_env`` flag has been removed. If you were parsing local rank by parsing
the ``--local_rank`` option, you need to get the local rank from the
environment variable ``LOCAL_RANK`` (e.g. ``int(os.environ["LOCAL_RANK"])``).
Below is an expository example of a training script that checkpoints on each
epoch, hence the worst-case progress lost on failure is one full epoch worth
of training.
.. code-block:: python
def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
initialize(state)
# torch.distributed.run ensures that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
save_checkpoint(state)
For concrete examples of torchelastic-compliant train scripts, visit
our `examples <examples.html>`_ page.