More accurate is_bw and prompt parents cleanup for ModuleTracker utils (#125634)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125634
Approved by: https://github.com/soulitzer, https://github.com/Chillee
This commit is contained in:
albanD
2024-05-07 20:57:33 +00:00
committed by PyTorch MergeBot
parent fdfef759a6
commit c5e04a4479
2 changed files with 30 additions and 10 deletions

View File

@ -60,6 +60,14 @@ class TestModuleTracker(TestCase):
],
)
def test_bw_detection(self):
mod = torch.nn.Linear(2, 2)
with ModuleTracker() as tracker:
mod(torch.rand(2, requires_grad=True)).sum().backward()
self.assertFalse(tracker.is_bw)
self.assertEqual(tracker.parents, {"Global"})
if __name__ == "__main__":
run_tests()