Files
pytorch/test/inductor/test_distributed_patterns.py
rzou 5531fafffe [compiled autograd] Proxy opaque nodes for built-in autograd nodes (#143296)
This PR is on the way to getting compiled autograd's initial capture to
stop specializing on Tensor metadata.

This PR changes compiled autograd's initial capture to proxy an opaque
(w.r.t. Dynamo) function into the graph for all built-in codegen'ed
autograd nodes and validate_outputs.

We changed each codegen'ed apply_with_saved (e.g.
MulBackward0::apply_with_saved) to call into Python to proxy a function
(compiled_autograd.ops.MulBackward0) into the graph. Then, we use the
node's InputMetadata to "guess" at the properties of the output Tensors
to create some new FakeTensors.

Some details:
- MulBackward0::apply_with_saved lives in libtorch_cpu, but needs to be
  call to Python via libtorch_python. There is an indirection
  (PyCompilerInterface) to do this.
- MulBackward0::apply_with_saved passes a C++ function to Python. To make
  our lives easier, every codegen'ed apply_with_saved passes a C++
  function with the same signature
  `(variable_list, ivalue_list) -> variable_list`.
- We define how to pack arbitrary C++ types into IValue via a helper
  IValuePacker struct and codegen functional variants of each builtin
  C++ autograd node (e.g. MulBackward0_apply_functional_ivalue).

MulBackward0 before this PR:
https://gist.github.com/zou3519/a80381d5fa38e970e413fcd91b0530de

MulBackward0 after this PR:
https://gist.github.com/zou3519/0c2eee8b3d8d96232b51ef430b53c5b0

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143296
Approved by: https://github.com/jansel
2025-01-22 21:50:29 +00:00

505 lines
16 KiB
Python

