Files
pytorch/test/inductor/test_distributed_patterns.py
Chinmay Kuchinad 55f01a48af [ROCm] Enable and fix several FSDP + Inductor distributed unit tests (#165011)
This PR enables a number of distributed unit tests and applies necessary fixes to ensure they pass on ROCm platforms. The changes have been successfully tested on both MI200 and MI300 hardware.

This work addresses the following issues:
**https://github.com/ROCm/frameworks-internal/issues/13586
https://github.com/ROCm/frameworks-internal/issues/13578**

**Enabled Tests**

The following tests have been enabled and are now passing:
1. test_compiled_autograd_ctx
2. test_simple_mlp_fullgraph_backend_aot_eager
3. test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition
4. test_simple_mlp_fullgraph_backend_inductor
5. test_nested_fully_shard_backend_aot_eager
6. test_nested_fully_shard_backend_aot_eager_decomp_partition
7. test_nested_fully_shard_backend_inductor_fullgraph_True
8. test_nested_fully_shard_backend_inductor_fullgraph_True_graph_partition
9. test_transformer_backend_aot_eager
10. test_transformer_backend_aot_eager_decomp_partition
11. test_storage_resize_zero_gpu
12. test_storage_resize_nonzero_gpu
13. test_fake_distributed_inductor

**Tests skipped due to upstream issues:**
1. test_nested_fully_shard_backend_inductor_fullgraph_False
2. test_transformer_backend_inductor_fullgraph_True
3. test_transformer_backend_inductor_fullgraph_True_graph_partition
4. test_transformer_backend_inductor_fullgraph_False

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165011
Approved by: https://github.com/jeffdaily
2025-10-10 14:10:54 +00:00

506 lines
17 KiB
Python

# Owner(s): ["oncall: pt2"]
import dataclasses
import functools
import torch
from torch import nn
from torch._dynamo import compiled_autograd
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import IS_MACOS, skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, requires_gpu
# Fake distributed
WORLD_SIZE = 2
def init_fake_distributed(device="cpu"):
@torch.no_grad
def all_gather(t):
return torch.cat([t] * WORLD_SIZE, 0)
@torch.no_grad
def reduce_scatter(t):
# clone since reduce_scatter input and output should not be aliases.
return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone()
def fw_pre_hook(mod, inp):
mod.unsharded_weight.untyped_storage().resize_(
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
)
with (
torch.no_grad(),
torch.autograd._unsafe_preserve_version_counter(mod.unsharded_weight),
):
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
mod._parameters["weight"] = mod.unsharded_weight
# Forward:
# mod.sharded_weight = local_shard (always)
# Before:
# mod.weight = local_shard
# mod.unsharded_weight = zero-sized allgather
# After:
# mod.weight = local_shard
# mod.unsharded_weight = zero-sized allgather
def fw_post_hook(mod, inp, out):
mod._parameters["weight"] = mod.sharded_weight
mod.unsharded_weight.untyped_storage().resize_(0)
def bw_pre_hook(mod, gO):
mod.unsharded_weight.untyped_storage().resize_(
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
)
with (
torch.no_grad(),
torch.autograd._unsafe_preserve_version_counter(mod.unsharded_weight),
):
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
mod._parameters["weight"] = mod.unsharded_weight
# Backward:
# mod.sharded_weight = local_shard (always)
# Before:
# mod.weight = local_shard
# mod.unsharded_weight = zero-sized allgather
# After:
# mod.weight = local_shard
# mod.unsharded_weight = zero-sized allgather
def bw_post_hook(mod, gI, gO):
grad = mod.weight.grad
new_grad = reduce_scatter(grad)
mod._parameters["weight"] = mod.sharded_weight
mod.weight.grad = new_grad
mod.unsharded_weight.untyped_storage().resize_(0)
torch.manual_seed(1234)
m = nn.Linear(20, 10, bias=False, device=device)
# Mimics eager 1st iteration
m.sharded_weight = nn.Parameter(reduce_scatter(m.weight))
m.unsharded_weight = nn.Parameter(all_gather(m.sharded_weight))
m.unsharded_weight.untyped_storage().resize_(0)
m.register_full_backward_pre_hook(bw_pre_hook)
m.register_full_backward_hook(bw_post_hook)
m.register_forward_pre_hook(fw_pre_hook)
m.register_forward_hook(fw_post_hook)
return m, torch.rand(2, 20, requires_grad=True, device=device)
def init_module_bw_hooks(allow_eager):
def bw_pre_hook(mod, gO):
assert allow_eager or torch._dynamo.is_compiling()
assert mod.weight.size() == (10, 10)
mod.hook_count_pre.add_(1)
return (torch.sin(gO[0] + 1.2),)
def bw_post_hook(mod, gI, gO):
assert allow_eager or torch._dynamo.is_compiling()
assert mod.weight.size() == (10, 10)
mod.hook_count_post.add_(1)
return (torch.sin(gI[0] + 3.4),)
torch.manual_seed(1234)
m = nn.Linear(10, 10)
m.hook_count_pre = torch.tensor(0)
m.hook_count_post = torch.tensor(0)
m.register_full_backward_pre_hook(bw_pre_hook)
m.register_full_backward_hook(bw_post_hook)
return m, torch.rand(2, 10, requires_grad=True)
def steps(m, inp):
for _ in range(4):
out = m(inp)
out.sum().backward()
return out
class DistributedPatternTests(TestCase):
def test_intermediate_hook_with_closure(self):
@dataclasses.dataclass
class CustomObj:
val: torch.Tensor
def fn(x, obj):
y = x.sin()
closure_var = y + 1
y.register_hook(lambda grad: grad + obj.val + closure_var)
z = y.sin()
return z
opt = torch.compile(fn, fullgraph=True)
obj1 = CustomObj(torch.tensor(88))
obj2 = CustomObj(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd._enable(
functools.partial(torch.compile, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_intermediate_hook_with_nested_closure(self):
@dataclasses.dataclass
class CustomObj:
val: torch.Tensor
def fn(x, obj):
def run():
y = x.sin()
closure_var = y + 1
y.register_hook(lambda grad: grad + obj.val + closure_var)
z = y.sin()
return z
return run()
opt = torch.compile(fn, fullgraph=True)
obj1 = CustomObj(torch.tensor(88))
obj2 = CustomObj(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd._enable(
functools.partial(torch.compile, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
@torch.no_grad()
def _test_storage_resize_zero(self, device):
@torch.compile(fullgraph=True)
def fn(x):
y = torch.sin(x)
x.untyped_storage().resize_(0)
return torch.cos(y)
x = torch.randn(10, device=device)
expected = torch.cos(torch.sin(x))
y = fn(x)
self.assertEqual(y, expected)
self.assertEqual(x.untyped_storage().size(), 0)
def test_storage_resize_zero_cpu(self):
self._test_storage_resize_zero("cpu")
@requires_gpu()
def test_storage_resize_zero_gpu(self):
self._test_storage_resize_zero(GPU_TYPE)
@torch.no_grad()
def _test_storage_resize_nonzero(self, device):
@torch.compile(fullgraph=True)
def fn(x, out):
y = torch.sin(x)
assert out.untyped_storage().size() == 0
out.untyped_storage().resize_(x.untyped_storage().size())
out.copy_(y.cos())
x = torch.randn(10, device=device)
out = torch.randn(10, device=device)
expected = torch.cos(torch.sin(x))
out.untyped_storage().resize_(0)
fn(x, out)
self.assertEqual(out.untyped_storage().size(), x.untyped_storage().size())
self.assertEqual(out, expected)
def test_storage_resize_nonzero_cpu(self):
self._test_storage_resize_nonzero("cpu")
@requires_gpu()
def test_storage_resize_nonzero_gpu(self):
self._test_storage_resize_nonzero(GPU_TYPE)
@torch.no_grad()
def test_unsafe_set_version_counter1(self):
cnt = CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(w, x):
x = x.sin()
v = w._version
w.copy_(x + 1)
torch._C._autograd._unsafe_set_version_counter((w,), (v,))
return w, v
for v in (3, 0, 1):
w1 = torch.randn(16)
for i in range(v):
w1.fill_(i) # bump w1._version
self.assertEqual(w1._version, v)
x1 = torch.randn(16)
w2, v2 = fn(w1, x1)
self.assertIs(w1, w2)
self.assertEqual(w1, x1.sin() + 1)
self.assertEqual(v2, v)
self.assertEqual(w1._version, v)
self.assertEqual(cnt.frame_count, 1)
def test_unsafe_set_version_counter2(self):
@torch.compile(backend="inductor", fullgraph=True)
def fn(w, x):
r = w.sin()
with torch.no_grad():
v = w._version
w.copy_(x)
torch._C._autograd._unsafe_set_version_counter((w,), (v,))
return r
w1 = torch.randn(1, requires_grad=True)
x1 = torch.randn(1)
expected_r1 = w1.detach().sin()
r1 = fn(w1, x1)
r1.backward()
self.assertEqual(r1, expected_r1)
self.assertEqual(w1, x1)
self.assertEqual(w1.grad, x1.cos())
@torch.no_grad()
def test_unsafe_preserve_version_counter1(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(w, x):
x = x.sin()
with torch.autograd._unsafe_preserve_version_counter(w):
w.copy_(x + 1)
return w
w1 = torch.randn(16).fill_(0).fill_(1)
x1 = torch.randn(16)
v1 = w1._version
w2 = fn(w1, x1)
v2 = w1._version
self.assertIs(w1, w2)
self.assertEqual(w1, x1.sin() + 1)
self.assertEqual(v1, v2)
def test_unsafe_preserve_version_counter2(self):
@torch.compile(backend="inductor", fullgraph=True)
def fn(w, x):
r = w.sin()
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(w):
w.copy_(x)
return r
w1 = torch.randn(1, requires_grad=True)
x1 = torch.randn(1)
expected_r1 = w1.detach().sin()
r1 = fn(w1, x1)
r1.backward()
self.assertEqual(r1, expected_r1)
self.assertEqual(w1, x1)
self.assertEqual(w1.grad, x1.cos())
def test_module_backward_hooks_eager(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(False)
fw_cnt = CompileCounter()
bw_cnt = CompileCounter()
with compiled_autograd._enable(torch.compile(backend=bw_cnt, fullgraph=True)):
m2 = torch.compile(m2, backend=fw_cnt, fullgraph=True)
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
self.assertEqual(fw_cnt.frame_count, 1)
self.assertEqual(fw_cnt.op_count, 5)
self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None
self.assertEqual(
bw_cnt.op_count, 111
) # Number of ops in the Dynamo-produced graphs
def test_module_backward_hooks_aot(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(True)
m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
with compiled_autograd._enable(lambda gm: gm):
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
def test_module_backward_hooks_inductor(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(False)
m2 = torch.compile(m2, fullgraph=True)
with compiled_autograd._enable(torch.compile(fullgraph=True)):
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
def test_module_backward_hooks_multi_layers(self):
a1, inp1 = init_module_bw_hooks(True)
b1, _ = init_module_bw_hooks(True)
out1 = steps(torch.nn.Sequential(a1, b1), inp1)
a2, inp2 = init_module_bw_hooks(False)
b2, _ = init_module_bw_hooks(False)
with compiled_autograd._enable(torch.compile(fullgraph=True)):
out2 = steps(
torch.compile(torch.nn.Sequential(a2, b2), fullgraph=True), inp2
)
self.assertEqual(a1.hook_count_pre, a2.hook_count_pre)
self.assertEqual(a1.hook_count_post, a2.hook_count_post)
self.assertEqual(b1.hook_count_pre, b2.hook_count_pre)
self.assertEqual(b1.hook_count_post, b2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(a1.weight.grad, a2.weight.grad)
self.assertEqual(a1.bias.grad, a2.bias.grad)
self.assertEqual(b1.weight.grad, b2.weight.grad)
self.assertEqual(b1.bias.grad, b2.bias.grad)
# TODO(jansel): support bw hooks with graph break
def _assert_same_grad(self, a, b):
self.assertEqual(type(a), type(b))
self.assertEqual(a, b)
self.assertEqual(a.grad, b.grad)
self.assertEqual(a.requires_grad, b.requires_grad)
def test_nn_param_return1(self):
def fn(x):
p = torch.nn.Parameter(x)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
r1.sum().backward()
p2, r2 = opt(x2)
r2.sum().backward()
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return2(self):
def fn(x):
p = torch.nn.Parameter(x, requires_grad=False)
return p, x + 1
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
p2, r2 = opt(x2)
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
@torch._dynamo.config.patch("graph_break_on_nn_param_ctor", False)
def test_nn_param_return3(self):
def fn(x):
p = torch.nn.Parameter(x + 123)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
r1.sum().backward()
p2, r2 = opt(x2)
r2.sum().backward()
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
@torch._dynamo.config.patch("graph_break_on_nn_param_ctor", False)
def test_nn_param_return4(self):
def fn(x):
p = torch.nn.Parameter(x + 123, requires_grad=False)
return p, x + 1
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
p2, r2 = opt(x2)
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
@torch._functorch.config.patch(recompute_views=True)
def test_fake_distributed_aot_eager(self):
m1, inp1 = init_fake_distributed()
out1 = steps(m1, inp1)
m2, inp2 = init_fake_distributed()
m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
bw_cnt = CompileCounter()
with compiled_autograd._enable(torch.compile(backend=bw_cnt, fullgraph=True)):
out2 = steps(m2, inp2)
self._assert_same_grad(m1.weight, m2.weight)
self._assert_same_grad(inp1, inp2)
self._assert_same_grad(out1, out2)
# Recompile on grad==None/grad!=None
self.assertEqual(bw_cnt.frame_count, 2)
@skipIfXpu
@requires_gpu()
@torch._functorch.config.patch(recompute_views=True)
def test_fake_distributed_inductor(self):
m1, inp1 = init_fake_distributed(GPU_TYPE)
out1 = steps(m1, inp1)
m2, inp2 = init_fake_distributed(GPU_TYPE)
m2 = torch.compile(m2, fullgraph=True)
with compiled_autograd._enable(torch.compile(fullgraph=True)):
out2 = steps(m2, inp2)
self._assert_same_grad(m1.weight, m2.weight)
self._assert_same_grad(inp1, inp2)
self._assert_same_grad(out1, out2)
if __name__ == "__main__":
if HAS_CPU and not IS_MACOS:
run_tests(needs="filelock")