mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[optim] add fused_adam/adamw_kernel support for CPU device (#123074)
On par with `CUDA` implementation.
For `autocast` logic, same with `CUDA` + `Fused Adam`:
- check inf in `gradscalar.step`
- In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.
**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused
# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict
# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108)
)
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```
**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)
```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```
**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
// std::cout << "exp_avg_sq " << exp_avg_sq_ptr[d] << std::endl;
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.
Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
9a71d12d92
commit
b412b75b42
175
aten/src/ATen/native/FusedAdam.cpp
Normal file
175
aten/src/ATen/native/FusedAdam.cpp
Normal file
@ -0,0 +1,175 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/FusedAdam.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_fused_adam.h>
|
||||
#include <ATen/ops/_fused_adam_native.h>
|
||||
#include <ATen/ops/_fused_adamw.h>
|
||||
#include <ATen/ops/_fused_adamw_native.h>
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
void _fused_adam_kernel_cpu_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
const float* grad_scale_ptr =
|
||||
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
|
||||
const float* found_inf_ptr =
|
||||
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
|
||||
if (found_inf_ptr && *found_inf_ptr == 1.0) {
|
||||
return;
|
||||
}
|
||||
size_t n_tensors = params.size();
|
||||
TORCH_CHECK(grads.size() == n_tensors);
|
||||
TORCH_CHECK(exp_avgs.size() == n_tensors);
|
||||
TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
|
||||
if (amsgrad) {
|
||||
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
|
||||
} else {
|
||||
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
|
||||
}
|
||||
TORCH_CHECK(state_steps.size() == n_tensors);
|
||||
at::Tensor max_exp_avg_sq = at::Tensor();
|
||||
for (size_t i = 0; i < n_tensors; i++){
|
||||
if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
|
||||
fused_adam_stub(
|
||||
kCPU,
|
||||
params[i],
|
||||
grads[i],
|
||||
exp_avgs[i],
|
||||
exp_avg_sqs[i],
|
||||
max_exp_avg_sq,
|
||||
state_steps[i],
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
amsgrad,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
ADAM_MODE::ORIGINAL);
|
||||
}
|
||||
}
|
||||
|
||||
// The following overload simply has a Tensor lr
|
||||
void _fused_adam_kernel_cpu_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const at::Tensor& lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
_fused_adam_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
|
||||
}
|
||||
|
||||
void _fused_adamw_kernel_cpu_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
const float* grad_scale_ptr =
|
||||
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
|
||||
const float* found_inf_ptr =
|
||||
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
|
||||
if (found_inf_ptr && *found_inf_ptr == 1.0) {
|
||||
return;
|
||||
}
|
||||
size_t n_tensors = params.size();
|
||||
TORCH_CHECK(grads.size() == n_tensors);
|
||||
TORCH_CHECK(exp_avgs.size() == n_tensors);
|
||||
TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
|
||||
if (amsgrad) {
|
||||
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
|
||||
} else {
|
||||
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
|
||||
}
|
||||
TORCH_CHECK(state_steps.size() == n_tensors);
|
||||
at::Tensor max_exp_avg_sq = at::Tensor();
|
||||
for (size_t i = 0; i < n_tensors; i++){
|
||||
if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
|
||||
fused_adam_stub(
|
||||
kCPU,
|
||||
params[i],
|
||||
grads[i],
|
||||
exp_avgs[i],
|
||||
exp_avg_sqs[i],
|
||||
max_exp_avg_sq,
|
||||
state_steps[i],
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
amsgrad,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
ADAM_MODE::ADAMW);
|
||||
}
|
||||
}
|
||||
|
||||
// The following overload simply has a Tensor lr
|
||||
void _fused_adamw_kernel_cpu_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const at::Tensor& lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
_fused_adamw_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
|
||||
}
|
||||
|
||||
|
||||
DEFINE_DISPATCH(fused_adam_stub);
|
||||
|
||||
}
|
||||
}
|
30
aten/src/ATen/native/FusedAdam.h
Normal file
30
aten/src/ATen/native/FusedAdam.h
Normal file
@ -0,0 +1,30 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
|
||||
|
||||
using fused_adam_fn = void (*)(
|
||||
const at::Tensor& param,
|
||||
const at::Tensor& grad,
|
||||
const at::Tensor& exp_avg,
|
||||
const at::Tensor& exp_avg_sq,
|
||||
const at::Tensor& max_exp_avg_sq,
|
||||
const at::Tensor& state_step,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const float* grad_scale_ptr,
|
||||
const ADAM_MODE);
|
||||
|
||||
DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
|
||||
|
||||
}
|
||||
}
|
379
aten/src/ATen/native/cpu/FusedAdamKernel.cpp
Normal file
379
aten/src/ATen/native/cpu/FusedAdamKernel.cpp
Normal file
@ -0,0 +1,379 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/FusedAdam.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
namespace at::native {
|
||||
|
||||
namespace{
|
||||
|
||||
template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
|
||||
typename std::enable_if<
|
||||
std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
|
||||
void>::
|
||||
type inline adam_math(
|
||||
scalar_t* param_ptr,
|
||||
scalar_t* exp_avg_ptr,
|
||||
scalar_t* exp_avg_sq_ptr,
|
||||
scalar_t* grad_ptr,
|
||||
scalar_t* max_exp_avg_sq_ptr,
|
||||
double lr,
|
||||
double bias_correction1,
|
||||
double bias_correction2,
|
||||
double exp_avg_grad_coefficient,
|
||||
double exp_avg_sq_grad_coefficient,
|
||||
double bias_correction2_sqrt,
|
||||
double eps,
|
||||
double weight_decay,
|
||||
double beta2,
|
||||
bool amsgrad,
|
||||
bool maximize,
|
||||
const float* grad_scale_ptr,
|
||||
int64_t size
|
||||
){
|
||||
double step_size = lr / bias_correction1;
|
||||
using lpVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<opmath_t>;
|
||||
lpVec grad_vec_to_store;
|
||||
int64_t d = 0;
|
||||
fVec param_vec1, param_vec2;
|
||||
fVec grad_vec1, grad_vec2;
|
||||
fVec exp_avg_vec1, exp_avg_vec2;
|
||||
fVec exp_avg_sq_vec1, exp_avg_sq_vec2;
|
||||
fVec max_exp_avg_sq_vec1, max_exp_avg_sq_vec2;
|
||||
for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
|
||||
lpVec param_lpvec = lpVec::loadu(param_ptr + d);
|
||||
std::tie(param_vec1, param_vec2) = vec::convert_to_float<scalar_t>(param_lpvec);
|
||||
lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
|
||||
std::tie(grad_vec1, grad_vec2) = vec::convert_to_float<scalar_t>(grad_lpvec);
|
||||
if (grad_scale_ptr) {
|
||||
grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
|
||||
grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
|
||||
grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
|
||||
grad_vec_to_store.store(grad_ptr + d);
|
||||
}
|
||||
if (maximize){
|
||||
grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
|
||||
grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
|
||||
}
|
||||
if (weight_decay != 0.f){
|
||||
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
|
||||
grad_vec1 += param_vec1 * fVec(opmath_t(weight_decay));
|
||||
grad_vec2 += param_vec2 * fVec(opmath_t(weight_decay));
|
||||
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
|
||||
param_vec1 = param_vec1 * fVec(opmath_t(1 - lr * weight_decay));
|
||||
param_vec2 = param_vec2 * fVec(opmath_t(1 - lr * weight_decay));
|
||||
}
|
||||
}
|
||||
|
||||
lpVec exp_avg_lpvec = lpVec::loadu(exp_avg_ptr + d);
|
||||
std::tie(exp_avg_vec1, exp_avg_vec2) = vec::convert_to_float<scalar_t>(exp_avg_lpvec);
|
||||
|
||||
// exp_avg.lerp_(grad, 1 - beta1)
|
||||
const fVec lerp_weight = fVec(opmath_t(exp_avg_grad_coefficient));
|
||||
auto mask = lerp_weight.abs() < fVec(0.5);
|
||||
auto coeff = fVec::blendv(lerp_weight - fVec(1), lerp_weight, mask);
|
||||
|
||||
auto base1 = fVec::blendv(grad_vec1, exp_avg_vec1, mask);
|
||||
exp_avg_vec1 = vec::fmadd(coeff, grad_vec1 - exp_avg_vec1, base1);
|
||||
|
||||
auto base2 = fVec::blendv(grad_vec2, exp_avg_vec2, mask);
|
||||
exp_avg_vec2 = vec::fmadd(coeff, grad_vec2 - exp_avg_vec2, base2);
|
||||
|
||||
lpVec exp_avg_sq_lpvec = lpVec::loadu(exp_avg_sq_ptr + d);
|
||||
std::tie(exp_avg_sq_vec1, exp_avg_sq_vec2) = vec::convert_to_float<scalar_t>(exp_avg_sq_lpvec);
|
||||
exp_avg_sq_vec1 = exp_avg_sq_vec1 * fVec(opmath_t(beta2)) +
|
||||
fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec1 * grad_vec1;
|
||||
exp_avg_sq_vec2 = exp_avg_sq_vec2 * fVec(opmath_t(beta2)) +
|
||||
fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec2 * grad_vec2;
|
||||
|
||||
vec::convert_from_float<scalar_t>(exp_avg_vec1, exp_avg_vec2).store(exp_avg_ptr + d);
|
||||
vec::convert_from_float<scalar_t>(exp_avg_sq_vec1, exp_avg_sq_vec2).store(exp_avg_sq_ptr + d);
|
||||
|
||||
fVec denom_vec1, denom_vec2;
|
||||
if (amsgrad) {
|
||||
lpVec max_exp_avg_sq_lpvec = lpVec::loadu(max_exp_avg_sq_ptr + d);
|
||||
std::tie(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2) = vec::convert_to_float<scalar_t>(max_exp_avg_sq_lpvec);
|
||||
max_exp_avg_sq_vec1 = maximum(max_exp_avg_sq_vec1, exp_avg_sq_vec1);
|
||||
max_exp_avg_sq_vec2 = maximum(max_exp_avg_sq_vec2, exp_avg_sq_vec2);
|
||||
vec::convert_from_float<scalar_t>(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2).store(max_exp_avg_sq_ptr + d);
|
||||
denom_vec1 =
|
||||
(max_exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
|
||||
denom_vec2 =
|
||||
(max_exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
|
||||
} else {
|
||||
denom_vec1 =
|
||||
(exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
|
||||
denom_vec2 =
|
||||
(exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
|
||||
}
|
||||
param_vec1 = param_vec1 + fVec(opmath_t(-step_size)) * exp_avg_vec1 / denom_vec1;
|
||||
param_vec2 = param_vec2 + fVec(opmath_t(-step_size)) * exp_avg_vec2 / denom_vec2;
|
||||
vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
|
||||
}
|
||||
scalar_t grad_val_to_store;
|
||||
for (; d < size; d++) {
|
||||
opmath_t grad_val = grad_ptr[d];
|
||||
opmath_t param_val = param_ptr[d];
|
||||
if (grad_scale_ptr) {
|
||||
grad_val = grad_ptr[d] / float(*grad_scale_ptr);
|
||||
grad_val_to_store = scalar_t(grad_val);
|
||||
grad_ptr[d] = grad_val_to_store;
|
||||
}
|
||||
if (maximize) grad_val = -grad_val;
|
||||
if (weight_decay != 0.f){
|
||||
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
|
||||
grad_val += param_val * opmath_t(weight_decay);
|
||||
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
|
||||
param_val = param_val * opmath_t(1 - lr * weight_decay);
|
||||
}
|
||||
}
|
||||
// exp_avg.lerp_(grad, 1 - beta1)
|
||||
opmath_t exp_avg_var = exp_avg_ptr[d];
|
||||
auto is_lerp_weight_small = std::abs(opmath_t(exp_avg_grad_coefficient)) < opmath_t(0.5);
|
||||
if (is_lerp_weight_small) {
|
||||
exp_avg_var = exp_avg_var + opmath_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_var);
|
||||
} else {
|
||||
exp_avg_var = grad_val - (grad_val - exp_avg_var) * (opmath_t(1) - opmath_t(exp_avg_grad_coefficient));
|
||||
}
|
||||
exp_avg_ptr[d] = scalar_t(exp_avg_var);
|
||||
opmath_t exp_avg_sq_var = exp_avg_sq_ptr[d];
|
||||
exp_avg_sq_var = exp_avg_sq_var * opmath_t(beta2);
|
||||
exp_avg_sq_var = exp_avg_sq_var +
|
||||
opmath_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
|
||||
exp_avg_sq_ptr[d] = scalar_t(exp_avg_sq_var);
|
||||
opmath_t demon_val;
|
||||
if (amsgrad) {
|
||||
opmath_t max_exp_avg_sq_var = max_exp_avg_sq_ptr[d];
|
||||
max_exp_avg_sq_var = std::max(max_exp_avg_sq_var, exp_avg_sq_var);
|
||||
max_exp_avg_sq_ptr[d] =
|
||||
scalar_t(max_exp_avg_sq_var);
|
||||
demon_val =
|
||||
std::sqrt(max_exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
|
||||
} else {
|
||||
demon_val = std::sqrt(exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
|
||||
}
|
||||
param_ptr[d] = param_val - opmath_t(step_size) * exp_avg_var / demon_val;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
|
||||
typename std::enable_if<
|
||||
std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
|
||||
void>::
|
||||
type inline adam_math(
|
||||
scalar_t* param_ptr,
|
||||
scalar_t* exp_avg_ptr,
|
||||
scalar_t* exp_avg_sq_ptr,
|
||||
scalar_t* grad_ptr,
|
||||
scalar_t* max_exp_avg_sq_ptr,
|
||||
double lr,
|
||||
double bias_correction1,
|
||||
double bias_correction2,
|
||||
double exp_avg_grad_coefficient,
|
||||
double exp_avg_sq_grad_coefficient,
|
||||
double bias_correction2_sqrt,
|
||||
double eps,
|
||||
double weight_decay,
|
||||
double beta2,
|
||||
bool amsgrad,
|
||||
bool maximize,
|
||||
const float* grad_scale_ptr,
|
||||
int64_t size
|
||||
){
|
||||
double step_size = lr / bias_correction1;
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
Vec grad_vec_to_store;
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
||||
Vec param_vec = Vec::loadu(param_ptr + d);
|
||||
Vec grad_vec = Vec::loadu(grad_ptr + d);
|
||||
if (grad_scale_ptr) {
|
||||
grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
|
||||
grad_vec_to_store = grad_vec;
|
||||
grad_vec_to_store.store(grad_ptr + d);
|
||||
}
|
||||
if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
|
||||
if (weight_decay != 0.f){
|
||||
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
|
||||
grad_vec += param_vec * Vec(scalar_t(weight_decay));
|
||||
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
|
||||
param_vec = param_vec * Vec(scalar_t(1 - lr * weight_decay));
|
||||
}
|
||||
}
|
||||
Vec exp_avg_vec = Vec::loadu(exp_avg_ptr + d);
|
||||
// exp_avg.lerp_(grad, 1 - beta1)
|
||||
const Vec lerp_weight = Vec(scalar_t(exp_avg_grad_coefficient));
|
||||
auto mask = lerp_weight.abs() < Vec(0.5);
|
||||
auto coeff = Vec::blendv(lerp_weight - Vec(1), lerp_weight, mask);
|
||||
auto base = Vec::blendv(grad_vec, exp_avg_vec, mask);
|
||||
exp_avg_vec = vec::fmadd(coeff, grad_vec - exp_avg_vec, base);
|
||||
|
||||
Vec exp_avg_sq_vec = Vec::loadu(exp_avg_sq_ptr + d) * Vec(scalar_t(beta2)) +
|
||||
Vec(scalar_t(exp_avg_sq_grad_coefficient)) * grad_vec * grad_vec;
|
||||
exp_avg_vec.store(exp_avg_ptr + d);
|
||||
exp_avg_sq_vec.store(exp_avg_sq_ptr + d);
|
||||
|
||||
Vec denom_vec;
|
||||
if (amsgrad) {
|
||||
Vec max_exp_avg_sq_vec =
|
||||
maximum(Vec::loadu(max_exp_avg_sq_ptr + d), exp_avg_sq_vec);
|
||||
max_exp_avg_sq_vec.store(max_exp_avg_sq_ptr + d);
|
||||
denom_vec =
|
||||
(max_exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps));
|
||||
} else {
|
||||
denom_vec =
|
||||
(exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps));
|
||||
}
|
||||
param_vec = param_vec + Vec(scalar_t(-step_size)) * exp_avg_vec / denom_vec;
|
||||
param_vec.store(param_ptr + d);
|
||||
}
|
||||
scalar_t grad_val_to_store;
|
||||
for (; d < size; d++) {
|
||||
scalar_t grad_val = grad_ptr[d];
|
||||
if (grad_scale_ptr) {
|
||||
grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
|
||||
grad_val_to_store = grad_val;
|
||||
grad_ptr[d] = grad_val_to_store;
|
||||
}
|
||||
if (maximize) grad_val = -grad_val;
|
||||
if (weight_decay != 0.f){
|
||||
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
|
||||
grad_val += param_ptr[d] * scalar_t(weight_decay);
|
||||
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
|
||||
param_ptr[d] = param_ptr[d] * scalar_t(1 - lr * weight_decay);
|
||||
}
|
||||
}
|
||||
// exp_avg.lerp_(grad, 1 - beta1)
|
||||
auto is_lerp_weight_small = std::abs(scalar_t(exp_avg_grad_coefficient)) < scalar_t(0.5);
|
||||
if (is_lerp_weight_small) {
|
||||
exp_avg_ptr[d] = exp_avg_ptr[d] + scalar_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_ptr[d]);
|
||||
} else {
|
||||
exp_avg_ptr[d] = grad_val - (grad_val - exp_avg_ptr[d]) * (scalar_t(1) - scalar_t(exp_avg_grad_coefficient));
|
||||
}
|
||||
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] * scalar_t(beta2);
|
||||
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
|
||||
scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
|
||||
scalar_t demon_val;
|
||||
if (amsgrad) {
|
||||
max_exp_avg_sq_ptr[d] =
|
||||
std::max(max_exp_avg_sq_ptr[d], exp_avg_sq_ptr[d]);
|
||||
demon_val =
|
||||
std::sqrt(max_exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps);
|
||||
} else {
|
||||
demon_val = std::sqrt(exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps);
|
||||
}
|
||||
param_ptr[d] = param_ptr[d] - scalar_t(step_size) * exp_avg_ptr[d] / demon_val;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, ADAM_MODE adam_mode>
|
||||
void adam_fused_step_impl(
|
||||
const at::Tensor& param,
|
||||
const at::Tensor& grad,
|
||||
const at::Tensor& exp_avg,
|
||||
const at::Tensor& exp_avg_sq,
|
||||
const at::Tensor& max_exp_avg_sq,
|
||||
const at::Tensor& state_step,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const float* grad_scale_ptr) {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
double step = state_step.item<float>();
|
||||
scalar_t* param_data = param.data_ptr<scalar_t>();
|
||||
scalar_t* exp_avg_data = exp_avg.data_ptr<scalar_t>();
|
||||
scalar_t* exp_avg_sq_data = exp_avg_sq.data_ptr<scalar_t>();
|
||||
scalar_t* max_exp_avg_sq_data = amsgrad ? max_exp_avg_sq.data_ptr<scalar_t>() : nullptr;
|
||||
scalar_t* grad_data = grad.data_ptr<scalar_t>();
|
||||
|
||||
// need to use double here to align with non-fused adam
|
||||
double bias_correction1 = 1 - std::pow(beta1, step);
|
||||
double bias_correction2 = 1 - std::pow(beta2, step);
|
||||
double exp_avg_grad_coefficient = 1 - beta1;
|
||||
double exp_avg_sq_grad_coefficient = 1 - beta2;
|
||||
double bias_correction2_sqrt = std::sqrt(bias_correction2);
|
||||
|
||||
|
||||
constexpr size_t cache_line_size = 64;
|
||||
constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
|
||||
size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
|
||||
|
||||
auto adam_fn = [&](int64_t begin, int64_t end) {
|
||||
// local pointers
|
||||
begin *= cache_line_aligned_task_unit;
|
||||
end = std::min(end * cache_line_aligned_task_unit, param.numel());
|
||||
scalar_t* param_ptr = param_data + begin;
|
||||
scalar_t* exp_avg_ptr = exp_avg_data + begin;
|
||||
scalar_t* exp_avg_sq_ptr = exp_avg_sq_data + begin;
|
||||
scalar_t* grad_ptr = grad_data + begin;
|
||||
scalar_t* max_exp_avg_sq_ptr = amsgrad ? max_exp_avg_sq_data + begin : nullptr;
|
||||
|
||||
const int64_t size = end - begin;
|
||||
adam_math<scalar_t, opmath_t, adam_mode>(
|
||||
param_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
grad_ptr,
|
||||
max_exp_avg_sq_ptr,
|
||||
lr,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
exp_avg_grad_coefficient,
|
||||
exp_avg_sq_grad_coefficient,
|
||||
bias_correction2_sqrt,
|
||||
eps,
|
||||
weight_decay,
|
||||
beta2,
|
||||
amsgrad,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
size
|
||||
);
|
||||
};
|
||||
at::parallel_for(
|
||||
0, num_units, 0, adam_fn);
|
||||
}
|
||||
|
||||
void fused_adam_kernel(
|
||||
const at::Tensor& param,
|
||||
const at::Tensor& grad,
|
||||
const at::Tensor& exp_avg,
|
||||
const at::Tensor& exp_avg_sq,
|
||||
const at::Tensor& max_exp_avg_sq,
|
||||
const at::Tensor& state_step,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const float* grad_scale_ptr,
|
||||
const ADAM_MODE adam_mode
|
||||
) {
|
||||
Tensor grad_contiguous = grad.contiguous();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adam_kernel", [&] {
|
||||
if(adam_mode == ADAM_MODE::ORIGINAL){
|
||||
adam_fused_step_impl<scalar_t, ADAM_MODE::ORIGINAL>(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr);
|
||||
} else {
|
||||
adam_fused_step_impl<scalar_t, ADAM_MODE::ADAMW>(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr);
|
||||
}
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fused_adam_stub, &fused_adam_kernel);
|
||||
} // namespace at::native
|
@ -15517,6 +15517,7 @@
|
||||
# Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _fused_adam_kernel_cpu_
|
||||
CUDA: _fused_adam_kernel_cuda_
|
||||
autogen: _fused_adam, _fused_adam.out
|
||||
|
||||
@ -15526,6 +15527,7 @@
|
||||
device_check: NoCheck
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _fused_adam_kernel_cpu_
|
||||
CUDA: _fused_adam_kernel_cuda_
|
||||
autogen: _fused_adam.tensor_lr, _fused_adam.tensor_lr_out
|
||||
|
||||
@ -15533,6 +15535,7 @@
|
||||
# Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _fused_adamw_kernel_cpu_
|
||||
CUDA: _fused_adamw_kernel_cuda_
|
||||
autogen: _fused_adamw, _fused_adamw.out
|
||||
|
||||
@ -15542,6 +15545,7 @@
|
||||
device_check: NoCheck
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _fused_adamw_kernel_cpu_
|
||||
CUDA: _fused_adamw_kernel_cuda_
|
||||
autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out
|
||||
|
||||
|
@ -1168,6 +1168,7 @@ aten_native_source_codegen_list = [
|
||||
"aten/src/ATen/native/cpu/SpmmReduceKernel.cpp",
|
||||
"aten/src/ATen/native/cpu/SparseFactories.cpp",
|
||||
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
|
||||
"aten/src/ATen/native/cpu/FusedAdamKernel.cpp",
|
||||
]
|
||||
|
||||
# This aten native source file list will not go through aten codegen process
|
||||
@ -1402,6 +1403,7 @@ aten_native_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/xnnpack/OpContext.cpp",
|
||||
"aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp",
|
||||
"aten/src/ATen/native/xnnpack/Shim.cpp",
|
||||
"aten/src/ATen/native/FusedAdam.cpp",
|
||||
# Files not in native, but depends on native symbols
|
||||
# "aten/src/ATen/TensorIndexing.cpp",
|
||||
"aten/src/ATen/TensorIterator.cpp",
|
||||
|
@ -21,9 +21,10 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_optimizers import (
|
||||
optim_db, optims, OptimizerErrorEnum, _get_optim_inputs_including_global_cliquey_kwargs, TensorTracker)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM)
|
||||
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM, onlyNativeDeviceTypes)
|
||||
from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase
|
||||
|
||||
from torch.testing._internal.common_cuda import _create_scaling_case
|
||||
from torch.testing._internal.common_dtype import floating_types_and
|
||||
|
||||
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
|
||||
|
||||
@ -581,6 +582,49 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertTrue(a1_grad_imags.all_popped())
|
||||
self.assertTrue(losses.all_popped())
|
||||
|
||||
def _compare_between(self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None):
|
||||
# why 7? iteration 7 is where we start to see differences for RAdam
|
||||
# params interacting with the small eps value, because that's right
|
||||
# after rho_t becomes greater than 5 in step 6.
|
||||
if assert_eq_kwargs is None:
|
||||
assert_eq_kwargs = {}
|
||||
kIterations = 7
|
||||
tracker = TensorTracker(assert_eq_kwargs)
|
||||
for i in range(kIterations):
|
||||
state, updated_params = [], []
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs, inputs]
|
||||
for input, model, optimizer in zip(inputs, models, optimizers):
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Test that step behaves as expected (a no-op) when grads are set to None
|
||||
if i != 3:
|
||||
output = model(input)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
state.append(optimizer.state)
|
||||
updated_params.append(model.parameters())
|
||||
|
||||
og_state, new_state = state
|
||||
for og_p, new_p in zip(updated_params[0], updated_params[1]):
|
||||
tracker.add(og_p)
|
||||
tracker.pop_check_set(new_p, self)
|
||||
|
||||
# check that optimizer states are the same
|
||||
og_p_state = og_state[og_p]
|
||||
new_p_state = new_state[new_p]
|
||||
if assert_step_dtype is not None:
|
||||
if torch.is_tensor(og_p_state.get("step", None)):
|
||||
self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
|
||||
if torch.is_tensor(new_p_state.get("step", None)):
|
||||
self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
|
||||
for k in og_p_state:
|
||||
tracker.add(og_p_state[k])
|
||||
tracker.pop_check_set(new_p_state[k], self)
|
||||
|
||||
self.assertTrue(tracker.all_popped())
|
||||
|
||||
def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
|
||||
"""
|
||||
@ -589,16 +633,12 @@ class TestOptimRenewed(TestCase):
|
||||
for provided optimizer configurations.
|
||||
"""
|
||||
assert flag in ("foreach", "fused")
|
||||
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
|
||||
|
||||
# why 7? iteration 7 is where we start to see differences for RAdam
|
||||
# params interacting with the small eps value, because that's right
|
||||
# after rho_t becomes greater than 5 in step 6.
|
||||
kIterations = 7
|
||||
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
|
||||
optim_cls = optim_info.optim_cls
|
||||
for optim_input in optim_inputs:
|
||||
updated_params, state = [], []
|
||||
models, optimizers = [], []
|
||||
kwargs = deepcopy(optim_input.kwargs)
|
||||
if kwargs.get("capturable", False) and str(device) == "cpu":
|
||||
# capturable is not supported on CPU
|
||||
@ -626,39 +666,10 @@ class TestOptimRenewed(TestCase):
|
||||
params = list(model.parameters()) + [empty_param]
|
||||
|
||||
optimizer = optim_cls(params, **kwargs)
|
||||
models.append(model)
|
||||
optimizers.append(optimizer)
|
||||
|
||||
for i in range(kIterations):
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Test that step behaves as expected (a no-op) when grads are set to None
|
||||
if i != 3:
|
||||
output = model(input)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if assert_step_dtype is not None:
|
||||
p_state = optimizer.state[params[0]]
|
||||
if torch.is_tensor(p_state.get("step", None)):
|
||||
self.assertEqual(p_state["step"].dtype, assert_step_dtype)
|
||||
|
||||
state.append(optimizer.state)
|
||||
updated_params.append(model.parameters())
|
||||
|
||||
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
|
||||
|
||||
og_state, new_state = state
|
||||
for og_p, new_p in zip(updated_params[0], updated_params[1]):
|
||||
self.assertEqual(og_p, new_p, **assert_eq_kwargs)
|
||||
|
||||
# check that optimizer states are the same
|
||||
og_p_state = og_state[og_p]
|
||||
new_p_state = new_state[new_p]
|
||||
|
||||
for k in og_p_state:
|
||||
self.assertEqual(og_p_state[k], new_p_state[k], **assert_eq_kwargs)
|
||||
|
||||
self._compare_between(input, models, optimizers, assert_eq_kwargs, assert_step_dtype)
|
||||
|
||||
@skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
|
||||
@optims([optim for optim in optim_db if "foreach" in optim.supported_impls], dtypes=[torch.float64])
|
||||
@ -847,16 +858,23 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertLessEqual(mt_max_mem, expected_max_mem)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float64])
|
||||
@onlyNativeDeviceTypes
|
||||
@optims(
|
||||
[optim for optim in optim_db if "fused" in optim.supported_impls],
|
||||
dtypes=floating_types_and(torch.bfloat16, torch.float16, )
|
||||
)
|
||||
def test_fused_matches_forloop(self, device, dtype, optim_info):
|
||||
if device not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
|
||||
self._test_derived_optimizers(device, dtype, optim_info, "fused")
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("64GB", "cuda")
|
||||
@onlyNativeDeviceTypes
|
||||
@largeTensorTest("64GB")
|
||||
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float16])
|
||||
def test_fused_large_tensor(self, device, dtype, optim_info):
|
||||
if device not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
for optim_input in optim_inputs:
|
||||
@ -1304,10 +1322,11 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
# Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable!
|
||||
capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
|
||||
fused = state_dict_cpu["param_groups"][0].get("fused", False)
|
||||
new_state_dict = optimizer_cuda.state_dict()
|
||||
for state_cpu, state_cuda in zip(state_dict_cpu["state"].values(), new_state_dict["state"].values()):
|
||||
if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
|
||||
self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable else "cpu")
|
||||
self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable or fused else "cpu")
|
||||
|
||||
for _ in range(5):
|
||||
optimizer.step(closure)
|
||||
@ -1615,6 +1634,104 @@ class TestOptimRenewed(TestCase):
|
||||
res2 = optim_neg_inf.step(closure)
|
||||
self.assertEqual(type(res1), type(res2))
|
||||
|
||||
@onlyCUDA
|
||||
@optims(
|
||||
[optim for optim in optim_db if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on],
|
||||
dtypes=floating_types_and(torch.bfloat16, torch.float16,)
|
||||
)
|
||||
def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device="cpu")
|
||||
for optim_input in optim_inputs:
|
||||
inpts, models, optimizers = [], [], []
|
||||
for dev in ('cpu', 'cuda'):
|
||||
kwargs = optim_input.kwargs
|
||||
kwargs["fused"] = True
|
||||
inpt = torch.tensor(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
|
||||
).reshape(3, 2)
|
||||
|
||||
torch.manual_seed(1)
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 3),
|
||||
torch.nn.Sigmoid(),
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
model.to(dtype=dtype, device=dev)
|
||||
|
||||
# foreach/fused optimizers should be tested with a
|
||||
# zero_size tensor as its last param.
|
||||
# ref: https://github.com/pytorch/pytorch/issues/100701
|
||||
empty_param = torch.empty((), device=dev, dtype=dtype, requires_grad=True)
|
||||
empty_param.grad = torch.rand_like(empty_param)
|
||||
params = list(model.parameters()) + [empty_param]
|
||||
|
||||
optimizer = optim_cls(params, **kwargs)
|
||||
inpts.append(inpt)
|
||||
models.append(model)
|
||||
optimizers.append(optimizer)
|
||||
self._compare_between(inpts, models, optimizers)
|
||||
|
||||
@onlyCPU
|
||||
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
|
||||
def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
|
||||
# This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers
|
||||
# but only test Adam/AdamW on CPU
|
||||
# TODO: haozhe, support SGD and unified this ut with the CUDA only one
|
||||
if device not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
optim_cls = optim_info.optim_cls
|
||||
for optim_input in optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
for _separate_unscale in (True, False):
|
||||
self._grad_scaling_autocast_fused_optimizers(
|
||||
optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale)
|
||||
|
||||
def _grad_scaling_autocast_fused_optimizers(self, optimizer_ctor, optimizer_kwargs, separate_unscale):
|
||||
(
|
||||
mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, _,
|
||||
) = _create_scaling_case(optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, device='cpu')
|
||||
kwargs = deepcopy(optimizer_kwargs)
|
||||
kwargs["fused"] = False
|
||||
if 'lr' not in optimizer_kwargs:
|
||||
# _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr
|
||||
kwargs['lr'] = 1.0
|
||||
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
|
||||
|
||||
scaler = torch.cpu.amp.GradScaler(init_scale=128.0)
|
||||
for input, target in data:
|
||||
opt_control.zero_grad()
|
||||
with torch.autocast('cpu', dtype=torch.half):
|
||||
output_control = mod_control(input)
|
||||
loss_control = loss_fn(output_control, target)
|
||||
scaler.scale(loss_control).backward()
|
||||
scaler.step(opt_control)
|
||||
scaler.update()
|
||||
|
||||
opt_scaling.zero_grad()
|
||||
with torch.autocast('cpu', dtype=torch.half):
|
||||
output_scaling = mod_scaling(input)
|
||||
loss_scaling = loss_fn(output_scaling, target)
|
||||
scaler.scale(loss_scaling).backward()
|
||||
if separate_unscale:
|
||||
scaler.unscale_(opt_scaling)
|
||||
scaler.step(opt_scaling)
|
||||
scaler.update()
|
||||
|
||||
self.assertEqual(loss_control, loss_scaling,)
|
||||
for param_control, param_scaling in zip(mod_control.parameters(), mod_scaling.parameters()):
|
||||
self.assertEqual(param_control.grad, param_scaling.grad,)
|
||||
self.assertEqual(param_control, param_scaling,)
|
||||
|
||||
state_control, state_scaling = opt_control.state[param_control], opt_scaling.state[param_scaling]
|
||||
|
||||
for k in state_control:
|
||||
actual = state_scaling[k]
|
||||
if k == "step":
|
||||
actual = actual.squeeze()
|
||||
self.assertEqual(state_control[k], actual,)
|
||||
|
||||
@onlyCUDA
|
||||
@optims([o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32])
|
||||
|
@ -47,8 +47,7 @@ from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA, onlyCPU,
|
||||
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
|
||||
skipMeta,
|
||||
PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes,
|
||||
skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes,
|
||||
get_all_device_types, skipXLA)
|
||||
from typing import Tuple
|
||||
import torch.backends.quantized
|
||||
@ -5932,7 +5931,7 @@ else:
|
||||
for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW):
|
||||
self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": True})
|
||||
|
||||
@onlyCUDA
|
||||
@onlyNativeDeviceTypes
|
||||
def test_grad_scaling_autocast_fused(self, device):
|
||||
device = torch.device(device)
|
||||
for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
|
||||
@ -5952,8 +5951,6 @@ else:
|
||||
{"foreach": False, "fused": True},
|
||||
),
|
||||
):
|
||||
if device.type != "cuda":
|
||||
optimizer_kwargs['fused'] = False
|
||||
with self.subTest(optimizer=optimizer_ctor, optimizer_kwargs=optimizer_kwargs):
|
||||
self._test_grads_invalidated_between_unscale_and_step(device.type, optimizer_ctor, optimizer_kwargs)
|
||||
|
||||
|
@ -76,7 +76,7 @@ class Adam(Optimizer):
|
||||
# Support AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices() + ["cpu"]
|
||||
if not all(
|
||||
p.device.type in fused_supported_devices and torch.is_floating_point(p)
|
||||
for pg in self.param_groups
|
||||
|
@ -75,7 +75,7 @@ class AdamW(Optimizer):
|
||||
# Suppor AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices() + ["cpu"]
|
||||
if not all(
|
||||
p.device.type in fused_supported_devices and torch.is_floating_point(p)
|
||||
for pg in self.param_groups
|
||||
|
@ -44,10 +44,7 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfTorchDynamo,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
)
|
||||
from torch.utils._foreach_utils import (
|
||||
_get_foreach_kernels_supported_devices,
|
||||
_get_fused_kernels_supported_devices,
|
||||
)
|
||||
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
|
||||
|
||||
|
||||
class OptimizerInput:
|
||||
@ -143,6 +140,7 @@ class OptimizerInfo:
|
||||
skips=(), # Indicates which tests to skip
|
||||
decorators=None, # Additional decorators to apply to generated tests
|
||||
optim_error_inputs_func=None, # Function to generate optim inputs that error
|
||||
supports_fused_on: Tuple[str] = (),
|
||||
):
|
||||
self.optim_cls = optim_cls
|
||||
self.optim_inputs_func = optim_inputs_func
|
||||
@ -160,6 +158,7 @@ class OptimizerInfo:
|
||||
*(skips if skips else []),
|
||||
)
|
||||
self.optim_error_inputs_func = optim_error_inputs_func
|
||||
self.supports_fused_on = supports_fused_on
|
||||
|
||||
def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
|
||||
result = [set_single_threaded_if_parallel_tbb]
|
||||
@ -291,7 +290,7 @@ def get_error_inputs_for_all_optims(device, dtype):
|
||||
# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
|
||||
|
||||
|
||||
def optim_inputs_func_adadelta(device):
|
||||
def optim_inputs_func_adadelta(device, dtype=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -340,7 +339,7 @@ def optim_error_inputs_func_adadelta(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_adagrad(device):
|
||||
def optim_inputs_func_adagrad(device, dtype=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(
|
||||
@ -384,7 +383,7 @@ def optim_error_inputs_func_adagrad(device, dtype):
|
||||
|
||||
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
|
||||
# with all implementation code paths...
|
||||
def optim_inputs_func_adam(device):
|
||||
def optim_inputs_func_adam(device, dtype=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -399,7 +398,7 @@ def optim_inputs_func_adam(device):
|
||||
),
|
||||
]
|
||||
|
||||
return [
|
||||
total = [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||
OptimizerInput(
|
||||
@ -414,6 +413,19 @@ def optim_inputs_func_adam(device):
|
||||
params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad"
|
||||
),
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
if dtype in (torch.float16,):
|
||||
for input in total:
|
||||
"""
|
||||
Too small eps will make denom to be zero for low precision dtype
|
||||
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
||||
For example,
|
||||
>>> a
|
||||
tensor([0.], dtype=torch.float16)
|
||||
>>> a + 1e-8
|
||||
tensor([0.], dtype=torch.float16)
|
||||
"""
|
||||
input.kwargs["eps"] = 0.1
|
||||
return total
|
||||
|
||||
|
||||
def optim_error_inputs_func_adam(device, dtype):
|
||||
@ -473,7 +485,7 @@ def optim_error_inputs_func_adam(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_adamax(device):
|
||||
def optim_inputs_func_adamax(device, dtype=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -524,15 +536,15 @@ def optim_error_inputs_func_adamax(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_adamw(device):
|
||||
return optim_inputs_func_adam(device)
|
||||
def optim_inputs_func_adamw(device, dtype=None):
|
||||
return optim_inputs_func_adam(device, dtype)
|
||||
|
||||
|
||||
def optim_error_inputs_func_adamw(device, dtype):
|
||||
return optim_error_inputs_func_adam(device, dtype)
|
||||
|
||||
|
||||
def optim_inputs_func_asgd(device):
|
||||
def optim_inputs_func_asgd(device, dtype=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -584,7 +596,7 @@ def optim_error_inputs_func_asgd(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_lbfgs(device):
|
||||
def optim_inputs_func_lbfgs(device, dtype=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||
@ -605,7 +617,7 @@ def optim_error_inputs_func_lbfgs(device, dtype):
|
||||
|
||||
|
||||
# Weird story bro, NAdam and RAdam do not have maximize.
|
||||
def optim_inputs_func_nadam(device):
|
||||
def optim_inputs_func_nadam(device, dtype=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -676,7 +688,7 @@ def optim_error_inputs_func_nadam(device, dtype):
|
||||
|
||||
|
||||
# Weird story bro, NAdam and RAdam do not have maximize.
|
||||
def optim_inputs_func_radam(device=None):
|
||||
def optim_inputs_func_radam(device=None, dtype=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -738,7 +750,7 @@ def optim_error_inputs_func_radam(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_rmsprop(device):
|
||||
def optim_inputs_func_rmsprop(device, dtype=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -799,7 +811,7 @@ def optim_error_inputs_func_rmsprop(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_rprop(device):
|
||||
def optim_inputs_func_rprop(device, dtype=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -841,7 +853,7 @@ def optim_error_inputs_func_rprop(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_sgd(device):
|
||||
def optim_inputs_func_sgd(device, dtype=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
|
||||
@ -886,7 +898,7 @@ def optim_error_inputs_func_sgd(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_sparseadam(device):
|
||||
def optim_inputs_func_sparseadam(device, dtype=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(
|
||||
@ -995,10 +1007,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
x
|
||||
for x in optim_info.supported_impls
|
||||
if x not in skip
|
||||
and (
|
||||
_get_device_type(device) in _get_fused_kernels_supported_devices()
|
||||
or x != "fused"
|
||||
)
|
||||
and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused")
|
||||
and (
|
||||
_get_device_type(device) in _get_foreach_kernels_supported_devices()
|
||||
or x != "foreach"
|
||||
@ -1196,6 +1205,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
optim_error_inputs_func=optim_error_inputs_func_adam,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_fused_on=("cpu", "cuda"),
|
||||
decorators=(
|
||||
# Expected floating point error between fused and compiled forloop
|
||||
DecorateInfo(
|
||||
@ -1205,6 +1215,21 @@ optim_db: List[OptimizerInfo] = [
|
||||
active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
|
||||
and kwargs["dtype"] == torch.float64,
|
||||
),
|
||||
DecorateInfo(
|
||||
# Note on tolerances:
|
||||
# difference comes from the fact that the non fused kernel have
|
||||
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
|
||||
# to make sure there is no discrepancies between cuda fused kernel
|
||||
# and cpu fused kernel
|
||||
toleranceOverride(
|
||||
{
|
||||
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
|
||||
torch.float16: tol(atol=5e-3, rtol=5e-3),
|
||||
}
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_fused_matches_forloop",
|
||||
),
|
||||
),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
@ -1364,6 +1389,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_adamw,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamw,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_fused_on=("cpu", "cuda"),
|
||||
decorators=(
|
||||
# Expected error between compiled forloop and fused optimizers
|
||||
DecorateInfo(
|
||||
@ -1373,6 +1399,21 @@ optim_db: List[OptimizerInfo] = [
|
||||
active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
|
||||
and kwargs["dtype"] == torch.float64,
|
||||
),
|
||||
DecorateInfo(
|
||||
toleranceOverride(
|
||||
# Note on tolerances:
|
||||
# difference comes from the fact that the non fused kernel have
|
||||
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
|
||||
# to make sure there is no discrepancies between cuda fused kernel
|
||||
# and cpu fused kernel
|
||||
{
|
||||
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
|
||||
torch.float16: tol(atol=5e-3, rtol=5e-3),
|
||||
}
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_fused_matches_forloop",
|
||||
),
|
||||
),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
@ -1865,6 +1906,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
},
|
||||
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
|
||||
),
|
||||
supports_fused_on=("cuda",),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
@ -2060,7 +2102,10 @@ class TensorTracker:
|
||||
numerical discrepancies, and so when the test fails, it is likely a real problem.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, assert_eq_kwargs=None):
|
||||
if assert_eq_kwargs is None:
|
||||
assert_eq_kwargs = {}
|
||||
self.assert_eq_kwargs = assert_eq_kwargs
|
||||
self.tensors = []
|
||||
|
||||
def add(self, tensor):
|
||||
@ -2080,7 +2125,7 @@ class TensorTracker:
|
||||
ref = self.tensors.pop(0)
|
||||
|
||||
testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}")
|
||||
testcase.assertEqual(tensor_to_set, ref)
|
||||
testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
tensor_to_set.copy_(ref)
|
||||
|
Reference in New Issue
Block a user