Fix SparseAdam consuming iterator (#86210)

Fixes https://github.com/pytorch/pytorch/issues/86209
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86210
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Tongzhou Wang
2022-10-06 23:11:22 +00:00
committed by PyTorch MergeBot
parent f0977c4658
commit 5ed75ec1d7

View File

@ -40,7 +40,9 @@ class SparseAdam(Optimizer):
sparse_params = []
for index, param in enumerate(params):
if isinstance(param, dict):
for d_index, d_param in enumerate(param.get("params", [])):
# given param group, convert given params to a list first before iterating
param['params'] = list(param.get("params", []))
for d_index, d_param in enumerate(param['params']):
if d_param.is_sparse:
sparse_params.append([index, d_index])
elif param.is_sparse: