mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
More accurate is_bw and prompt parents cleanup for ModuleTracker utils (#125634)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125634 Approved by: https://github.com/soulitzer, https://github.com/Chillee
This commit is contained in:
@ -60,6 +60,14 @@ class TestModuleTracker(TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_bw_detection(self):
|
||||
mod = torch.nn.Linear(2, 2)
|
||||
|
||||
with ModuleTracker() as tracker:
|
||||
mod(torch.rand(2, requires_grad=True)).sum().backward()
|
||||
self.assertFalse(tracker.is_bw)
|
||||
self.assertEqual(tracker.parents, {"Global"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user