[optim][adagrad] default to foreach when CUDA + differentiable=False (#92716)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92716
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2023-01-20 22:58:47 +00:00
committed by PyTorch MergeBot
parent 6f1727b288
commit 9b4a778420

View File

@ -2,7 +2,7 @@ import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
_differentiable_doc, _maximize_doc)
_default_to_foreach, _differentiable_doc, _foreach_doc, _maximize_doc)
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional
@ -172,15 +172,14 @@ Adagrad.__doc__ = r"""Implements Adagrad algorithm.
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-10)
foreach (bool, optional): whether foreach implementation of optimizer
is used (default: None)
{foreach}
{maximize}
{differentiable}
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
Optimization: http://jmlr.org/papers/v12/duchi11a.html
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc)
""".format(foreach=_foreach_doc, maximize=_maximize_doc, differentiable=_differentiable_doc)
def adagrad(
@ -191,7 +190,7 @@ def adagrad(
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting these as kwargs for now as functional API is compiled by torch/distributed/optim
has_sparse_grad: bool = None,
foreach: bool = None,
foreach: Optional[bool] = None,
differentiable: bool = False,
*,
lr: float,
@ -211,8 +210,7 @@ def adagrad(
)
if foreach is None:
# Placeholder for more complex foreach logic to be added when value is not set
foreach = False
foreach = _default_to_foreach([params, grads, state_sums, state_steps], differentiable=differentiable)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")