mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Using `fsdp.set_` for unsharded_param inplace update causes difficult-to-debug errors when enabling Traceable FSDP2 on TorchTune models. In this PR, we change it to use `fsdp.copy_` which fixes the error and also strictly follows eager semantics (i.e. if user explictly stores an alias of the unsharded_param during execution of the user's module code, that alias will get updated correctly when the unsharded_param is copy_ into; whereas if we just swap out unsharded_param storage via set_, that user-saved alias will not get updated, which is not good). This PR also implements the graph pass to remove the resizes and copy if there is a resize_(full) -> copy_ -> resize_(0) pattern. ------ Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_copy_` - `pytest -rA test/dynamo/test_repros.py::ReproTests::test_partitioner_cse_respects_mutation_boundaries` - `pytest -rA test/dynamo/test_repros.py::ReproTests::test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients` - `pytest -rA test/inductor/test_pattern_matcher.py::TestPatternMatcher::test_mutation_op_matching` - `python test/inductor/test_distributed_patterns.py DistributedPatternTests.test_fake_distributed_aot_eager` - `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 PYTORCH_TEST_WITH_CROSSREF=1 python test/functorch/test_aotdispatch.py TestEagerFusionOpInfoCPU.test_aot_autograd_exhaustive_norm_cpu_float32` - `python test/distributed/test_inductor_collectives.py TestCollectivesInductor.test_backwards` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133730 Approved by: https://github.com/bdhirsh
466 lines
15 KiB
Python
466 lines
15 KiB
Python
# Owner(s): ["oncall: pt2"]
|
|
import dataclasses
|
|
import functools
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch._dynamo import compiled_autograd
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._dynamo.testing import CompileCounter
|
|
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm, skipIfXpu
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, requires_gpu
|
|
|
|
|
|
# Fake distributed
|
|
WORLD_SIZE = 2
|
|
|
|
|
|
def init_fake_distributed(device="cpu"):
|
|
@torch.no_grad
|
|
def all_gather(t):
|
|
return torch.cat([t] * WORLD_SIZE, 0)
|
|
|
|
@torch.no_grad
|
|
def reduce_scatter(t):
|
|
# clone since reduce_scatter input and output should not be aliases.
|
|
return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone()
|
|
|
|
def fw_pre_hook(mod, inp):
|
|
mod.unsharded_weight.untyped_storage().resize_(
|
|
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
|
|
)
|
|
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
|
mod.unsharded_weight
|
|
):
|
|
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
|
|
mod._parameters["weight"] = mod.unsharded_weight
|
|
|
|
# Forward:
|
|
# mod.sharded_weight = local_shard (always)
|
|
# Before:
|
|
# mod.weight = local_shard
|
|
# mod.unsharded_weight = zero-sized allgather
|
|
# After:
|
|
# mod.weight = local_shard
|
|
# mod.unsharded_weight = zero-sized allgather
|
|
|
|
def fw_post_hook(mod, inp, out):
|
|
mod._parameters["weight"] = mod.sharded_weight
|
|
mod.unsharded_weight.untyped_storage().resize_(0)
|
|
|
|
def bw_pre_hook(mod, gO):
|
|
mod.unsharded_weight.untyped_storage().resize_(
|
|
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
|
|
)
|
|
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
|
|
mod.unsharded_weight
|
|
):
|
|
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
|
|
mod._parameters["weight"] = mod.unsharded_weight
|
|
|
|
# Backward:
|
|
# mod.sharded_weight = local_shard (always)
|
|
# Before:
|
|
# mod.weight = local_shard
|
|
# mod.unsharded_weight = zero-sized allgather
|
|
# After:
|
|
# mod.weight = local_shard
|
|
# mod.unsharded_weight = zero-sized allgather
|
|
|
|
def bw_post_hook(mod, gI, gO):
|
|
grad = mod.weight.grad
|
|
new_grad = reduce_scatter(grad)
|
|
mod._parameters["weight"] = mod.sharded_weight
|
|
mod.weight.grad = new_grad
|
|
mod.unsharded_weight.untyped_storage().resize_(0)
|
|
|
|
torch.manual_seed(1234)
|
|
m = nn.Linear(20, 10, bias=False, device=device)
|
|
|
|
# Mimics eager 1st iteration
|
|
m.sharded_weight = nn.Parameter(reduce_scatter(m.weight))
|
|
m.unsharded_weight = nn.Parameter(all_gather(m.sharded_weight))
|
|
m.unsharded_weight.untyped_storage().resize_(0)
|
|
|
|
m.register_full_backward_pre_hook(bw_pre_hook)
|
|
m.register_full_backward_hook(bw_post_hook)
|
|
m.register_forward_pre_hook(fw_pre_hook)
|
|
m.register_forward_hook(fw_post_hook)
|
|
return m, torch.rand(2, 20, requires_grad=True, device=device)
|
|
|
|
|
|
def init_module_bw_hooks(allow_eager):
|
|
def bw_pre_hook(mod, gO):
|
|
assert allow_eager or torch._dynamo.is_compiling()
|
|
assert mod.weight.size() == (10, 10)
|
|
mod.hook_count_pre.add_(1)
|
|
return (torch.sin(gO[0] + 1.2),)
|
|
|
|
def bw_post_hook(mod, gI, gO):
|
|
assert allow_eager or torch._dynamo.is_compiling()
|
|
assert mod.weight.size() == (10, 10)
|
|
mod.hook_count_post.add_(1)
|
|
return (torch.sin(gI[0] + 3.4),)
|
|
|
|
torch.manual_seed(1234)
|
|
m = nn.Linear(10, 10)
|
|
m.hook_count_pre = torch.tensor(0)
|
|
m.hook_count_post = torch.tensor(0)
|
|
m.register_full_backward_pre_hook(bw_pre_hook)
|
|
m.register_full_backward_hook(bw_post_hook)
|
|
return m, torch.rand(2, 10, requires_grad=True)
|
|
|
|
|
|
def steps(m, inp):
|
|
for _ in range(4):
|
|
out = m(inp)
|
|
out.sum().backward()
|
|
return out
|
|
|
|
|
|
class DistributedPatternTests(TestCase):
|
|
def test_intermediate_hook_with_closure(self):
|
|
@dataclasses.dataclass
|
|
class CustomObj:
|
|
val: torch.Tensor
|
|
|
|
def fn(x, obj):
|
|
y = x.sin()
|
|
closure_var = y + 1
|
|
y.register_hook(lambda grad: grad + obj.val + closure_var)
|
|
z = y.sin()
|
|
return z
|
|
|
|
opt = torch.compile(fn, fullgraph=True)
|
|
|
|
obj1 = CustomObj(torch.tensor(88))
|
|
obj2 = CustomObj(torch.tensor(99))
|
|
x0 = torch.ones(4, requires_grad=True)
|
|
x1 = torch.ones(4, requires_grad=True)
|
|
x2 = torch.ones(4, requires_grad=True)
|
|
x3 = torch.ones(4, requires_grad=True)
|
|
fn(x0, obj1).sum().backward()
|
|
fn(x1, obj2).sum().backward()
|
|
|
|
with compiled_autograd.enable(functools.partial(torch.compile, fullgraph=True)):
|
|
opt(x2, obj1).sum().backward()
|
|
opt(x3, obj2).sum().backward()
|
|
|
|
self.assertEqual(x0.grad, x2.grad)
|
|
self.assertEqual(x1.grad, x3.grad)
|
|
|
|
@torch.no_grad()
|
|
def _test_storage_resize_zero(self, device):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
y = torch.sin(x)
|
|
x.untyped_storage().resize_(0)
|
|
return torch.cos(y)
|
|
|
|
x = torch.randn(10, device=device)
|
|
expected = torch.cos(torch.sin(x))
|
|
y = fn(x)
|
|
self.assertEqual(y, expected)
|
|
self.assertEqual(x.untyped_storage().size(), 0)
|
|
|
|
def test_storage_resize_zero_cpu(self):
|
|
self._test_storage_resize_zero("cpu")
|
|
|
|
@skipIfRocm
|
|
@requires_gpu()
|
|
def test_storage_resize_zero_gpu(self):
|
|
self._test_storage_resize_zero(GPU_TYPE)
|
|
|
|
@torch.no_grad()
|
|
def _test_storage_resize_nonzero(self, device):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x, out):
|
|
y = torch.sin(x)
|
|
assert out.untyped_storage().size() == 0
|
|
out.untyped_storage().resize_(x.untyped_storage().size())
|
|
out.copy_(y.cos())
|
|
|
|
x = torch.randn(10, device=device)
|
|
out = torch.randn(10, device=device)
|
|
expected = torch.cos(torch.sin(x))
|
|
out.untyped_storage().resize_(0)
|
|
fn(x, out)
|
|
self.assertEqual(out.untyped_storage().size(), x.untyped_storage().size())
|
|
self.assertEqual(out, expected)
|
|
|
|
def test_storage_resize_nonzero_cpu(self):
|
|
self._test_storage_resize_nonzero("cpu")
|
|
|
|
@skipIfRocm
|
|
@requires_gpu()
|
|
def test_storage_resize_nonzero_gpu(self):
|
|
self._test_storage_resize_nonzero(GPU_TYPE)
|
|
|
|
@torch.no_grad()
|
|
def test_unsafe_set_version_counter1(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(w, x):
|
|
x = x.sin()
|
|
v = w._version
|
|
w.copy_(x + 1)
|
|
torch._C._autograd._unsafe_set_version_counter(w, v)
|
|
return w, v
|
|
|
|
for v in (3, 0, 1):
|
|
w1 = torch.randn(16)
|
|
for i in range(v):
|
|
w1.fill_(i) # bump w1._version
|
|
self.assertEqual(w1._version, v)
|
|
x1 = torch.randn(16)
|
|
w2, v2 = fn(w1, x1)
|
|
|
|
self.assertIs(w1, w2)
|
|
self.assertEqual(w1, x1.sin() + 1)
|
|
self.assertEqual(v2, v)
|
|
self.assertEqual(w1._version, v)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_unsafe_set_version_counter2(self):
|
|
@torch.compile(backend="inductor", fullgraph=True)
|
|
def fn(w, x):
|
|
r = w.sin()
|
|
with torch.no_grad():
|
|
v = w._version
|
|
w.copy_(x)
|
|
torch._C._autograd._unsafe_set_version_counter(w, v)
|
|
return r
|
|
|
|
w1 = torch.randn(1, requires_grad=True)
|
|
x1 = torch.randn(1)
|
|
expected_r1 = w1.detach().sin()
|
|
|
|
r1 = fn(w1, x1)
|
|
r1.backward()
|
|
self.assertEqual(r1, expected_r1)
|
|
self.assertEqual(w1, x1)
|
|
self.assertEqual(w1.grad, x1.cos())
|
|
|
|
@torch.no_grad()
|
|
def test_unsafe_preserve_version_counter1(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(w, x):
|
|
x = x.sin()
|
|
with torch.autograd._unsafe_preserve_version_counter(w):
|
|
w.copy_(x + 1)
|
|
return w
|
|
|
|
w1 = torch.randn(16).fill_(0).fill_(1)
|
|
x1 = torch.randn(16)
|
|
v1 = w1._version
|
|
w2 = fn(w1, x1)
|
|
v2 = w1._version
|
|
|
|
self.assertIs(w1, w2)
|
|
self.assertEqual(w1, x1.sin() + 1)
|
|
self.assertEqual(v1, v2)
|
|
|
|
def test_unsafe_preserve_version_counter2(self):
|
|
@torch.compile(backend="inductor", fullgraph=True)
|
|
def fn(w, x):
|
|
r = w.sin()
|
|
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(w):
|
|
w.copy_(x)
|
|
return r
|
|
|
|
w1 = torch.randn(1, requires_grad=True)
|
|
x1 = torch.randn(1)
|
|
expected_r1 = w1.detach().sin()
|
|
|
|
r1 = fn(w1, x1)
|
|
r1.backward()
|
|
self.assertEqual(r1, expected_r1)
|
|
self.assertEqual(w1, x1)
|
|
self.assertEqual(w1.grad, x1.cos())
|
|
|
|
def test_module_backward_hooks_eager(self):
|
|
m1, inp1 = init_module_bw_hooks(True)
|
|
out1 = steps(m1, inp1)
|
|
|
|
m2, inp2 = init_module_bw_hooks(False)
|
|
fw_cnt = CompileCounter()
|
|
bw_cnt = CompileCounter()
|
|
with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)):
|
|
m2 = torch.compile(m2, backend=fw_cnt, fullgraph=True)
|
|
out2 = steps(m2, inp2)
|
|
|
|
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
|
|
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
|
|
self.assertEqual(out1, out2)
|
|
self.assertEqual(inp1.grad, inp2.grad)
|
|
self.assertEqual(m1.weight.grad, m2.weight.grad)
|
|
self.assertEqual(m1.bias.grad, m2.bias.grad)
|
|
|
|
self.assertEqual(fw_cnt.frame_count, 1)
|
|
self.assertEqual(fw_cnt.op_count, 5)
|
|
self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None
|
|
self.assertEqual(bw_cnt.op_count, 48)
|
|
|
|
def test_module_backward_hooks_aot(self):
|
|
m1, inp1 = init_module_bw_hooks(True)
|
|
out1 = steps(m1, inp1)
|
|
|
|
m2, inp2 = init_module_bw_hooks(True)
|
|
m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
|
|
with compiled_autograd.enable(lambda gm: gm):
|
|
out2 = steps(m2, inp2)
|
|
|
|
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
|
|
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
|
|
self.assertEqual(out1, out2)
|
|
self.assertEqual(inp1.grad, inp2.grad)
|
|
self.assertEqual(m1.weight.grad, m2.weight.grad)
|
|
self.assertEqual(m1.bias.grad, m2.bias.grad)
|
|
|
|
def test_module_backward_hooks_inductor(self):
|
|
m1, inp1 = init_module_bw_hooks(True)
|
|
out1 = steps(m1, inp1)
|
|
|
|
m2, inp2 = init_module_bw_hooks(False)
|
|
m2 = torch.compile(m2, fullgraph=True)
|
|
with compiled_autograd.enable(torch.compile(fullgraph=True)):
|
|
out2 = steps(m2, inp2)
|
|
|
|
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
|
|
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
|
|
self.assertEqual(out1, out2)
|
|
self.assertEqual(inp1.grad, inp2.grad)
|
|
self.assertEqual(m1.weight.grad, m2.weight.grad)
|
|
self.assertEqual(m1.bias.grad, m2.bias.grad)
|
|
|
|
def test_module_backward_hooks_multi_layers(self):
|
|
a1, inp1 = init_module_bw_hooks(True)
|
|
b1, _ = init_module_bw_hooks(True)
|
|
out1 = steps(torch.nn.Sequential(a1, b1), inp1)
|
|
|
|
a2, inp2 = init_module_bw_hooks(False)
|
|
b2, _ = init_module_bw_hooks(False)
|
|
with compiled_autograd.enable(torch.compile(fullgraph=True)):
|
|
out2 = steps(
|
|
torch.compile(torch.nn.Sequential(a2, b2), fullgraph=True), inp2
|
|
)
|
|
|
|
self.assertEqual(a1.hook_count_pre, a2.hook_count_pre)
|
|
self.assertEqual(a1.hook_count_post, a2.hook_count_post)
|
|
self.assertEqual(b1.hook_count_pre, b2.hook_count_pre)
|
|
self.assertEqual(b1.hook_count_post, b2.hook_count_post)
|
|
self.assertEqual(out1, out2)
|
|
self.assertEqual(inp1.grad, inp2.grad)
|
|
self.assertEqual(a1.weight.grad, a2.weight.grad)
|
|
self.assertEqual(a1.bias.grad, a2.bias.grad)
|
|
self.assertEqual(b1.weight.grad, b2.weight.grad)
|
|
self.assertEqual(b1.bias.grad, b2.bias.grad)
|
|
|
|
# TODO(jansel): support bw hooks with graph break
|
|
|
|
def _assert_same_grad(self, a, b):
|
|
self.assertEqual(type(a), type(b))
|
|
self.assertEqual(a, b)
|
|
self.assertEqual(a.grad, b.grad)
|
|
self.assertEqual(a.requires_grad, b.requires_grad)
|
|
|
|
def test_nn_param_return1(self):
|
|
def fn(x):
|
|
p = torch.nn.Parameter(x)
|
|
return p, p.sin()
|
|
|
|
opt = torch.compile(fn, fullgraph=True)
|
|
x1 = torch.randn(16)
|
|
x2 = x1.clone()
|
|
|
|
p1, r1 = fn(x1)
|
|
r1.sum().backward()
|
|
p2, r2 = opt(x2)
|
|
r2.sum().backward()
|
|
self._assert_same_grad(r1, r2)
|
|
self._assert_same_grad(p1, p2)
|
|
|
|
def test_nn_param_return2(self):
|
|
def fn(x):
|
|
p = torch.nn.Parameter(x, requires_grad=False)
|
|
return p, x + 1
|
|
|
|
opt = torch.compile(fn, fullgraph=True)
|
|
x1 = torch.randn(16)
|
|
x2 = x1.clone()
|
|
|
|
p1, r1 = fn(x1)
|
|
p2, r2 = opt(x2)
|
|
self._assert_same_grad(r1, r2)
|
|
self._assert_same_grad(p1, p2)
|
|
|
|
def test_nn_param_return3(self):
|
|
def fn(x):
|
|
p = torch.nn.Parameter(x + 123)
|
|
return p, p.sin()
|
|
|
|
opt = torch.compile(fn, fullgraph=True)
|
|
x1 = torch.randn(16)
|
|
x2 = x1.clone()
|
|
|
|
p1, r1 = fn(x1)
|
|
r1.sum().backward()
|
|
p2, r2 = opt(x2)
|
|
r2.sum().backward()
|
|
self._assert_same_grad(r1, r2)
|
|
self._assert_same_grad(p1, p2)
|
|
|
|
def test_nn_param_return4(self):
|
|
def fn(x):
|
|
p = torch.nn.Parameter(x + 123, requires_grad=False)
|
|
return p, x + 1
|
|
|
|
opt = torch.compile(fn, fullgraph=True)
|
|
x1 = torch.randn(16)
|
|
x2 = x1.clone()
|
|
|
|
p1, r1 = fn(x1)
|
|
p2, r2 = opt(x2)
|
|
self._assert_same_grad(r1, r2)
|
|
self._assert_same_grad(p1, p2)
|
|
|
|
@torch._functorch.config.patch(recompute_views=True)
|
|
def test_fake_distributed_aot_eager(self):
|
|
m1, inp1 = init_fake_distributed()
|
|
out1 = steps(m1, inp1)
|
|
|
|
m2, inp2 = init_fake_distributed()
|
|
m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
|
|
bw_cnt = CompileCounter()
|
|
with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)):
|
|
out2 = steps(m2, inp2)
|
|
|
|
self._assert_same_grad(m1.weight, m2.weight)
|
|
self._assert_same_grad(inp1, inp2)
|
|
self._assert_same_grad(out1, out2)
|
|
# Recompile on grad==None/grad!=None
|
|
self.assertEqual(bw_cnt.frame_count, 2)
|
|
|
|
@skipIfRocm
|
|
@skipIfXpu
|
|
@requires_gpu()
|
|
@torch._functorch.config.patch(recompute_views=True)
|
|
def test_fake_distributed_inductor(self):
|
|
m1, inp1 = init_fake_distributed(GPU_TYPE)
|
|
out1 = steps(m1, inp1)
|
|
|
|
m2, inp2 = init_fake_distributed(GPU_TYPE)
|
|
m2 = torch.compile(m2, fullgraph=True)
|
|
with compiled_autograd.enable(torch.compile(fullgraph=True)):
|
|
out2 = steps(m2, inp2)
|
|
|
|
self._assert_same_grad(m1.weight, m2.weight)
|
|
self._assert_same_grad(inp1, inp2)
|
|
self._assert_same_grad(out1, out2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if HAS_CPU and not IS_MACOS:
|
|
run_tests(needs="filelock")
|