Use the run_subtests utility instead of self.subTest (#94983)

The use of run_subtests utility is a better test practice.

Related #84071

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94983
Approved by: https://github.com/awgu
This commit is contained in:
Sahdev Zala
2023-02-16 22:13:10 +00:00
committed by PyTorch MergeBot
parent ee0e7f0529
commit e0106e1850

View File

@ -95,6 +95,20 @@ class TestFSDPMisc(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_fsdp_not_all_outputs_used_in_loss(self):
self.run_subtests(
{
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
]
},
self._test_fsdp_not_all_outputs_used_in_loss,
)
def _test_fsdp_not_all_outputs_used_in_loss(
self, sharding_strategy: ShardingStrategy
):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
@ -120,58 +134,52 @@ class TestFSDPMisc(FSDPTest):
for p1, p2 in zip(fsdp.parameters(), local.parameters()):
torch.testing.assert_close(p1, p2)
for sharding_strategy in [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
]:
with self.subTest(sharding_strategy=sharding_strategy):
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
m = MyModule().cuda()
m_local = deepcopy(m)
local_m = m_local
prev_params = [p.clone() for p in m_local.parameters()]
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
m = MyModule().cuda()
m_local = deepcopy(m)
local_m = m_local
prev_params = [p.clone() for p in m_local.parameters()]
m.lin1 = fsdp_ctor(m.lin1)
m = fsdp_ctor(m)
_check_equal(m_local, m)
m.lin1 = fsdp_ctor(m.lin1)
m = fsdp_ctor(m)
_check_equal(m_local, m)
opt = torch.optim.SGD(m.parameters(), lr=1e-3)
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
opt = torch.optim.SGD(m.parameters(), lr=1e-3)
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
for i in range(6):
t = torch.ones(4, device="cuda")
a, b = m(t)
local_a, local_b = local_m(t)
if i < 2:
# use both params in loss computation. Later,
# b will go unused and we check grads are the
# same as local training.
loss = (a @ b).sum()
loss_local = (local_a @ local_b).sum()
else:
loss = a.sum()
loss_local = local_a.sum()
for i in range(6):
t = torch.ones(4, device="cuda")
a, b = m(t)
local_a, local_b = local_m(t)
if i < 2:
# use both params in loss computation. Later,
# b will go unused and we check grads are the
# same as local training.
loss = (a @ b).sum()
loss_local = (local_a @ local_b).sum()
else:
loss = a.sum()
loss_local = local_a.sum()
loss.backward()
loss_local.backward()
_check_resharded(m)
opt.step()
opt_local.step()
_check_equal(m_local, m)
# Ensure at least some change from previous params, otherwise
# above check would be vacuously true.
self.assertTrue(
any(
not torch.equal(p1, p2)
for p1, p2 in zip(prev_params, m_local.parameters())
)
)
prev_params = [p.clone() for p in local_m.parameters()]
opt.zero_grad()
opt_local.zero_grad()
loss.backward()
loss_local.backward()
_check_resharded(m)
opt.step()
opt_local.step()
_check_equal(m_local, m)
# Ensure at least some change from previous params, otherwise
# above check would be vacuously true.
self.assertTrue(
any(
not torch.equal(p1, p2)
for p1, p2 in zip(prev_params, m_local.parameters())
)
)
prev_params = [p.clone() for p in local_m.parameters()]
opt.zero_grad()
opt_local.zero_grad()
dist.barrier()
dist.barrier()
@skip_if_lt_x_gpu(2)
@parametrize("use_second_layer", [True, False])