Compare commits

...

2 Commits

Author SHA1 Message Date
079fd19cc0 [c10d] support dynamic shapes for all_to_all_single_autograd
ghstack-source-id: afd2f9ebf593a5987a9d7d69b9b875460d6d79a0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157521
2025-07-03 09:00:00 -07:00
30274c6c83 [DTensor][FSDP2] necessary changes to FSDP and TP to unblock EP
ghstack-source-id: 62a57a5117bbb2d959ae7ccca6d0d2a3131d7d98
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157216
2025-06-29 01:05:01 -07:00
5 changed files with 154 additions and 25 deletions

View File

@ -0,0 +1,121 @@
# Owner(s): ["module: dynamo"]
from unittest import skipIf
import torch
import torch.distributed as dist
from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
from torch.distributed._functional_collectives import (
all_to_all_single_autograd,
wait_tensor,
)
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.distributed.fake_pg import FakeStore
def normalize_graph(gm):
return normalize_gm(gm.print_readable(print_output=False))
@skipIf(not dist.is_available(), "requires distributed")
class TestFakeDistributed(DynamoTestCase):
def setUp(self):
# Use FakeProcessGroup to run tests on a single process
self.store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=self.store)
def tearDown(self):
dist.destroy_process_group()
def test_all_to_all_single_autograd(self):
backend = AotEagerAndRecordGraphs()
@torch.compile(fullgraph=True, backend=backend)
def fn(x):
return all_to_all_single_autograd(
x,
None, # Will use equal splits
None, # Will use equal splits
group=dist.group.WORLD,
)
# Test backed shapes
x = torch.randn(8, 8, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
wait_tensor(fn(x))
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
self.assertExpectedInline(
normalize_graph(backend.fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", primals_3: "f32[s77, s27]"):
floordiv: "Sym((s77//2))" = primals_1 // 2
all_to_all_single: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.all_to_all_single.default(primals_3, [floordiv, floordiv], [floordiv, floordiv], '0'); primals_3 = None
wait_tensor: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None
return (wait_tensor, primals_1, primals_2, floordiv)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_graph(backend.bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", floordiv: "Sym((s77//2))", tangents_1: "f32[2*((s77//2)), s27]"):
all_to_all_single_1: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.all_to_all_single.default(tangents_1, [floordiv, floordiv], [floordiv, floordiv], '0'); tangents_1 = floordiv = None
wait_tensor_1: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None
return (None, None, wait_tensor_1)
""", # noqa: B950
)
backend.fw_graphs.clear()
backend.bw_graphs.clear()
# Test unbacked shapes
x = torch.randn(8, 8, 8, requires_grad=True)
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
torch._dynamo.decorators.mark_unbacked(x, 2)
wait_tensor(fn(x))
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
self.assertExpectedInline(
normalize_graph(backend.fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"):
ge_1: "Sym(u0 >= 0)" = primals_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = primals_2 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge_5: "Sym(u2 >= 0)" = primals_3 >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
floordiv: "Sym((u0//2))" = primals_1 // 2
all_to_all_single: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.all_to_all_single.default(primals_4, [floordiv, floordiv], [floordiv, floordiv], '0'); primals_4 = None
wait_tensor: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None
return (wait_tensor, primals_1, primals_2, primals_3, floordiv)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_graph(backend.bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", floordiv: "Sym((u0//2))", tangents_1: "f32[2*((u0//2)), u1, u2]"):
all_to_all_single_1: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.all_to_all_single.default(tangents_1, [floordiv, floordiv], [floordiv, floordiv], '0'); tangents_1 = floordiv = None
wait_tensor_1: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None
return (None, None, None, wait_tensor_1)
""", # noqa: B950
)
instantiate_parametrized_tests(TestFakeDistributed)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -197,10 +197,19 @@ at::Tensor reduce_scatter_tensor(
at::Tensor all_to_all_single(
const at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
c10::SymIntArrayRef _output_split_sizes,
c10::SymIntArrayRef _input_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
std::vector<int64_t> output_split_sizes;
std::vector<int64_t> input_split_sizes;
for (const auto& size : _output_split_sizes) {
output_split_sizes.emplace_back(size.expect_int());
}
for (const auto& size : _input_split_sizes) {
input_split_sizes.emplace_back(size.expect_int());
}
TORCH_CHECK(input.is_contiguous());
std::vector<int64_t> output_sizes = input.sizes().vec();
output_sizes[0] = std::accumulate(
@ -338,14 +347,14 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<int64_t> output_split_sizes,
at::SymIntArrayRef output_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<int64_t> input_split_sizes,
at::SymIntArrayRef input_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
// swap sizes for backwards pass
ctx->saved_data["output_split_sizes"] = input_split_sizes;
ctx->saved_data["input_split_sizes"] = output_split_sizes;
ctx->saved_data["output_split_sizes"] = input_split_sizes.vec();
ctx->saved_data["input_split_sizes"] = output_split_sizes.vec();
ctx->saved_data["group_name"] = group_name;
return c10::Dispatcher::singleton()
@ -357,10 +366,10 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_out_list) {
const std::vector<int64_t>& output_split_sizes =
ctx->saved_data["output_split_sizes"].toIntVector();
const std::vector<int64_t>& input_split_sizes =
ctx->saved_data["input_split_sizes"].toIntVector();
std::vector<c10::SymInt> output_split_sizes =
ctx->saved_data["output_split_sizes"].toSymIntVector();
std::vector<c10::SymInt> input_split_sizes =
ctx->saved_data["input_split_sizes"].toSymIntVector();
const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
DCHECK(grad_out_list.size() == 1);
@ -385,8 +394,8 @@ class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
at::Tensor all_to_all_single_autograd(
const at::Tensor& input,
const std::vector<int64_t>& output_split_sizes,
const std::vector<int64_t>& input_split_sizes,
at::SymIntArrayRef output_split_sizes,
at::SymIntArrayRef input_split_sizes,
const std::string& group_name) {
return AllToAllSingle::apply(
input, output_split_sizes, input_split_sizes, group_name);

View File

@ -60,8 +60,8 @@ C10_EXPORT at::Tensor reduce_scatter_tensor(
C10_EXPORT at::Tensor all_to_all_single(
const at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
at::SymIntArrayRef output_split_sizes,
at::SymIntArrayRef input_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name);

View File

@ -300,13 +300,13 @@ class FSDPParam:
assert tp_mesh.mesh_dim_names is not None, name_dims_error
submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names
self._spmd_mesh = dp_global_mesh[submesh_names]
if len(self._tp_spec.placements) != 1:
raise NotImplementedError(
f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
)
# if len(self._tp_spec.placements) != 1:
# raise NotImplementedError(
# f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
# )
split_factor = self._tp_spec.num_shards_map[shard_dim]
assert 2 <= self._spmd_mesh.ndim <= 3, (
f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
assert 2 <= self._spmd_mesh.ndim <= 4, (
f"_spmd_mesh.ndim can only be 2, 3, or 4 but got {self._spmd_mesh.ndim}."
)
self._spmd_placements: tuple[Placement, ...]
dp_shard_tp_placement = (
@ -315,11 +315,11 @@ class FSDPParam:
if split_factor > 1
else fsdp_placement
),
self._tp_spec.placements[0],
*self._tp_spec.placements,
)
if self._spmd_mesh.ndim == 2:
if dp_mesh.ndim == 1: # FSDP
self._spmd_placements = dp_shard_tp_placement
else:
else: # HSDP
assert self.mesh_info.replicate_mesh_dim == 0
self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
self._sharding_spec = DTensorSpec(

View File

@ -6,7 +6,6 @@ from typing import Optional, Union
import torch
import torch.nn as nn
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import ParallelStyle
@ -71,7 +70,7 @@ def parallelize_module( # type: ignore[return]
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
_validate_tp_mesh_dim(device_mesh)
# _validate_tp_mesh_dim(device_mesh)
if parallelize_plan is None:
warnings.warn(