mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[optim][adagrad] default to foreach when CUDA + differentiable=False (#92716)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92716 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6f1727b288
commit
9b4a778420
@ -2,7 +2,7 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
|
||||
_differentiable_doc, _maximize_doc)
|
||||
_default_to_foreach, _differentiable_doc, _foreach_doc, _maximize_doc)
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
from typing import List, Optional
|
||||
|
||||
@ -172,15 +172,14 @@ Adagrad.__doc__ = r"""Implements Adagrad algorithm.
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-10)
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
{foreach}
|
||||
{maximize}
|
||||
{differentiable}
|
||||
|
||||
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
|
||||
Optimization: http://jmlr.org/papers/v12/duchi11a.html
|
||||
|
||||
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc)
|
||||
""".format(foreach=_foreach_doc, maximize=_maximize_doc, differentiable=_differentiable_doc)
|
||||
|
||||
|
||||
def adagrad(
|
||||
@ -191,7 +190,7 @@ def adagrad(
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting these as kwargs for now as functional API is compiled by torch/distributed/optim
|
||||
has_sparse_grad: bool = None,
|
||||
foreach: bool = None,
|
||||
foreach: Optional[bool] = None,
|
||||
differentiable: bool = False,
|
||||
*,
|
||||
lr: float,
|
||||
@ -211,8 +210,7 @@ def adagrad(
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
# Placeholder for more complex foreach logic to be added when value is not set
|
||||
foreach = False
|
||||
foreach = _default_to_foreach([params, grads, state_sums, state_steps], differentiable=differentiable)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
Reference in New Issue
Block a user