mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add TensorLR variant for fused Adagrad on CPU (#153078)
This PR adds a tensor LR variant for the CPU Adagrad(fused=True). I copied the behavior from the tensor LR variant of CPU Adam(fused=True), where the `lr.item()` is cast to a double and passed in the default function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153078 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
d51bc27378
commit
9ad9a04ca7
@ -11,7 +11,6 @@
|
||||
#include <ATen/ops/_fused_adagrad_native.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void _fused_adagrad_kernel_cpu_(
|
||||
@ -31,28 +30,54 @@ void _fused_adagrad_kernel_cpu_(
|
||||
const float* found_inf_ptr =
|
||||
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
|
||||
if (found_inf_ptr && *found_inf_ptr == 1.0) {
|
||||
return;
|
||||
return;
|
||||
}
|
||||
size_t n_tensors = params.size();
|
||||
TORCH_CHECK(grads.size() == n_tensors);
|
||||
TORCH_CHECK(state_sums.size() == n_tensors);
|
||||
TORCH_CHECK(state_steps.size() == n_tensors);
|
||||
for (size_t i = 0; i < n_tensors; i++){
|
||||
for (size_t i = 0; i < n_tensors; i++) {
|
||||
fused_adagrad_stub(
|
||||
kCPU,
|
||||
params[i],
|
||||
grads[i],
|
||||
state_sums[i],
|
||||
state_steps[i],
|
||||
lr,
|
||||
kCPU,
|
||||
params[i],
|
||||
grads[i],
|
||||
state_sums[i],
|
||||
state_steps[i],
|
||||
lr,
|
||||
lr_decay,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
void _fused_adagrad_kernel_cpu_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList state_sums,
|
||||
at::TensorList state_steps,
|
||||
const at::Tensor& lr,
|
||||
const double lr_decay,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<at::Tensor>& grad_scale,
|
||||
const std::optional<at::Tensor>& found_inf) {
|
||||
_fused_adagrad_kernel_cpu_(
|
||||
params,
|
||||
grads,
|
||||
state_sums,
|
||||
state_steps,
|
||||
lr.item<double>(),
|
||||
lr_decay,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr);
|
||||
}
|
||||
grad_scale,
|
||||
found_inf);
|
||||
}
|
||||
|
||||
DEFINE_DISPATCH(fused_adagrad_stub);
|
||||
|
||||
}
|
||||
} // namespace at::native
|
||||
|
@ -15856,6 +15856,13 @@
|
||||
CPU: _fused_adagrad_kernel_cpu_
|
||||
autogen: _fused_adagrad, _fused_adagrad.out
|
||||
|
||||
- func: _fused_adagrad_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
|
||||
device_check: NoCheck
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _fused_adagrad_kernel_cpu_
|
||||
autogen: _fused_adagrad.tensor_lr, _fused_adagrad.tensor_lr_out
|
||||
|
||||
# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
|
||||
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
|
||||
variants: function
|
||||
|
@ -356,7 +356,10 @@ aten::_functional_sym_constrain_range
|
||||
aten::_functional_sym_constrain_range_for_size
|
||||
aten::_fused_adagrad
|
||||
aten::_fused_adagrad.out
|
||||
aten::_fused_adagrad.tensor_lr
|
||||
aten::_fused_adagrad.tensor_lr_out
|
||||
aten::_fused_adagrad_
|
||||
aten::_fused_adagrad_.tensor_lr
|
||||
aten::_fused_adam
|
||||
aten::_fused_adam.out
|
||||
aten::_fused_adam.tensor_lr
|
||||
|
Reference in New Issue
Block a user