Add TensorLR variant for fused Adagrad on CPU (#153078)

This PR adds a tensor LR variant for the CPU Adagrad(fused=True).

I copied the behavior from the tensor LR variant of CPU Adam(fused=True), where the `lr.item()` is cast to a double and passed in the default function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153078
Approved by: https://github.com/janeyx99
This commit is contained in:
Meet Patel
2025-05-14 02:23:29 +00:00
committed by PyTorch MergeBot
parent d51bc27378
commit 9ad9a04ca7
3 changed files with 47 additions and 12 deletions

View File

@ -11,7 +11,6 @@
#include <ATen/ops/_fused_adagrad_native.h>
#endif
namespace at::native {
void _fused_adagrad_kernel_cpu_(
@ -31,28 +30,54 @@ void _fused_adagrad_kernel_cpu_(
const float* found_inf_ptr =
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
if (found_inf_ptr && *found_inf_ptr == 1.0) {
return;
return;
}
size_t n_tensors = params.size();
TORCH_CHECK(grads.size() == n_tensors);
TORCH_CHECK(state_sums.size() == n_tensors);
TORCH_CHECK(state_steps.size() == n_tensors);
for (size_t i = 0; i < n_tensors; i++){
for (size_t i = 0; i < n_tensors; i++) {
fused_adagrad_stub(
kCPU,
params[i],
grads[i],
state_sums[i],
state_steps[i],
lr,
kCPU,
params[i],
grads[i],
state_sums[i],
state_steps[i],
lr,
lr_decay,
weight_decay,
eps,
maximize,
grad_scale_ptr);
}
}
void _fused_adagrad_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList state_sums,
at::TensorList state_steps,
const at::Tensor& lr,
const double lr_decay,
const double weight_decay,
const double eps,
const bool maximize,
const std::optional<at::Tensor>& grad_scale,
const std::optional<at::Tensor>& found_inf) {
_fused_adagrad_kernel_cpu_(
params,
grads,
state_sums,
state_steps,
lr.item<double>(),
lr_decay,
weight_decay,
eps,
maximize,
grad_scale_ptr);
}
grad_scale,
found_inf);
}
DEFINE_DISPATCH(fused_adagrad_stub);
}
} // namespace at::native

View File

@ -15856,6 +15856,13 @@
CPU: _fused_adagrad_kernel_cpu_
autogen: _fused_adagrad, _fused_adagrad.out
- func: _fused_adagrad_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
device_check: NoCheck
variants: function
dispatch:
CPU: _fused_adagrad_kernel_cpu_
autogen: _fused_adagrad.tensor_lr, _fused_adagrad.tensor_lr_out
# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
variants: function

View File

@ -356,7 +356,10 @@ aten::_functional_sym_constrain_range
aten::_functional_sym_constrain_range_for_size
aten::_fused_adagrad
aten::_fused_adagrad.out
aten::_fused_adagrad.tensor_lr
aten::_fused_adagrad.tensor_lr_out
aten::_fused_adagrad_
aten::_fused_adagrad_.tensor_lr
aten::_fused_adam
aten::_fused_adam.out
aten::_fused_adam.tensor_lr