diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 40ba3eb37f82..77aa8da7784a 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -199,15 +199,8 @@ the model. For example: stage_index, num_stages, device, - input_args=example_input_microbatch, ) - -The ``PipelineStage`` requires an example argument ``input_args`` representing -the runtime input to the stage, which would be one microbatch worth of input -data. This argument is passed through the forward method of the stage module to -determine the input and output shapes required for communication. - When composing with other Data or Model parallelism techniques, ``output_args`` may also be required, if the output shape/dtype of the model chunk will be affected. @@ -421,7 +414,7 @@ are subclasses of ``PipelineScheduleMulti``. Logging ******* -You can turn on additional logging using the `TORCH_LOGS` environment variable from [`torch._logging`](https://pytorch.org/docs/main/logging.html#module-torch._logging): +You can turn on additional logging using the `TORCH_LOGS` environment variable from `torch._logging `_: * `TORCH_LOGS=+pp` will display `logging.DEBUG` messages and all levels above it. * `TORCH_LOGS=pp` will display `logging.INFO` messages and above. diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 54cc11a6ae33..33703e859ce8 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1143,6 +1143,13 @@ class Pipe(torch.nn.Module): class SplitPoint(Enum): + """ + Enum representing the points at which a split can occur in the execution of a submodule. + Attributes: + BEGINNING: Represents adding a split point *before* the execution of a certain submodule in the `forward` function. + END: Represents adding a split point *after* the execution of a certain submodule in the `forward` function. + """ + BEGINNING = 1 END = 2