mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix(optim): adagrad sparse multitensor incorrect early exit (#110454)
Fixes https://github.com/pytorch/pytorch/issues/110444#issuecomment-1745181530 This PR: Passes Main: ``` test/optim/test_optim.py::TestOptim::test_adagrad_sparse FAILED [0.0058s] ==================================================================================================================================== FAILURES ===================================================================================================================================== __________________________________________________________________________________________________________________________ TestOptim.test_adagrad_sparse __________________________________________________________________________________________________________________________ Traceback (most recent call last): File "/home/jonch/Desktop/Programming/mlsys/pytorch/test/optim/test_optim.py", line 1448, in test_adagrad_sparse self._test_rosenbrock_sparse( File "/home/jonch/Desktop/Programming/mlsys/pytorch/test/optim/test_optim.py", line 128, in _test_rosenbrock_sparse self.assertEqual(params, params_c, atol=1e-6, rtol=1e-6) File "/home/jonch/Desktop/Programming/mlsys/pytorch/torch/testing/_internal/common_utils.py", line 3309, in assertEqual raise error_metas.pop()[0].to_error( AssertionError: Tensor-likes are not close! Mismatched elements: 1 / 2 (50.0%) Greatest absolute difference: 0.09999999999993325 at index (1,) (up to 1e-06 allowed) Greatest relative difference: 0.06249999999996089 at index (1,) (up to 1e-06 allowed) ``` CC: @janeyx99 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110454 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
ecdd1bcf03
commit
c99de9f37c
@ -60,34 +60,37 @@ class TestOptim(TestCase):
|
||||
scheduler_constructors=None,
|
||||
sparse_only=False,
|
||||
maximize=False,
|
||||
multi_tensor=False
|
||||
):
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
|
||||
param_t = torch.tensor([1.5, 1.5])
|
||||
if multi_tensor:
|
||||
params_t = [torch.tensor([1.5, 1.5]), torch.tensor([1.5, 1.5], dtype=torch.float64)]
|
||||
else:
|
||||
params_t = [torch.tensor([1.5, 1.5])]
|
||||
|
||||
param = Parameter(param_t)
|
||||
optimizer = constructor([param])
|
||||
params = [Parameter(param_t) for param_t in params_t]
|
||||
optimizer = constructor(params)
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
if not sparse_only:
|
||||
param_c = Parameter(param_t.clone())
|
||||
optimizer_c = constructor([param_c])
|
||||
params_c = [Parameter(param_t.clone()) for param_t in params_t]
|
||||
optimizer_c = constructor(params_c)
|
||||
|
||||
solution = torch.tensor([1, 1])
|
||||
with torch.no_grad():
|
||||
initial_dist = param.dist(solution)
|
||||
initial_dist = sum([param.dist(solution) for param in params])
|
||||
|
||||
def eval(param, sparse_grad, w):
|
||||
# Depending on w, provide only the x or y gradient
|
||||
optimizer.zero_grad()
|
||||
loss = rosenbrock(param)
|
||||
loss.backward()
|
||||
def get_grad(param, sparse_grad):
|
||||
grad = drosenbrock(param)
|
||||
# NB: We torture test the optimizer by returning an
|
||||
# uncoalesced sparse tensor
|
||||
|
||||
# Depending on w, provide only the x or y gradient
|
||||
if sparse_grad:
|
||||
if w:
|
||||
i = torch.LongTensor([[0, 0]])
|
||||
x = grad[0]
|
||||
@ -96,31 +99,54 @@ class TestOptim(TestCase):
|
||||
i = torch.LongTensor([[1, 1]])
|
||||
y = grad[1]
|
||||
v = torch.tensor([y - y / 4.0, y / 4.0])
|
||||
x = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
|
||||
with torch.no_grad():
|
||||
if sparse_grad:
|
||||
param.grad = x
|
||||
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
|
||||
else:
|
||||
param.grad = x.to_dense()
|
||||
if w:
|
||||
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
|
||||
else:
|
||||
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
|
||||
return grad_out
|
||||
|
||||
def eval(params, sparse_grad, w):
|
||||
optimizer.zero_grad()
|
||||
if multi_tensor:
|
||||
loss = sum(rosenbrock(param) for param in params)
|
||||
else:
|
||||
loss = rosenbrock(params[0])
|
||||
loss.backward()
|
||||
|
||||
grads_out = [get_grad(param, sparse_grad) for param in params]
|
||||
with torch.no_grad():
|
||||
params[0].grad = grads_out[0]
|
||||
if multi_tensor:
|
||||
params[1].grad = grads_out[1].to(dtype=torch.float64)
|
||||
return loss
|
||||
|
||||
for i in range(2000):
|
||||
# Do cyclic coordinate descent
|
||||
w = i % 2
|
||||
optimizer.step(functools.partial(eval, param, True, w))
|
||||
optimizer.step(functools.partial(eval, params, True, w))
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
scheduler.step(rosenbrock(param))
|
||||
scheduler.step(rosenbrock(params[0]))
|
||||
else:
|
||||
scheduler.step()
|
||||
if not sparse_only:
|
||||
optimizer_c.step(functools.partial(eval, param_c, False, w))
|
||||
self.assertEqual(param, param_c)
|
||||
optimizer_c.step(functools.partial(eval, params_c, False, w))
|
||||
# Tolerance is increased due to floating point error from different
|
||||
# code path for dense case: x v.s. x - x / 4.0 + x / 4.0
|
||||
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
|
||||
|
||||
if not maximize:
|
||||
self.assertLessEqual(param.dist(solution), initial_dist)
|
||||
self.assertLessEqual(
|
||||
sum([param.dist(solution) for param in params]),
|
||||
initial_dist
|
||||
)
|
||||
else:
|
||||
self.assertGreaterEqual(rosenbrock(param), rosenbrock(param_t))
|
||||
self.assertGreaterEqual(
|
||||
sum([rosenbrock(param) for param in params]),
|
||||
sum([rosenbrock(param_t) for param_t in params_t]),
|
||||
)
|
||||
|
||||
def _test_basic_cases_template(
|
||||
self,
|
||||
@ -597,7 +623,7 @@ class TestOptim(TestCase):
|
||||
def test_sgd_sparse(self):
|
||||
for foreach in (False, True):
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: optim.SGD(params, lr=4.8e-3, foreach=foreach)
|
||||
lambda params: optim.SGD(params, lr=4.8e-3, foreach=foreach),
|
||||
)
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: optim.SGD(params, lr=0.0048, foreach=foreach),
|
||||
@ -1432,7 +1458,8 @@ class TestOptim(TestCase):
|
||||
def test_adagrad_sparse(self):
|
||||
for foreach in (False, True):
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: optim.Adagrad(params, lr=1e-1, foreach=foreach)
|
||||
lambda params: optim.Adagrad(params, lr=1e-1, foreach=foreach),
|
||||
multi_tensor=foreach,
|
||||
)
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach),
|
||||
@ -1440,6 +1467,7 @@ class TestOptim(TestCase):
|
||||
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
|
||||
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
|
||||
],
|
||||
multi_tensor=foreach,
|
||||
)
|
||||
|
||||
def test_adagrad_complex(self):
|
||||
|
@ -324,7 +324,7 @@ def _multi_tensor_adagrad(
|
||||
device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
|
||||
|
||||
if device_has_sparse_grad:
|
||||
return _single_tensor_adagrad(
|
||||
_single_tensor_adagrad(
|
||||
device_params,
|
||||
device_grads,
|
||||
device_state_sums,
|
||||
@ -337,6 +337,7 @@ def _multi_tensor_adagrad(
|
||||
maximize=False,
|
||||
differentiable=differentiable,
|
||||
)
|
||||
continue
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads)
|
||||
|
Reference in New Issue
Block a user