mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add testing regarding SparseAdam state_dicts (#130645)
Summary: - Updated SparseAdam to run test_state_dict_deterministic unit test. - Made gradients sparse while keeping weights dense in the above test. Test Plan: - Ran test_optim.py locally. Fixes #116507 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130645 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
168e41009b
commit
e57101d927
@ -1327,6 +1327,11 @@ class TestOptimRenewed(TestCase):
|
||||
optim.zero_grad()
|
||||
loss = (w.mv(i) + b).pow(2).sum()
|
||||
loss.backward()
|
||||
if optim_info.only_supports_sparse_grads:
|
||||
if w.grad is not None:
|
||||
w.grad = w.grad.to_sparse()
|
||||
if b.grad is not None:
|
||||
b.grad = b.grad.to_sparse()
|
||||
return loss
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
|
@ -1825,13 +1825,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
skipIfMps, # SparseAdam does not support MPS
|
||||
"TestOptimRenewed",
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip(
|
||||
"SparseAdam does not support dense gradients, see #116507"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
|
Reference in New Issue
Block a user