mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
I am trying to give some test files better owner labels than `module: unknown`. I am not sure them, but they seem pretty reasonable Pull Request resolved: https://github.com/pytorch/pytorch/pull/163174 Approved by: https://github.com/ezyang
203 lines
6.9 KiB
Python
203 lines
6.9 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
from copy import copy
|
|
|
|
import torch
|
|
from torch.distributed._tools.mod_tracker import ModTracker
|
|
from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
class TestModTracker(TestCase):
|
|
# "https://github.com/pytorch/pytorch/issues/127112
|
|
@xfailIfTorchDynamo
|
|
def test_module_hierarchy(self):
|
|
seen_fw = []
|
|
seen_bw = []
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x["a"].relu_()
|
|
seen_fw.append((copy(tracker.parents), tracker.is_bw))
|
|
x.register_hook(
|
|
lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw))
|
|
)
|
|
return {"a": torch.mm(x, x)}
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = Foo()
|
|
self.b = torch.nn.ModuleDict({"nest": Foo()})
|
|
self.c = torch.nn.ModuleList([Foo()])
|
|
|
|
def forward(self, x):
|
|
x = self.c[0](x)
|
|
return self.b["nest"](self.a(x))
|
|
|
|
mod = Mod()
|
|
|
|
with ModTracker() as tracker:
|
|
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
|
|
"a"
|
|
].sum().backward()
|
|
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
|
|
"a"
|
|
].sum().backward()
|
|
|
|
self.assertEqual(
|
|
seen_fw,
|
|
[
|
|
({"Global", "Mod", "Mod.c.0"}, False),
|
|
({"Global", "Mod", "Mod.a"}, False),
|
|
({"Global", "Mod", "Mod.b.nest"}, False),
|
|
({"Global", "Mod", "Mod.c.0"}, False),
|
|
({"Global", "Mod", "Mod.a"}, False),
|
|
({"Global", "Mod", "Mod.b.nest"}, False),
|
|
],
|
|
)
|
|
|
|
self.assertEqual(
|
|
seen_bw,
|
|
[
|
|
({"Global", "Mod", "Mod.b.nest"}, True),
|
|
({"Global", "Mod", "Mod.a"}, True),
|
|
({"Global", "Mod", "Mod.c.0"}, True),
|
|
({"Global", "Mod", "Mod.b.nest"}, True),
|
|
({"Global", "Mod", "Mod.a"}, True),
|
|
({"Global", "Mod", "Mod.c.0"}, True),
|
|
],
|
|
)
|
|
|
|
def test_bw_detection(self):
|
|
mod = torch.nn.Linear(2, 2)
|
|
|
|
with ModTracker() as tracker:
|
|
mod(torch.rand(2, requires_grad=True)).sum().backward()
|
|
self.assertFalse(tracker.is_bw)
|
|
self.assertEqual(tracker.parents, {"Global"})
|
|
|
|
@xfailIfTorchDynamo
|
|
def test_user_hooks(self):
|
|
class Bar(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.foo(x).relu_()
|
|
|
|
mt = ModTracker()
|
|
test_op = []
|
|
|
|
def hook(mod, hook_name):
|
|
mfqn = mt.get_known_fqn(mod) if mod is not None else None
|
|
test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw))
|
|
|
|
mod = Bar()
|
|
|
|
mt.register_user_hooks(
|
|
lambda m, inp: hook(m, "pre_fw"),
|
|
lambda m, inp, op: hook(m, "post_fw"),
|
|
lambda m, gop: hook(m, "pre_bw"),
|
|
lambda m, ginp: hook(m, "post_bw"),
|
|
)
|
|
with mt:
|
|
mod(torch.rand(10, 10, requires_grad=True)).sum().backward()
|
|
expected_op = [
|
|
("pre_fw", "Bar", True, False),
|
|
("pre_fw", "Bar.foo", True, False),
|
|
("post_fw", "Bar.foo", True, False),
|
|
("post_fw", "Bar", True, False),
|
|
("pre_bw", "Bar", True, True),
|
|
("pre_bw", "Bar.foo", True, True),
|
|
("post_bw", "Bar", True, True),
|
|
("post_bw", "Bar.foo", True, True),
|
|
]
|
|
self.assertEqual(test_op, expected_op)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
mt.register_user_hooks(lambda x, y: x, None, None, None)
|
|
|
|
test_op.clear()
|
|
with mt:
|
|
loss = mod(torch.rand(10, 10, requires_grad=True)).sum()
|
|
del mod
|
|
loss.backward()
|
|
expected_op = [
|
|
("pre_fw", "Bar", True, False),
|
|
("pre_fw", "Bar.foo", True, False),
|
|
("post_fw", "Bar.foo", True, False),
|
|
("post_fw", "Bar", True, False),
|
|
("pre_bw", None, False, True),
|
|
("pre_bw", None, False, True),
|
|
("post_bw", None, False, True),
|
|
("post_bw", None, False, True),
|
|
]
|
|
self.assertEqual(test_op, expected_op)
|
|
|
|
@xfailIfTorchDynamo
|
|
def test_ac(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, n_layers: int, dim: int, use_ac: bool = False):
|
|
super().__init__()
|
|
self.linears = torch.nn.ModuleList()
|
|
self.use_ac = use_ac
|
|
for _ in range(n_layers):
|
|
self.linears.append(torch.nn.Linear(dim, dim))
|
|
|
|
def forward(self, x):
|
|
for i, block in enumerate(self.linears):
|
|
if i >= 1 and self.use_ac:
|
|
x = checkpoint(
|
|
block, x, preserve_rng_state=True, use_reentrant=False
|
|
)
|
|
else:
|
|
x = block(x)
|
|
assert x is not None
|
|
x = torch.nn.functional.relu(x)
|
|
return x
|
|
|
|
bsz = 2
|
|
dim = 8
|
|
n_layers = 2
|
|
test_op = []
|
|
|
|
def hook(mod, mt, hook_name):
|
|
mfqn = mt.get_known_fqn(mod) if mod is not None else None
|
|
test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw))
|
|
|
|
mt = ModTracker()
|
|
mt.register_user_hooks(
|
|
lambda m, i: hook(m, mt, "pre_fw"),
|
|
lambda m, i, o: hook(m, mt, "post_fw"),
|
|
lambda m, go: hook(m, mt, "pre_bw"),
|
|
lambda m, gi: hook(m, mt, "post_bw"),
|
|
)
|
|
model = Foo(n_layers, dim, True)
|
|
x = torch.randn(bsz, dim)
|
|
with mt:
|
|
model(x).sum().backward()
|
|
|
|
expected_op = [
|
|
("pre_fw", "Foo", True, False),
|
|
("pre_fw", "Foo.linears.0", True, False),
|
|
("post_fw", "Foo.linears.0", True, False),
|
|
("pre_fw", "Foo.linears.1", True, False),
|
|
("post_fw", "Foo.linears.1", True, False),
|
|
("post_fw", "Foo", True, False),
|
|
("pre_bw", "Foo", True, True),
|
|
("pre_bw", "Foo.linears.1", True, True),
|
|
("pre_fw", "Foo.linears.1", True, True),
|
|
("post_fw", "Foo.linears.1", True, True),
|
|
("post_bw", "Foo.linears.1", True, True),
|
|
("pre_bw", "Foo.linears.0", True, True),
|
|
("post_bw", "Foo.linears.0", True, True),
|
|
("post_bw", "Foo", True, True),
|
|
]
|
|
self.assertEqual(test_op, expected_op)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|