mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Use the run_subtests utility instead of self.subTest (#94983)
The use of run_subtests utility is a better test practice. Related #84071 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94983 Approved by: https://github.com/awgu
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee0e7f0529
commit
e0106e1850
@ -95,6 +95,20 @@ class TestFSDPMisc(FSDPTest):
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fsdp_not_all_outputs_used_in_loss(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"sharding_strategy": [
|
||||
ShardingStrategy.FULL_SHARD,
|
||||
ShardingStrategy.SHARD_GRAD_OP,
|
||||
ShardingStrategy.NO_SHARD,
|
||||
]
|
||||
},
|
||||
self._test_fsdp_not_all_outputs_used_in_loss,
|
||||
)
|
||||
|
||||
def _test_fsdp_not_all_outputs_used_in_loss(
|
||||
self, sharding_strategy: ShardingStrategy
|
||||
):
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -120,58 +134,52 @@ class TestFSDPMisc(FSDPTest):
|
||||
for p1, p2 in zip(fsdp.parameters(), local.parameters()):
|
||||
torch.testing.assert_close(p1, p2)
|
||||
|
||||
for sharding_strategy in [
|
||||
ShardingStrategy.FULL_SHARD,
|
||||
ShardingStrategy.SHARD_GRAD_OP,
|
||||
ShardingStrategy.NO_SHARD,
|
||||
]:
|
||||
with self.subTest(sharding_strategy=sharding_strategy):
|
||||
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
|
||||
m = MyModule().cuda()
|
||||
m_local = deepcopy(m)
|
||||
local_m = m_local
|
||||
prev_params = [p.clone() for p in m_local.parameters()]
|
||||
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
|
||||
m = MyModule().cuda()
|
||||
m_local = deepcopy(m)
|
||||
local_m = m_local
|
||||
prev_params = [p.clone() for p in m_local.parameters()]
|
||||
|
||||
m.lin1 = fsdp_ctor(m.lin1)
|
||||
m = fsdp_ctor(m)
|
||||
_check_equal(m_local, m)
|
||||
m.lin1 = fsdp_ctor(m.lin1)
|
||||
m = fsdp_ctor(m)
|
||||
_check_equal(m_local, m)
|
||||
|
||||
opt = torch.optim.SGD(m.parameters(), lr=1e-3)
|
||||
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
|
||||
opt = torch.optim.SGD(m.parameters(), lr=1e-3)
|
||||
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
|
||||
|
||||
for i in range(6):
|
||||
t = torch.ones(4, device="cuda")
|
||||
a, b = m(t)
|
||||
local_a, local_b = local_m(t)
|
||||
if i < 2:
|
||||
# use both params in loss computation. Later,
|
||||
# b will go unused and we check grads are the
|
||||
# same as local training.
|
||||
loss = (a @ b).sum()
|
||||
loss_local = (local_a @ local_b).sum()
|
||||
else:
|
||||
loss = a.sum()
|
||||
loss_local = local_a.sum()
|
||||
for i in range(6):
|
||||
t = torch.ones(4, device="cuda")
|
||||
a, b = m(t)
|
||||
local_a, local_b = local_m(t)
|
||||
if i < 2:
|
||||
# use both params in loss computation. Later,
|
||||
# b will go unused and we check grads are the
|
||||
# same as local training.
|
||||
loss = (a @ b).sum()
|
||||
loss_local = (local_a @ local_b).sum()
|
||||
else:
|
||||
loss = a.sum()
|
||||
loss_local = local_a.sum()
|
||||
|
||||
loss.backward()
|
||||
loss_local.backward()
|
||||
_check_resharded(m)
|
||||
opt.step()
|
||||
opt_local.step()
|
||||
_check_equal(m_local, m)
|
||||
# Ensure at least some change from previous params, otherwise
|
||||
# above check would be vacuously true.
|
||||
self.assertTrue(
|
||||
any(
|
||||
not torch.equal(p1, p2)
|
||||
for p1, p2 in zip(prev_params, m_local.parameters())
|
||||
)
|
||||
)
|
||||
prev_params = [p.clone() for p in local_m.parameters()]
|
||||
opt.zero_grad()
|
||||
opt_local.zero_grad()
|
||||
loss.backward()
|
||||
loss_local.backward()
|
||||
_check_resharded(m)
|
||||
opt.step()
|
||||
opt_local.step()
|
||||
_check_equal(m_local, m)
|
||||
# Ensure at least some change from previous params, otherwise
|
||||
# above check would be vacuously true.
|
||||
self.assertTrue(
|
||||
any(
|
||||
not torch.equal(p1, p2)
|
||||
for p1, p2 in zip(prev_params, m_local.parameters())
|
||||
)
|
||||
)
|
||||
prev_params = [p.clone() for p in local_m.parameters()]
|
||||
opt.zero_grad()
|
||||
opt_local.zero_grad()
|
||||
|
||||
dist.barrier()
|
||||
dist.barrier()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("use_second_layer", [True, False])
|
||||
|
||||
Reference in New Issue
Block a user