Files
pytorch/torch/distributed/_shard/sharding_plan/api.py
joncrall 4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00

86 lines
3.5 KiB
Python

import abc
import torch.nn as nn
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from torch.distributed._shard.sharder import Sharder
from torch.distributed._shard.sharding_spec import ShardingSpec
@dataclass
class ShardingPlan(object):
"""
Representation of a sharding plan, describes how to shard a module
across hosts. `plan` is used to shard module parameters according to the spec provided,
`output_plan` and `return_local_tensor` are optional, they are used to specify the output
layout of a module with a spec, and when to convert back to data parallel fashion.
Args:
plan (Dict[str, Union[:class:`torch.distributed._shard.sharding_spec.ShardingSpec`,
:class:`torch.distributed._shard.sharder.Sharder`]):
a dict describes how to shard a module, there're currently two ways to shard a module:
1. directly shard a module parameter by a `ShardingSpec`, keyed by the name of
a parameter to a `ShardingSpec`.
2. shard a submodule by applying a `Sharder` on it, keyed by the name of a module
to a `Sharder` object.
output_plan (Dict[str, :class:`torch.distributed._shard.sharding_spec.ShardingSpec`), optional):
a dict specifies the layout of a module's output which produces a ShardedTensor,
keyed by the name of module to ShardingSpec("" in key means the root module).
Default: `None`
return_local_tensor (List[str], optional): a list of string, each element enables
a module's sharded output to be returned as a Tensor from its local shards to
ensure further processsing in a data parallel fashion. ("" in list means the
root module).
Default: None
Example:
Suppose we want to shard a module with two linear layers and then run it with DDP, we also
want to convert the output of the second linear layer back to DDP, we can do it as follows:
>>> class MyModule(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.fc1 = nn.Linear()
>>> self.gelu = nn.GELU()
>>> self.fc2 = nn.Linear()
>>> self.relu = nn.Linear()
>>>
>>> def forward(self, input):
>>> return self.relu(self.fc2(self.gelu(self.fc1(input))))
>>> # xdoctest: +SKIP("Undefined spec1, spec2)
>>> sharding_plan = ShardingPlan(
>>> plan={
>>> "fc1.weight": spec1,
>>> "fc2.weight": spec2
>>> },
>>> output_plan={
>>> "fc2": output_spec
>>> },
>>> return_local_tensor=["fc2"]
>>> )
"""
plan: Dict[str, Union[ShardingSpec, Sharder]]
output_plan: Optional[Dict[str, ShardingSpec]] = None
return_local_tensor: Optional[List[str]] = None
class ShardingPlanner(abc.ABC):
"""
Default ShardingPlanner interface, can be extended and
implement advanced sharding strategies.
"""
@abc.abstractmethod
def build_plan(self, module: nn.Module) -> ShardingPlan:
"""
Given a nn.Module, define how to shard the module across
ranks, return a ShardingPlan
Args:
module (:class:`torch.nn.Module`):
The module to apply sharding to.
Returns:
A :class:`torch.distributed._shard.sharding_plan.ShardingPlan` object that
represents how to shard the module.
"""
pass