# Owner(s): ["oncall: pt2"]
import dataclasses
import functools
import torch
from torch import nn
from torch._dynamo import compiled_autograd
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm, skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, requires_gpu
# Fake distributed
WORLD_SIZE = 2
def init_fake_distributed(device="cpu"):
@torch.no_grad
def all_gather(t):
return torch.cat([t] * WORLD_SIZE, 0)
@torch.no_grad
def reduce_scatter(t):
# clone since reduce_scatter input and output should not be aliases.
return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone()
def fw_pre_hook(mod, inp):
mod.unsharded_weight.untyped_storage().resize_(
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
)
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.unsharded_weight
):
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
mod._parameters["weight"] = mod.unsharded_weight
# Forward:
# mod.sharded_weight = local_shard (always)
# Before:
# mod.weight = local_shard
# mod.unsharded_weight = zero-sized allgather
# After:
# mod.weight = local_shard
# mod.unsharded_weight = zero-sized allgather
def fw_post_hook(mod, inp, out):
mod._parameters["weight"] = mod.sharded_weight
mod.unsharded_weight.untyped_storage().resize_(0)
def bw_pre_hook(mod, gO):
mod.unsharded_weight.untyped_storage().resize_(
mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size()
)
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
mod.unsharded_weight
):
torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight))
mod._parameters["weight"] = mod.unsharded_weight
# Backward:
# mod.sharded_weight = local_shard (always)
# Before:
# mod.weight = local_shard
# mod.unsharded_weight = zero-sized allgather
# After:
# mod.weight = local_shard
# mod.unsharded_weight = zero-sized allgather
def bw_post_hook(mod, gI, gO):
grad = mod.weight.grad
new_grad = reduce_scatter(grad)
mod._parameters["weight"] = mod.sharded_weight
mod.weight.grad = new_grad
mod.unsharded_weight.untyped_storage().resize_(0)
torch.manual_seed(1234)
m = nn.Linear(20, 10, bias=False, device=device)
# Mimics eager 1st iteration
m.sharded_weight = nn.Parameter(reduce_scatter(m.weight))
m.unsharded_weight = nn.Parameter(all_gather(m.sharded_weight))
m.unsharded_weight.untyped_storage().resize_(0)
m.register_full_backward_pre_hook(bw_pre_hook)
m.register_full_backward_hook(bw_post_hook)
m.register_forward_pre_hook(fw_pre_hook)
m.register_forward_hook(fw_post_hook)
return m, torch.rand(2, 20, requires_grad=True, device=device)
def init_module_bw_hooks(allow_eager):
def bw_pre_hook(mod, gO):
assert allow_eager or torch._dynamo.is_compiling()
assert mod.weight.size() == (10, 10)
mod.hook_count_pre.add_(1)
return (torch.sin(gO[0] + 1.2),)
def bw_post_hook(mod, gI, gO):
assert allow_eager or torch._dynamo.is_compiling()
assert mod.weight.size() == (10, 10)
mod.hook_count_post.add_(1)
return (torch.sin(gI[0] + 3.4),)
torch.manual_seed(1234)
m = nn.Linear(10, 10)
m.hook_count_pre = torch.tensor(0)
m.hook_count_post = torch.tensor(0)
m.register_full_backward_pre_hook(bw_pre_hook)
m.register_full_backward_hook(bw_post_hook)
return m, torch.rand(2, 10, requires_grad=True)
def steps(m, inp):
for _ in range(4):
out = m(inp)
out.sum().backward()
return out
class DistributedPatternTests(TestCase):
def test_intermediate_hook_with_closure(self):
@dataclasses.dataclass
class CustomObj:
val: torch.Tensor
def fn(x, obj):
y = x.sin()
closure_var = y + 1
y.register_hook(lambda grad: grad + obj.val + closure_var)
z = y.sin()
return z
opt = torch.compile(fn, fullgraph=True)
obj1 = CustomObj(torch.tensor(88))
obj2 = CustomObj(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd._enable(
functools.partial(torch.compile, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
def test_intermediate_hook_with_nested_closure(self):
@dataclasses.dataclass
class CustomObj:
val: torch.Tensor
def fn(x, obj):
def run():
y = x.sin()
closure_var = y + 1
y.register_hook(lambda grad: grad + obj.val + closure_var)
z = y.sin()
return z
return run()
opt = torch.compile(fn, fullgraph=True)
obj1 = CustomObj(torch.tensor(88))
obj2 = CustomObj(torch.tensor(99))
x0 = torch.ones(4, requires_grad=True)
x1 = torch.ones(4, requires_grad=True)
x2 = torch.ones(4, requires_grad=True)
x3 = torch.ones(4, requires_grad=True)
fn(x0, obj1).sum().backward()
fn(x1, obj2).sum().backward()
with compiled_autograd._enable(
functools.partial(torch.compile, fullgraph=True)
):
opt(x2, obj1).sum().backward()
opt(x3, obj2).sum().backward()
self.assertEqual(x0.grad, x2.grad)
self.assertEqual(x1.grad, x3.grad)
@torch.no_grad()
def _test_storage_resize_zero(self, device):
@torch.compile(fullgraph=True)
def fn(x):
y = torch.sin(x)
x.untyped_storage().resize_(0)
return torch.cos(y)
x = torch.randn(10, device=device)
expected = torch.cos(torch.sin(x))
y = fn(x)
self.assertEqual(y, expected)
self.assertEqual(x.untyped_storage().size(), 0)
def test_storage_resize_zero_cpu(self):
self._test_storage_resize_zero("cpu")
@skipIfRocm
@requires_gpu()
def test_storage_resize_zero_gpu(self):
self._test_storage_resize_zero(GPU_TYPE)
@torch.no_grad()
def _test_storage_resize_nonzero(self, device):
@torch.compile(fullgraph=True)
def fn(x, out):
y = torch.sin(x)
assert out.untyped_storage().size() == 0
out.untyped_storage().resize_(x.untyped_storage().size())
out.copy_(y.cos())
x = torch.randn(10, device=device)
out = torch.randn(10, device=device)
expected = torch.cos(torch.sin(x))
out.untyped_storage().resize_(0)
fn(x, out)
self.assertEqual(out.untyped_storage().size(), x.untyped_storage().size())
self.assertEqual(out, expected)
def test_storage_resize_nonzero_cpu(self):
self._test_storage_resize_nonzero("cpu")
@skipIfRocm
@requires_gpu()
def test_storage_resize_nonzero_gpu(self):
self._test_storage_resize_nonzero(GPU_TYPE)
@torch.no_grad()
def test_unsafe_set_version_counter1(self):
cnt = CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(w, x):
x = x.sin()
v = w._version
w.copy_(x + 1)
torch._C._autograd._unsafe_set_version_counter(w, v)
return w, v
for v in (3, 0, 1):
w1 = torch.randn(16)
for i in range(v):
w1.fill_(i) # bump w1._version
self.assertEqual(w1._version, v)
x1 = torch.randn(16)
w2, v2 = fn(w1, x1)
self.assertIs(w1, w2)
self.assertEqual(w1, x1.sin() + 1)
self.assertEqual(v2, v)
self.assertEqual(w1._version, v)
self.assertEqual(cnt.frame_count, 1)
def test_unsafe_set_version_counter2(self):
@torch.compile(backend="inductor", fullgraph=True)
def fn(w, x):
r = w.sin()
with torch.no_grad():
v = w._version
w.copy_(x)
torch._C._autograd._unsafe_set_version_counter(w, v)
return r
w1 = torch.randn(1, requires_grad=True)
x1 = torch.randn(1)
expected_r1 = w1.detach().sin()
r1 = fn(w1, x1)
r1.backward()
self.assertEqual(r1, expected_r1)
self.assertEqual(w1, x1)
self.assertEqual(w1.grad, x1.cos())
@torch.no_grad()
def test_unsafe_preserve_version_counter1(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(w, x):
x = x.sin()
with torch.autograd._unsafe_preserve_version_counter(w):
w.copy_(x + 1)
return w
w1 = torch.randn(16).fill_(0).fill_(1)
x1 = torch.randn(16)
v1 = w1._version
w2 = fn(w1, x1)
v2 = w1._version
self.assertIs(w1, w2)
self.assertEqual(w1, x1.sin() + 1)
self.assertEqual(v1, v2)
def test_unsafe_preserve_version_counter2(self):
@torch.compile(backend="inductor", fullgraph=True)
def fn(w, x):
r = w.sin()
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(w):
w.copy_(x)
return r
w1 = torch.randn(1, requires_grad=True)
x1 = torch.randn(1)
expected_r1 = w1.detach().sin()
r1 = fn(w1, x1)
r1.backward()
self.assertEqual(r1, expected_r1)
self.assertEqual(w1, x1)
self.assertEqual(w1.grad, x1.cos())
def test_module_backward_hooks_eager(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(False)
fw_cnt = CompileCounter()
bw_cnt = CompileCounter()
with compiled_autograd._enable(torch.compile(backend=bw_cnt, fullgraph=True)):
m2 = torch.compile(m2, backend=fw_cnt, fullgraph=True)
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
self.assertEqual(fw_cnt.frame_count, 1)
self.assertEqual(fw_cnt.op_count, 5)
self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None
self.assertEqual(
bw_cnt.op_count, 72
) # Number of ops in the Dynamo-produced graphs
def test_module_backward_hooks_aot(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(True)
m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
with compiled_autograd._enable(lambda gm: gm):
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
def test_module_backward_hooks_inductor(self):
m1, inp1 = init_module_bw_hooks(True)
out1 = steps(m1, inp1)
m2, inp2 = init_module_bw_hooks(False)
m2 = torch.compile(m2, fullgraph=True)
with compiled_autograd._enable(torch.compile(fullgraph=True)):
out2 = steps(m2, inp2)
self.assertEqual(m1.hook_count_pre, m2.hook_count_pre)
self.assertEqual(m1.hook_count_post, m2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(m1.weight.grad, m2.weight.grad)
self.assertEqual(m1.bias.grad, m2.bias.grad)
def test_module_backward_hooks_multi_layers(self):
a1, inp1 = init_module_bw_hooks(True)
b1, _ = init_module_bw_hooks(True)
out1 = steps(torch.nn.Sequential(a1, b1), inp1)
a2, inp2 = init_module_bw_hooks(False)
b2, _ = init_module_bw_hooks(False)
with compiled_autograd._enable(torch.compile(fullgraph=True)):
out2 = steps(
torch.compile(torch.nn.Sequential(a2, b2), fullgraph=True), inp2
)
self.assertEqual(a1.hook_count_pre, a2.hook_count_pre)
self.assertEqual(a1.hook_count_post, a2.hook_count_post)
self.assertEqual(b1.hook_count_pre, b2.hook_count_pre)
self.assertEqual(b1.hook_count_post, b2.hook_count_post)
self.assertEqual(out1, out2)
self.assertEqual(inp1.grad, inp2.grad)
self.assertEqual(a1.weight.grad, a2.weight.grad)
self.assertEqual(a1.bias.grad, a2.bias.grad)
self.assertEqual(b1.weight.grad, b2.weight.grad)
self.assertEqual(b1.bias.grad, b2.bias.grad)
# TODO(jansel): support bw hooks with graph break
def _assert_same_grad(self, a, b):
self.assertEqual(type(a), type(b))
self.assertEqual(a, b)
self.assertEqual(a.grad, b.grad)
self.assertEqual(a.requires_grad, b.requires_grad)
def test_nn_param_return1(self):
def fn(x):
p = torch.nn.Parameter(x)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
r1.sum().backward()
p2, r2 = opt(x2)
r2.sum().backward()
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return2(self):
def fn(x):
p = torch.nn.Parameter(x, requires_grad=False)
return p, x + 1
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
p2, r2 = opt(x2)
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return3(self):
def fn(x):
p = torch.nn.Parameter(x + 123)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
r1.sum().backward()
p2, r2 = opt(x2)
r2.sum().backward()
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
def test_nn_param_return4(self):
def fn(x):
p = torch.nn.Parameter(x + 123, requires_grad=False)
return p, x + 1
opt = torch.compile(fn, fullgraph=True)
x1 = torch.randn(16)
x2 = x1.clone()
p1, r1 = fn(x1)
p2, r2 = opt(x2)
self._assert_same_grad(r1, r2)
self._assert_same_grad(p1, p2)
@torch._functorch.config.patch(recompute_views=True)
def test_fake_distributed_aot_eager(self):
m1, inp1 = init_fake_distributed()
out1 = steps(m1, inp1)
m2, inp2 = init_fake_distributed()
m2 = torch.compile(m2, backend="aot_eager", fullgraph=True)
bw_cnt = CompileCounter()
with compiled_autograd._enable(torch.compile(backend=bw_cnt, fullgraph=True)):
out2 = steps(m2, inp2)
self._assert_same_grad(m1.weight, m2.weight)
self._assert_same_grad(inp1, inp2)
self._assert_same_grad(out1, out2)
# Recompile on grad==None/grad!=None
self.assertEqual(bw_cnt.frame_count, 2)
@skipIfRocm
@skipIfXpu
@requires_gpu()
@torch._functorch.config.patch(recompute_views=True)
def test_fake_distributed_inductor(self):
m1, inp1 = init_fake_distributed(GPU_TYPE)
out1 = steps(m1, inp1)
m2, inp2 = init_fake_distributed(GPU_TYPE)
m2 = torch.compile(m2, fullgraph=True)
with compiled_autograd._enable(torch.compile(fullgraph=True)):
out2 = steps(m2, inp2)
self._assert_same_grad(m1.weight, m2.weight)
self._assert_same_grad(inp1, inp2)
self._assert_same_grad(out1, out2)
if __name__ == "__main__":
if HAS_CPU and not IS_MACOS:
run_tests(needs="filelock")