Add testing regarding SparseAdam state_dicts (#130645)

Summary:
- Updated SparseAdam to run test_state_dict_deterministic unit test.
- Made gradients sparse while keeping weights dense in the above test.

Test Plan:
- Ran test_optim.py locally.

Fixes #116507

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130645
Approved by: https://github.com/janeyx99
This commit is contained in:
Jovian Anthony Jaison
2024-07-16 11:29:20 +00:00
committed by PyTorch MergeBot
parent 168e41009b
commit e57101d927
2 changed files with 5 additions and 7 deletions

View File

@ -1327,6 +1327,11 @@ class TestOptimRenewed(TestCase):
optim.zero_grad()
loss = (w.mv(i) + b).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
if w.grad is not None:
w.grad = w.grad.to_sparse()
if b.grad is not None:
b.grad = b.grad.to_sparse()
return loss
for optim_input in all_optim_inputs:

View File

@ -1825,13 +1825,6 @@ optim_db: List[OptimizerInfo] = [
skipIfMps, # SparseAdam does not support MPS
"TestOptimRenewed",
),
DecorateInfo(
unittest.skip(
"SparseAdam does not support dense gradients, see #116507"
),
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",