Files
pytorch/test/distributed/_tools/test_mod_tracker.py
Catherine Lee 5b764267f4 [testing] Add test owner labels for some distributed tests (#163174)
I am trying to give some test files better owner labels than `module: unknown`.  I am not sure them, but they seem pretty reasonable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163174
Approved by: https://github.com/ezyang
2025-09-26 18:19:04 +00:00

203 lines
6.9 KiB
Python

# Owner(s): ["oncall: distributed"]
from copy import copy
import torch
from torch.distributed._tools.mod_tracker import ModTracker
from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
from torch.utils.checkpoint import checkpoint
class TestModTracker(TestCase):
# "https://github.com/pytorch/pytorch/issues/127112
@xfailIfTorchDynamo
def test_module_hierarchy(self):
seen_fw = []
seen_bw = []
class Foo(torch.nn.Module):
def forward(self, x):
x = x["a"].relu_()
seen_fw.append((copy(tracker.parents), tracker.is_bw))
x.register_hook(
lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw))
)
return {"a": torch.mm(x, x)}
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = Foo()
self.b = torch.nn.ModuleDict({"nest": Foo()})
self.c = torch.nn.ModuleList([Foo()])
def forward(self, x):
x = self.c[0](x)
return self.b["nest"](self.a(x))
mod = Mod()
with ModTracker() as tracker:
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
"a"
].sum().backward()
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
"a"
].sum().backward()
self.assertEqual(
seen_fw,
[
({"Global", "Mod", "Mod.c.0"}, False),
({"Global", "Mod", "Mod.a"}, False),
({"Global", "Mod", "Mod.b.nest"}, False),
({"Global", "Mod", "Mod.c.0"}, False),
({"Global", "Mod", "Mod.a"}, False),
({"Global", "Mod", "Mod.b.nest"}, False),
],
)
self.assertEqual(
seen_bw,
[
({"Global", "Mod", "Mod.b.nest"}, True),
({"Global", "Mod", "Mod.a"}, True),
({"Global", "Mod", "Mod.c.0"}, True),
({"Global", "Mod", "Mod.b.nest"}, True),
({"Global", "Mod", "Mod.a"}, True),
({"Global", "Mod", "Mod.c.0"}, True),
],
)
def test_bw_detection(self):
mod = torch.nn.Linear(2, 2)
with ModTracker() as tracker:
mod(torch.rand(2, requires_grad=True)).sum().backward()
self.assertFalse(tracker.is_bw)
self.assertEqual(tracker.parents, {"Global"})
@xfailIfTorchDynamo
def test_user_hooks(self):
class Bar(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = torch.nn.Linear(10, 10)
def forward(self, x):
return self.foo(x).relu_()
mt = ModTracker()
test_op = []
def hook(mod, hook_name):
mfqn = mt.get_known_fqn(mod) if mod is not None else None
test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw))
mod = Bar()
mt.register_user_hooks(
lambda m, inp: hook(m, "pre_fw"),
lambda m, inp, op: hook(m, "post_fw"),
lambda m, gop: hook(m, "pre_bw"),
lambda m, ginp: hook(m, "post_bw"),
)
with mt:
mod(torch.rand(10, 10, requires_grad=True)).sum().backward()
expected_op = [
("pre_fw", "Bar", True, False),
("pre_fw", "Bar.foo", True, False),
("post_fw", "Bar.foo", True, False),
("post_fw", "Bar", True, False),
("pre_bw", "Bar", True, True),
("pre_bw", "Bar.foo", True, True),
("post_bw", "Bar", True, True),
("post_bw", "Bar.foo", True, True),
]
self.assertEqual(test_op, expected_op)
with self.assertRaises(AssertionError):
mt.register_user_hooks(lambda x, y: x, None, None, None)
test_op.clear()
with mt:
loss = mod(torch.rand(10, 10, requires_grad=True)).sum()
del mod
loss.backward()
expected_op = [
("pre_fw", "Bar", True, False),
("pre_fw", "Bar.foo", True, False),
("post_fw", "Bar.foo", True, False),
("post_fw", "Bar", True, False),
("pre_bw", None, False, True),
("pre_bw", None, False, True),
("post_bw", None, False, True),
("post_bw", None, False, True),
]
self.assertEqual(test_op, expected_op)
@xfailIfTorchDynamo
def test_ac(self):
class Foo(torch.nn.Module):
def __init__(self, n_layers: int, dim: int, use_ac: bool = False):
super().__init__()
self.linears = torch.nn.ModuleList()
self.use_ac = use_ac
for _ in range(n_layers):
self.linears.append(torch.nn.Linear(dim, dim))
def forward(self, x):
for i, block in enumerate(self.linears):
if i >= 1 and self.use_ac:
x = checkpoint(
block, x, preserve_rng_state=True, use_reentrant=False
)
else:
x = block(x)
assert x is not None
x = torch.nn.functional.relu(x)
return x
bsz = 2
dim = 8
n_layers = 2
test_op = []
def hook(mod, mt, hook_name):
mfqn = mt.get_known_fqn(mod) if mod is not None else None
test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw))
mt = ModTracker()
mt.register_user_hooks(
lambda m, i: hook(m, mt, "pre_fw"),
lambda m, i, o: hook(m, mt, "post_fw"),
lambda m, go: hook(m, mt, "pre_bw"),
lambda m, gi: hook(m, mt, "post_bw"),
)
model = Foo(n_layers, dim, True)
x = torch.randn(bsz, dim)
with mt:
model(x).sum().backward()
expected_op = [
("pre_fw", "Foo", True, False),
("pre_fw", "Foo.linears.0", True, False),
("post_fw", "Foo.linears.0", True, False),
("pre_fw", "Foo.linears.1", True, False),
("post_fw", "Foo.linears.1", True, False),
("post_fw", "Foo", True, False),
("pre_bw", "Foo", True, True),
("pre_bw", "Foo.linears.1", True, True),
("pre_fw", "Foo.linears.1", True, True),
("post_fw", "Foo.linears.1", True, True),
("post_bw", "Foo.linears.1", True, True),
("pre_bw", "Foo.linears.0", True, True),
("post_bw", "Foo.linears.0", True, True),
("post_bw", "Foo", True, True),
]
self.assertEqual(test_op, expected_op)
if __name__ == "__main__":
run_tests()