mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] add support for fused optimizer with parameters across multiple meshes (#157682)
We are seeing more and more use cases where parameters in a model (under the same optimizer group) are put on different meshes. E.g. - when FSDP and TP are both applied, some parameters are sharded only on the FSDP mesh but not TP mesh (see https://github.com/pytorch/pytorch/pull/153268). - in [dp2ep Expert Parallel](https://github.com/pytorch/torchtitan/pull/1324), the routed experts are sharded on the (global FSDP \ EP) mesh for smaller FSDP and on the EP mesh for EP, whereas other params are sharded on the global FSDP mesh for FSDP. This PR is, in some sense, a continuation of https://github.com/pytorch/pytorch/pull/147869 to tackle the problem when fused optimizers are used. In such cases, the [`fused_adam`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L15786) / `fused_adamw` has a scalar tensor arg `state_steps` which gets automatically cast to DTensor on the default [`compute_mesh`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_dispatch.py#L350) (one of the multiple meshes), even though the it could correspond to different meshes. To avoid hitting the cross-mesh propagation exception in `common_pointwise_strategy` and followup redistribute problems, we manually set the target mesh and placements to be the same as input mesh and placements, so that no redistribute will be triggered. This also helps bypass the situation where [`generate_redistribute_costs`](https://github.com/pytorch/pytorch/pull/157682/files#diff-eea32a36dd2d4e58307bc5229402e48048b2ecaef64a7c085495fba1ee10ac89R597) returns infinite cost due to cross mesh redistribute. Moreover, this PR has minimal scope (restricted to the `fused_ops`) and doesn't need to modify other files such as `_sharding_prop.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157682 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
777eca9f16
commit
ed911747c2
@ -9,6 +9,7 @@ from torch.distributed.tensor import (
|
||||
distribute_module,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
init_device_mesh,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
@ -606,6 +607,115 @@ class TestDTensorOptimizer(DTensorTestBase):
|
||||
mesh, mod, opt, dist_mod, dist_opt, inp, atol=1.3e-5, rtol=1e-4
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_admaw_fused_across_meshes(self):
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=("x", "y")
|
||||
)
|
||||
mesh_flatten = mesh_2d[("x", "y")]._flatten(mesh_dim_name="mesh_flatten")
|
||||
|
||||
# lr as a Tensor is not supported for capturable=False and foreach=True
|
||||
adamw_float_lr_configs = [
|
||||
{"lr": 0.1, "foreach": False},
|
||||
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
|
||||
{"lr": 0.1, "weight_decay": 0.05},
|
||||
{
|
||||
"lr": 0.1,
|
||||
"betas": (0.6, 0.66),
|
||||
"eps": 1e-6,
|
||||
"weight_decay": 0.05,
|
||||
"amsgrad": True,
|
||||
},
|
||||
{
|
||||
"lr": 0.1,
|
||||
"betas": (0.6, 0.66),
|
||||
"eps": 1e-6,
|
||||
"weight_decay": 0.05,
|
||||
"maximize": True,
|
||||
"amsgrad": True,
|
||||
},
|
||||
]
|
||||
fused_adamw_float_lr_configs = [
|
||||
{"lr": 0.1, "weight_decay": 0.05, "fused": True},
|
||||
{
|
||||
"lr": 0.1,
|
||||
"betas": (0.6, 0.66),
|
||||
"eps": 1e-6,
|
||||
"weight_decay": 0.05,
|
||||
"amsgrad": True,
|
||||
"fused": True,
|
||||
},
|
||||
{
|
||||
"lr": 0.1,
|
||||
"betas": (0.6, 0.66),
|
||||
"eps": 1e-6,
|
||||
"weight_decay": 0.05,
|
||||
"maximize": True,
|
||||
"amsgrad": True,
|
||||
"fused": True,
|
||||
},
|
||||
]
|
||||
# lr could be a Tensor or a float when fused=True for adamW optimizer
|
||||
fused_adamw_tensor_lr_configs = [
|
||||
{**config, "lr": torch.tensor(0.1)}
|
||||
for config in fused_adamw_float_lr_configs
|
||||
]
|
||||
fused_adamw_tensor_lr_configs.extend(
|
||||
[
|
||||
{**config, "lr": torch.tensor([0.1])}
|
||||
for config in fused_adamw_float_lr_configs
|
||||
]
|
||||
)
|
||||
adamw_configs = [
|
||||
*adamw_float_lr_configs,
|
||||
*fused_adamw_float_lr_configs,
|
||||
*fused_adamw_tensor_lr_configs,
|
||||
]
|
||||
|
||||
# shard function to do full sharding on all parameters of a module
|
||||
def _shard_fn_2d(name, module, device_mesh):
|
||||
if isinstance(module, nn.Linear):
|
||||
for name, param in module.named_parameters():
|
||||
dist_param = torch.nn.Parameter(
|
||||
distribute_tensor(param, device_mesh, [Replicate(), Shard(0)])
|
||||
)
|
||||
# make sure partial sum get cleared after backward()
|
||||
dist_param.register_hook(
|
||||
lambda grad: grad.redistribute(
|
||||
placements=[Replicate(), Shard(0)]
|
||||
)
|
||||
)
|
||||
module.register_parameter(name, dist_param)
|
||||
|
||||
# prepare input
|
||||
def _input_fn_2d(mod, inputs, device_mesh):
|
||||
# split the input tensor to be sharded input on a 2d mesh
|
||||
dist_inp = DTensor.from_local(
|
||||
inputs[0], device_mesh, [Replicate(), Shard(0)], run_check=False
|
||||
)
|
||||
return dist_inp
|
||||
|
||||
for config in adamw_configs:
|
||||
mod = MLPModule(self.device_type)
|
||||
opt = torch.optim.AdamW(mod.parameters(), **config)
|
||||
|
||||
mod_copy = deepcopy(mod)
|
||||
# MLPModule.net1 is sharded on the flatten mesh
|
||||
distribute_module(
|
||||
mod_copy.net1, mesh_flatten, shard_fn, input_fn, output_fn
|
||||
)
|
||||
# MLPModule.net2 is sharded on the 2d mesh
|
||||
distribute_module(
|
||||
mod_copy.net2, mesh_2d, _shard_fn_2d, _input_fn_2d, output_fn
|
||||
)
|
||||
dist_opt = torch.optim.AdamW(mod_copy.parameters(), **config)
|
||||
|
||||
# use ones to make sure the single machine model have the same input
|
||||
# on different ranks
|
||||
inp = torch.ones(8, 10, device=self.device_type)
|
||||
self._assert_optimizer(None, mod, opt, mod_copy, dist_opt, inp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
@ -25,6 +25,7 @@ from torch.distributed.tensor.placement_types import (
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -485,6 +486,7 @@ def common_pointwise_strategy(
|
||||
followed_strategy: OpStrategy,
|
||||
followed_strategy_index: int,
|
||||
linearity: int = -1,
|
||||
scalar_tensor_idx: Optional[int] = None,
|
||||
) -> OpStrategy:
|
||||
"""
|
||||
Common strategy for pointwise operations.
|
||||
@ -501,6 +503,8 @@ def common_pointwise_strategy(
|
||||
2: the binary operation supports multiplicative linearity, where it requires
|
||||
the primary operand to be partial, and the other operands to be replicate,
|
||||
output propagates partial.
|
||||
scalar_tensor_idx: Index of the Replicate scalar tensor for which we allow the mesh
|
||||
to be different from the mesh of followed_strategy
|
||||
"""
|
||||
# handle broadcasting
|
||||
common_shape = torch.broadcast_shapes(
|
||||
@ -538,16 +542,34 @@ def common_pointwise_strategy(
|
||||
redistribute_costs: list[list[float]] = []
|
||||
for input_idx, input_arg in enumerate(args_schema):
|
||||
if isinstance(input_arg, OpStrategy):
|
||||
input_arg_spec = input_arg.strategies[0].output_spec
|
||||
|
||||
# sanity check that all args that follow the same strategy
|
||||
# are on the same DeviceMesh
|
||||
if input_arg.mesh != followed_strategy.mesh:
|
||||
raise ValueError(
|
||||
f"Could not run pointwise computation across different mesh: "
|
||||
f"Found {input_arg.mesh} and {followed_strategy.mesh}!"
|
||||
)
|
||||
# For the scalar tensor arg in fused ops, do not follow followed_strategy;
|
||||
# instead, let the input mesh and the Replicate placements propagate through.
|
||||
if input_idx == scalar_tensor_idx:
|
||||
assert all(p == Replicate() for p in input_arg_spec.placements)
|
||||
input_arg_target_spec = DTensorSpec(
|
||||
mesh=input_arg.mesh,
|
||||
placements=input_arg_spec.placements,
|
||||
tensor_meta=input_arg_spec.tensor_meta,
|
||||
)
|
||||
input_specs.append(input_arg_target_spec)
|
||||
redistribute_costs.append(
|
||||
generate_redistribute_costs(
|
||||
input_arg, input_arg_target_spec
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not run pointwise computation across different mesh: "
|
||||
f"Found {input_arg.mesh} and {followed_strategy.mesh}!"
|
||||
)
|
||||
|
||||
# every arg follow the out_placements, but need to handle broadcasting
|
||||
input_arg_spec = input_arg.strategies[0].output_spec
|
||||
input_arg_dims_map = infer_broadcast_dims_map(
|
||||
common_shape, input_arg_spec.shape
|
||||
)
|
||||
@ -680,11 +702,13 @@ def list_pointwise_strategy(
|
||||
OpStrategy: generated strategy
|
||||
"""
|
||||
|
||||
def args_tuple_strategies(args_schema: tuple[object, ...]) -> list[TupleStrategy]:
|
||||
def args_tuple_strategies(
|
||||
args_schema: tuple[object, ...],
|
||||
) -> list[Optional[TupleStrategy]]:
|
||||
first_arg = args_schema[0]
|
||||
assert isinstance(first_arg, TupleStrategy)
|
||||
strategy_len = len(first_arg.childs)
|
||||
tuple_strategies: list[TupleStrategy] = []
|
||||
tuple_strategies: list[Optional[TupleStrategy]] = []
|
||||
for arg_idx, arg in enumerate(args_schema):
|
||||
if isinstance(arg, TupleStrategy):
|
||||
# every tuple strategy should have the same length
|
||||
@ -699,19 +723,28 @@ def list_pointwise_strategy(
|
||||
raise RuntimeError(
|
||||
f"list op only supports tuple strategy! {op_schema}"
|
||||
)
|
||||
else:
|
||||
# insert None as placeholder so that the idx of arg is kept
|
||||
tuple_strategies.append(None)
|
||||
return tuple_strategies
|
||||
|
||||
args_strategies = args_tuple_strategies(op_schema.args_schema)
|
||||
follow_strategy: TupleStrategy = args_strategies[0]
|
||||
follow_strategy: TupleStrategy = not_none(args_strategies[0])
|
||||
list_strategy: list[OpStrategy] = []
|
||||
|
||||
for child_idx, child_strtgy in enumerate(follow_strategy.childs):
|
||||
assert isinstance(child_strtgy, OpStrategy)
|
||||
args_schema: list[OpStrategy] = [
|
||||
cast(OpStrategy, arg_strategy.childs[child_idx])
|
||||
args_schema: list[Optional[OpStrategy]] = [
|
||||
cast(OpStrategy, arg_strategy.childs[child_idx]) if arg_strategy else None
|
||||
for arg_strategy in args_strategies
|
||||
]
|
||||
pointwise_strategy: OpStrategy = common_pointwise_strategy(
|
||||
args_schema, child_strtgy, linearity
|
||||
args_schema,
|
||||
child_strtgy,
|
||||
linearity,
|
||||
scalar_tensor_idx=_FUSED_OP_SCALAR_IDX
|
||||
if op_schema.op in fused_ops
|
||||
else None,
|
||||
)
|
||||
list_strategy.append(pointwise_strategy)
|
||||
return TupleStrategy(list_strategy)
|
||||
@ -745,6 +778,13 @@ fused_ops = [
|
||||
aten._fused_adamw_.tensor_lr,
|
||||
]
|
||||
|
||||
|
||||
# The state_steps arg of fused adam / adamw is a Replicate scalar tensor, which will be put on
|
||||
# the compute_mesh of an op across all parameter groups, even when not all parameter groups
|
||||
# are on the same device mesh. This idx will help avoid hitting exceptions or unnecessary
|
||||
# redistribute during sharding propagation.
|
||||
_FUSED_OP_SCALAR_IDX = 5
|
||||
|
||||
for op in fused_ops:
|
||||
register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
|
||||
list_pointwise_strategy
|
||||
|
||||
Reference in New Issue
Block a user