[dtensor] add support for fused optimizer with parameters across multiple meshes (#157682)

We are seeing more and more use cases where parameters in a model (under the same optimizer group) are put on different meshes. E.g.
- when FSDP and TP are both applied, some parameters are sharded only on the FSDP mesh but not TP mesh (see https://github.com/pytorch/pytorch/pull/153268).
- in [dp2ep Expert Parallel](https://github.com/pytorch/torchtitan/pull/1324), the routed experts are sharded on the (global FSDP \ EP) mesh for smaller FSDP and on the EP mesh for EP, whereas other params are sharded on the global FSDP mesh for FSDP.

This PR is, in some sense, a continuation of https://github.com/pytorch/pytorch/pull/147869 to tackle the problem when fused optimizers are used. In such cases, the [`fused_adam`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L15786) / `fused_adamw` has a scalar tensor arg `state_steps` which gets automatically cast to DTensor on the default [`compute_mesh`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_dispatch.py#L350) (one of the multiple meshes), even though the it could correspond to different meshes.

To avoid hitting the cross-mesh propagation exception in `common_pointwise_strategy` and followup redistribute problems, we manually set the target mesh and placements to be the same as input mesh and placements, so that no redistribute will be triggered. This also helps bypass the situation where [`generate_redistribute_costs`](https://github.com/pytorch/pytorch/pull/157682/files#diff-eea32a36dd2d4e58307bc5229402e48048b2ecaef64a7c085495fba1ee10ac89R597) returns infinite cost due to cross mesh redistribute.

Moreover, this PR has minimal scope (restricted to the `fused_ops`) and doesn't need to modify other files such as `_sharding_prop.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157682
Approved by: https://github.com/wanchaol
This commit is contained in:
Tianyu Liu
2025-07-07 20:49:17 -07:00
committed by PyTorch MergeBot
parent 777eca9f16
commit ed911747c2
2 changed files with 162 additions and 12 deletions

View File

@ -9,6 +9,7 @@ from torch.distributed.tensor import (
distribute_module,
distribute_tensor,
DTensor,
init_device_mesh,
Replicate,
Shard,
)
@ -606,6 +607,115 @@ class TestDTensorOptimizer(DTensorTestBase):
mesh, mod, opt, dist_mod, dist_opt, inp, atol=1.3e-5, rtol=1e-4
)
@with_comms
def test_admaw_fused_across_meshes(self):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=("x", "y")
)
mesh_flatten = mesh_2d[("x", "y")]._flatten(mesh_dim_name="mesh_flatten")
# lr as a Tensor is not supported for capturable=False and foreach=True
adamw_float_lr_configs = [
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"amsgrad": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"maximize": True,
"amsgrad": True,
},
]
fused_adamw_float_lr_configs = [
{"lr": 0.1, "weight_decay": 0.05, "fused": True},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"amsgrad": True,
"fused": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"maximize": True,
"amsgrad": True,
"fused": True,
},
]
# lr could be a Tensor or a float when fused=True for adamW optimizer
fused_adamw_tensor_lr_configs = [
{**config, "lr": torch.tensor(0.1)}
for config in fused_adamw_float_lr_configs
]
fused_adamw_tensor_lr_configs.extend(
[
{**config, "lr": torch.tensor([0.1])}
for config in fused_adamw_float_lr_configs
]
)
adamw_configs = [
*adamw_float_lr_configs,
*fused_adamw_float_lr_configs,
*fused_adamw_tensor_lr_configs,
]
# shard function to do full sharding on all parameters of a module
def _shard_fn_2d(name, module, device_mesh):
if isinstance(module, nn.Linear):
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate(), Shard(0)])
)
# make sure partial sum get cleared after backward()
dist_param.register_hook(
lambda grad: grad.redistribute(
placements=[Replicate(), Shard(0)]
)
)
module.register_parameter(name, dist_param)
# prepare input
def _input_fn_2d(mod, inputs, device_mesh):
# split the input tensor to be sharded input on a 2d mesh
dist_inp = DTensor.from_local(
inputs[0], device_mesh, [Replicate(), Shard(0)], run_check=False
)
return dist_inp
for config in adamw_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.AdamW(mod.parameters(), **config)
mod_copy = deepcopy(mod)
# MLPModule.net1 is sharded on the flatten mesh
distribute_module(
mod_copy.net1, mesh_flatten, shard_fn, input_fn, output_fn
)
# MLPModule.net2 is sharded on the 2d mesh
distribute_module(
mod_copy.net2, mesh_2d, _shard_fn_2d, _input_fn_2d, output_fn
)
dist_opt = torch.optim.AdamW(mod_copy.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(None, mod, opt, mod_copy, dist_opt, inp)
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Sequence
from typing import cast
from typing import cast, Optional
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@ -25,6 +25,7 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
from torch.utils._typing_utils import not_none
aten = torch.ops.aten
@ -485,6 +486,7 @@ def common_pointwise_strategy(
followed_strategy: OpStrategy,
followed_strategy_index: int,
linearity: int = -1,
scalar_tensor_idx: Optional[int] = None,
) -> OpStrategy:
"""
Common strategy for pointwise operations.
@ -501,6 +503,8 @@ def common_pointwise_strategy(
2: the binary operation supports multiplicative linearity, where it requires
the primary operand to be partial, and the other operands to be replicate,
output propagates partial.
scalar_tensor_idx: Index of the Replicate scalar tensor for which we allow the mesh
to be different from the mesh of followed_strategy
"""
# handle broadcasting
common_shape = torch.broadcast_shapes(
@ -538,16 +542,34 @@ def common_pointwise_strategy(
redistribute_costs: list[list[float]] = []
for input_idx, input_arg in enumerate(args_schema):
if isinstance(input_arg, OpStrategy):
input_arg_spec = input_arg.strategies[0].output_spec
# sanity check that all args that follow the same strategy
# are on the same DeviceMesh
if input_arg.mesh != followed_strategy.mesh:
raise ValueError(
f"Could not run pointwise computation across different mesh: "
f"Found {input_arg.mesh} and {followed_strategy.mesh}!"
)
# For the scalar tensor arg in fused ops, do not follow followed_strategy;
# instead, let the input mesh and the Replicate placements propagate through.
if input_idx == scalar_tensor_idx:
assert all(p == Replicate() for p in input_arg_spec.placements)
input_arg_target_spec = DTensorSpec(
mesh=input_arg.mesh,
placements=input_arg_spec.placements,
tensor_meta=input_arg_spec.tensor_meta,
)
input_specs.append(input_arg_target_spec)
redistribute_costs.append(
generate_redistribute_costs(
input_arg, input_arg_target_spec
)
)
continue
else:
raise ValueError(
f"Could not run pointwise computation across different mesh: "
f"Found {input_arg.mesh} and {followed_strategy.mesh}!"
)
# every arg follow the out_placements, but need to handle broadcasting
input_arg_spec = input_arg.strategies[0].output_spec
input_arg_dims_map = infer_broadcast_dims_map(
common_shape, input_arg_spec.shape
)
@ -680,11 +702,13 @@ def list_pointwise_strategy(
OpStrategy: generated strategy
"""
def args_tuple_strategies(args_schema: tuple[object, ...]) -> list[TupleStrategy]:
def args_tuple_strategies(
args_schema: tuple[object, ...],
) -> list[Optional[TupleStrategy]]:
first_arg = args_schema[0]
assert isinstance(first_arg, TupleStrategy)
strategy_len = len(first_arg.childs)
tuple_strategies: list[TupleStrategy] = []
tuple_strategies: list[Optional[TupleStrategy]] = []
for arg_idx, arg in enumerate(args_schema):
if isinstance(arg, TupleStrategy):
# every tuple strategy should have the same length
@ -699,19 +723,28 @@ def list_pointwise_strategy(
raise RuntimeError(
f"list op only supports tuple strategy! {op_schema}"
)
else:
# insert None as placeholder so that the idx of arg is kept
tuple_strategies.append(None)
return tuple_strategies
args_strategies = args_tuple_strategies(op_schema.args_schema)
follow_strategy: TupleStrategy = args_strategies[0]
follow_strategy: TupleStrategy = not_none(args_strategies[0])
list_strategy: list[OpStrategy] = []
for child_idx, child_strtgy in enumerate(follow_strategy.childs):
assert isinstance(child_strtgy, OpStrategy)
args_schema: list[OpStrategy] = [
cast(OpStrategy, arg_strategy.childs[child_idx])
args_schema: list[Optional[OpStrategy]] = [
cast(OpStrategy, arg_strategy.childs[child_idx]) if arg_strategy else None
for arg_strategy in args_strategies
]
pointwise_strategy: OpStrategy = common_pointwise_strategy(
args_schema, child_strtgy, linearity
args_schema,
child_strtgy,
linearity,
scalar_tensor_idx=_FUSED_OP_SCALAR_IDX
if op_schema.op in fused_ops
else None,
)
list_strategy.append(pointwise_strategy)
return TupleStrategy(list_strategy)
@ -745,6 +778,13 @@ fused_ops = [
aten._fused_adamw_.tensor_lr,
]
# The state_steps arg of fused adam / adamw is a Replicate scalar tensor, which will be put on
# the compute_mesh of an op across all parameter groups, even when not all parameter groups
# are on the same device mesh. This idx will help avoid hitting exceptions or unnecessary
# redistribute during sharding propagation.
_FUSED_OP_SCALAR_IDX = 5
for op in fused_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
list_pointwise_strategy