Wrote a test checking for KeyErrors.

This commit is contained in:
Samuel Park
2025-10-16 18:00:11 +00:00
parent 53c919957d
commit 8e64cd19c1

View File

@ -180,6 +180,42 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
del model
del optim
def _test_tracker_multihandler_hook(self):
"""Should run without KeyError."""
class TestModule(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.norm = nn.RMSNorm(dim)
self.output = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
x = self.output(x)
return x
gc.collect()
torch.manual_seed(42)
dev = torch.device(torch.accelerator.current_device_index())
with torch.device(dev):
model = TestModule(128)
mesh = init_device_mesh(dev.type, (self.world_size,))
fully_shard([model.norm, model.output], mesh=mesh)
fully_shard(model, mesh=mesh)
fmt = FSDPMemTracker(model)
with fmt:
inp = torch.randn(16, 128, device=dev)
y = model(inp)
loss = y.sum()
loss.backward()
del inp
del model
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
@property