[optim] add fused_adam/adamw_kernel support for CPU device (#123074)

On par with `CUDA` implementation.

For `autocast` logic, same with `CUDA` + `Fused Adam`:
 - check inf in `gradscalar.step`
 - In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.

**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused

# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict

# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108))
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```

**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)

```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```

**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
  exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
  //  std::cout << "exp_avg_sq " <<   exp_avg_sq_ptr[d] << std::endl;
  exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
      scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.

Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
This commit is contained in:
Jane Xu
2024-04-19 09:54:05 +00:00
committed by PyTorch MergeBot
parent 9a71d12d92
commit b412b75b42
10 changed files with 827 additions and 78 deletions

View File

@ -0,0 +1,175 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/FusedAdam.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_fused_adam.h>
#include <ATen/ops/_fused_adam_native.h>
#include <ATen/ops/_fused_adamw.h>
#include <ATen/ops/_fused_adamw_native.h>
#endif
namespace at {
namespace native {
void _fused_adam_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
const float* grad_scale_ptr =
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
const float* found_inf_ptr =
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
if (found_inf_ptr && *found_inf_ptr == 1.0) {
return;
}
size_t n_tensors = params.size();
TORCH_CHECK(grads.size() == n_tensors);
TORCH_CHECK(exp_avgs.size() == n_tensors);
TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
if (amsgrad) {
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
} else {
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
}
TORCH_CHECK(state_steps.size() == n_tensors);
at::Tensor max_exp_avg_sq = at::Tensor();
for (size_t i = 0; i < n_tensors; i++){
if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
fused_adam_stub(
kCPU,
params[i],
grads[i],
exp_avgs[i],
exp_avg_sqs[i],
max_exp_avg_sq,
state_steps[i],
lr,
beta1,
beta2,
weight_decay,
eps,
amsgrad,
maximize,
grad_scale_ptr,
ADAM_MODE::ORIGINAL);
}
}
// The following overload simply has a Tensor lr
void _fused_adam_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const at::Tensor& lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
_fused_adam_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
}
void _fused_adamw_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
const float* grad_scale_ptr =
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
const float* found_inf_ptr =
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
if (found_inf_ptr && *found_inf_ptr == 1.0) {
return;
}
size_t n_tensors = params.size();
TORCH_CHECK(grads.size() == n_tensors);
TORCH_CHECK(exp_avgs.size() == n_tensors);
TORCH_CHECK(exp_avg_sqs.size() == n_tensors);
if (amsgrad) {
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
} else {
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
}
TORCH_CHECK(state_steps.size() == n_tensors);
at::Tensor max_exp_avg_sq = at::Tensor();
for (size_t i = 0; i < n_tensors; i++){
if (amsgrad) max_exp_avg_sq = max_exp_avg_sqs[i];
fused_adam_stub(
kCPU,
params[i],
grads[i],
exp_avgs[i],
exp_avg_sqs[i],
max_exp_avg_sq,
state_steps[i],
lr,
beta1,
beta2,
weight_decay,
eps,
amsgrad,
maximize,
grad_scale_ptr,
ADAM_MODE::ADAMW);
}
}
// The following overload simply has a Tensor lr
void _fused_adamw_kernel_cpu_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const at::Tensor& lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
_fused_adamw_kernel_cpu_(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr.item<double>(), beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
}
DEFINE_DISPATCH(fused_adam_stub);
}
}

View File

@ -0,0 +1,30 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at {
namespace native {
enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
using fused_adam_fn = void (*)(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& exp_avg,
const at::Tensor& exp_avg_sq,
const at::Tensor& max_exp_avg_sq,
const at::Tensor& state_step,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const float* grad_scale_ptr,
const ADAM_MODE);
DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
}
}

View File

@ -0,0 +1,379 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Parallel.h>
#include <ATen/OpMathType.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/FusedAdam.h>
#include <ATen/Dispatch.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
namespace at::native {
namespace{
template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
typename std::enable_if<
std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
void>::
type inline adam_math(
scalar_t* param_ptr,
scalar_t* exp_avg_ptr,
scalar_t* exp_avg_sq_ptr,
scalar_t* grad_ptr,
scalar_t* max_exp_avg_sq_ptr,
double lr,
double bias_correction1,
double bias_correction2,
double exp_avg_grad_coefficient,
double exp_avg_sq_grad_coefficient,
double bias_correction2_sqrt,
double eps,
double weight_decay,
double beta2,
bool amsgrad,
bool maximize,
const float* grad_scale_ptr,
int64_t size
){
double step_size = lr / bias_correction1;
using lpVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<opmath_t>;
lpVec grad_vec_to_store;
int64_t d = 0;
fVec param_vec1, param_vec2;
fVec grad_vec1, grad_vec2;
fVec exp_avg_vec1, exp_avg_vec2;
fVec exp_avg_sq_vec1, exp_avg_sq_vec2;
fVec max_exp_avg_sq_vec1, max_exp_avg_sq_vec2;
for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
lpVec param_lpvec = lpVec::loadu(param_ptr + d);
std::tie(param_vec1, param_vec2) = vec::convert_to_float<scalar_t>(param_lpvec);
lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
std::tie(grad_vec1, grad_vec2) = vec::convert_to_float<scalar_t>(grad_lpvec);
if (grad_scale_ptr) {
grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
grad_vec_to_store.store(grad_ptr + d);
}
if (maximize){
grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
}
if (weight_decay != 0.f){
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
grad_vec1 += param_vec1 * fVec(opmath_t(weight_decay));
grad_vec2 += param_vec2 * fVec(opmath_t(weight_decay));
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
param_vec1 = param_vec1 * fVec(opmath_t(1 - lr * weight_decay));
param_vec2 = param_vec2 * fVec(opmath_t(1 - lr * weight_decay));
}
}
lpVec exp_avg_lpvec = lpVec::loadu(exp_avg_ptr + d);
std::tie(exp_avg_vec1, exp_avg_vec2) = vec::convert_to_float<scalar_t>(exp_avg_lpvec);
// exp_avg.lerp_(grad, 1 - beta1)
const fVec lerp_weight = fVec(opmath_t(exp_avg_grad_coefficient));
auto mask = lerp_weight.abs() < fVec(0.5);
auto coeff = fVec::blendv(lerp_weight - fVec(1), lerp_weight, mask);
auto base1 = fVec::blendv(grad_vec1, exp_avg_vec1, mask);
exp_avg_vec1 = vec::fmadd(coeff, grad_vec1 - exp_avg_vec1, base1);
auto base2 = fVec::blendv(grad_vec2, exp_avg_vec2, mask);
exp_avg_vec2 = vec::fmadd(coeff, grad_vec2 - exp_avg_vec2, base2);
lpVec exp_avg_sq_lpvec = lpVec::loadu(exp_avg_sq_ptr + d);
std::tie(exp_avg_sq_vec1, exp_avg_sq_vec2) = vec::convert_to_float<scalar_t>(exp_avg_sq_lpvec);
exp_avg_sq_vec1 = exp_avg_sq_vec1 * fVec(opmath_t(beta2)) +
fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec1 * grad_vec1;
exp_avg_sq_vec2 = exp_avg_sq_vec2 * fVec(opmath_t(beta2)) +
fVec(opmath_t(exp_avg_sq_grad_coefficient)) * grad_vec2 * grad_vec2;
vec::convert_from_float<scalar_t>(exp_avg_vec1, exp_avg_vec2).store(exp_avg_ptr + d);
vec::convert_from_float<scalar_t>(exp_avg_sq_vec1, exp_avg_sq_vec2).store(exp_avg_sq_ptr + d);
fVec denom_vec1, denom_vec2;
if (amsgrad) {
lpVec max_exp_avg_sq_lpvec = lpVec::loadu(max_exp_avg_sq_ptr + d);
std::tie(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2) = vec::convert_to_float<scalar_t>(max_exp_avg_sq_lpvec);
max_exp_avg_sq_vec1 = maximum(max_exp_avg_sq_vec1, exp_avg_sq_vec1);
max_exp_avg_sq_vec2 = maximum(max_exp_avg_sq_vec2, exp_avg_sq_vec2);
vec::convert_from_float<scalar_t>(max_exp_avg_sq_vec1, max_exp_avg_sq_vec2).store(max_exp_avg_sq_ptr + d);
denom_vec1 =
(max_exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
denom_vec2 =
(max_exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
} else {
denom_vec1 =
(exp_avg_sq_vec1.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
denom_vec2 =
(exp_avg_sq_vec2.sqrt() / fVec(opmath_t(bias_correction2_sqrt))) + fVec(opmath_t(eps));
}
param_vec1 = param_vec1 + fVec(opmath_t(-step_size)) * exp_avg_vec1 / denom_vec1;
param_vec2 = param_vec2 + fVec(opmath_t(-step_size)) * exp_avg_vec2 / denom_vec2;
vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
}
scalar_t grad_val_to_store;
for (; d < size; d++) {
opmath_t grad_val = grad_ptr[d];
opmath_t param_val = param_ptr[d];
if (grad_scale_ptr) {
grad_val = grad_ptr[d] / float(*grad_scale_ptr);
grad_val_to_store = scalar_t(grad_val);
grad_ptr[d] = grad_val_to_store;
}
if (maximize) grad_val = -grad_val;
if (weight_decay != 0.f){
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
grad_val += param_val * opmath_t(weight_decay);
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
param_val = param_val * opmath_t(1 - lr * weight_decay);
}
}
// exp_avg.lerp_(grad, 1 - beta1)
opmath_t exp_avg_var = exp_avg_ptr[d];
auto is_lerp_weight_small = std::abs(opmath_t(exp_avg_grad_coefficient)) < opmath_t(0.5);
if (is_lerp_weight_small) {
exp_avg_var = exp_avg_var + opmath_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_var);
} else {
exp_avg_var = grad_val - (grad_val - exp_avg_var) * (opmath_t(1) - opmath_t(exp_avg_grad_coefficient));
}
exp_avg_ptr[d] = scalar_t(exp_avg_var);
opmath_t exp_avg_sq_var = exp_avg_sq_ptr[d];
exp_avg_sq_var = exp_avg_sq_var * opmath_t(beta2);
exp_avg_sq_var = exp_avg_sq_var +
opmath_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
exp_avg_sq_ptr[d] = scalar_t(exp_avg_sq_var);
opmath_t demon_val;
if (amsgrad) {
opmath_t max_exp_avg_sq_var = max_exp_avg_sq_ptr[d];
max_exp_avg_sq_var = std::max(max_exp_avg_sq_var, exp_avg_sq_var);
max_exp_avg_sq_ptr[d] =
scalar_t(max_exp_avg_sq_var);
demon_val =
std::sqrt(max_exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
} else {
demon_val = std::sqrt(exp_avg_sq_var) / opmath_t(bias_correction2_sqrt) + opmath_t(eps);
}
param_ptr[d] = param_val - opmath_t(step_size) * exp_avg_var / demon_val;
}
}
template <typename scalar_t, typename opmath_t, ADAM_MODE adam_mode>
typename std::enable_if<
std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
void>::
type inline adam_math(
scalar_t* param_ptr,
scalar_t* exp_avg_ptr,
scalar_t* exp_avg_sq_ptr,
scalar_t* grad_ptr,
scalar_t* max_exp_avg_sq_ptr,
double lr,
double bias_correction1,
double bias_correction2,
double exp_avg_grad_coefficient,
double exp_avg_sq_grad_coefficient,
double bias_correction2_sqrt,
double eps,
double weight_decay,
double beta2,
bool amsgrad,
bool maximize,
const float* grad_scale_ptr,
int64_t size
){
double step_size = lr / bias_correction1;
using Vec = at::vec::Vectorized<scalar_t>;
Vec grad_vec_to_store;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec param_vec = Vec::loadu(param_ptr + d);
Vec grad_vec = Vec::loadu(grad_ptr + d);
if (grad_scale_ptr) {
grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
grad_vec_to_store = grad_vec;
grad_vec_to_store.store(grad_ptr + d);
}
if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
if (weight_decay != 0.f){
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
grad_vec += param_vec * Vec(scalar_t(weight_decay));
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
param_vec = param_vec * Vec(scalar_t(1 - lr * weight_decay));
}
}
Vec exp_avg_vec = Vec::loadu(exp_avg_ptr + d);
// exp_avg.lerp_(grad, 1 - beta1)
const Vec lerp_weight = Vec(scalar_t(exp_avg_grad_coefficient));
auto mask = lerp_weight.abs() < Vec(0.5);
auto coeff = Vec::blendv(lerp_weight - Vec(1), lerp_weight, mask);
auto base = Vec::blendv(grad_vec, exp_avg_vec, mask);
exp_avg_vec = vec::fmadd(coeff, grad_vec - exp_avg_vec, base);
Vec exp_avg_sq_vec = Vec::loadu(exp_avg_sq_ptr + d) * Vec(scalar_t(beta2)) +
Vec(scalar_t(exp_avg_sq_grad_coefficient)) * grad_vec * grad_vec;
exp_avg_vec.store(exp_avg_ptr + d);
exp_avg_sq_vec.store(exp_avg_sq_ptr + d);
Vec denom_vec;
if (amsgrad) {
Vec max_exp_avg_sq_vec =
maximum(Vec::loadu(max_exp_avg_sq_ptr + d), exp_avg_sq_vec);
max_exp_avg_sq_vec.store(max_exp_avg_sq_ptr + d);
denom_vec =
(max_exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps));
} else {
denom_vec =
(exp_avg_sq_vec.sqrt() / Vec(scalar_t(bias_correction2_sqrt))) + Vec(scalar_t(eps));
}
param_vec = param_vec + Vec(scalar_t(-step_size)) * exp_avg_vec / denom_vec;
param_vec.store(param_ptr + d);
}
scalar_t grad_val_to_store;
for (; d < size; d++) {
scalar_t grad_val = grad_ptr[d];
if (grad_scale_ptr) {
grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
grad_val_to_store = grad_val;
grad_ptr[d] = grad_val_to_store;
}
if (maximize) grad_val = -grad_val;
if (weight_decay != 0.f){
if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
grad_val += param_ptr[d] * scalar_t(weight_decay);
} else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
param_ptr[d] = param_ptr[d] * scalar_t(1 - lr * weight_decay);
}
}
// exp_avg.lerp_(grad, 1 - beta1)
auto is_lerp_weight_small = std::abs(scalar_t(exp_avg_grad_coefficient)) < scalar_t(0.5);
if (is_lerp_weight_small) {
exp_avg_ptr[d] = exp_avg_ptr[d] + scalar_t(exp_avg_grad_coefficient) * (grad_val - exp_avg_ptr[d]);
} else {
exp_avg_ptr[d] = grad_val - (grad_val - exp_avg_ptr[d]) * (scalar_t(1) - scalar_t(exp_avg_grad_coefficient));
}
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] * scalar_t(beta2);
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
scalar_t demon_val;
if (amsgrad) {
max_exp_avg_sq_ptr[d] =
std::max(max_exp_avg_sq_ptr[d], exp_avg_sq_ptr[d]);
demon_val =
std::sqrt(max_exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps);
} else {
demon_val = std::sqrt(exp_avg_sq_ptr[d]) / scalar_t(bias_correction2_sqrt) + scalar_t(eps);
}
param_ptr[d] = param_ptr[d] - scalar_t(step_size) * exp_avg_ptr[d] / demon_val;
}
}
template <typename scalar_t, ADAM_MODE adam_mode>
void adam_fused_step_impl(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& exp_avg,
const at::Tensor& exp_avg_sq,
const at::Tensor& max_exp_avg_sq,
const at::Tensor& state_step,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const float* grad_scale_ptr) {
using opmath_t = at::opmath_type<scalar_t>;
double step = state_step.item<float>();
scalar_t* param_data = param.data_ptr<scalar_t>();
scalar_t* exp_avg_data = exp_avg.data_ptr<scalar_t>();
scalar_t* exp_avg_sq_data = exp_avg_sq.data_ptr<scalar_t>();
scalar_t* max_exp_avg_sq_data = amsgrad ? max_exp_avg_sq.data_ptr<scalar_t>() : nullptr;
scalar_t* grad_data = grad.data_ptr<scalar_t>();
// need to use double here to align with non-fused adam
double bias_correction1 = 1 - std::pow(beta1, step);
double bias_correction2 = 1 - std::pow(beta2, step);
double exp_avg_grad_coefficient = 1 - beta1;
double exp_avg_sq_grad_coefficient = 1 - beta2;
double bias_correction2_sqrt = std::sqrt(bias_correction2);
constexpr size_t cache_line_size = 64;
constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
auto adam_fn = [&](int64_t begin, int64_t end) {
// local pointers
begin *= cache_line_aligned_task_unit;
end = std::min(end * cache_line_aligned_task_unit, param.numel());
scalar_t* param_ptr = param_data + begin;
scalar_t* exp_avg_ptr = exp_avg_data + begin;
scalar_t* exp_avg_sq_ptr = exp_avg_sq_data + begin;
scalar_t* grad_ptr = grad_data + begin;
scalar_t* max_exp_avg_sq_ptr = amsgrad ? max_exp_avg_sq_data + begin : nullptr;
const int64_t size = end - begin;
adam_math<scalar_t, opmath_t, adam_mode>(
param_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
grad_ptr,
max_exp_avg_sq_ptr,
lr,
bias_correction1,
bias_correction2,
exp_avg_grad_coefficient,
exp_avg_sq_grad_coefficient,
bias_correction2_sqrt,
eps,
weight_decay,
beta2,
amsgrad,
maximize,
grad_scale_ptr,
size
);
};
at::parallel_for(
0, num_units, 0, adam_fn);
}
void fused_adam_kernel(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& exp_avg,
const at::Tensor& exp_avg_sq,
const at::Tensor& max_exp_avg_sq,
const at::Tensor& state_step,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const float* grad_scale_ptr,
const ADAM_MODE adam_mode
) {
Tensor grad_contiguous = grad.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adam_kernel", [&] {
if(adam_mode == ADAM_MODE::ORIGINAL){
adam_fused_step_impl<scalar_t, ADAM_MODE::ORIGINAL>(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr);
} else {
adam_fused_step_impl<scalar_t, ADAM_MODE::ADAMW>(param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, state_step, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_ptr);
}
});
}
}
REGISTER_DISPATCH(fused_adam_stub, &fused_adam_kernel);
} // namespace at::native

View File

@ -15517,6 +15517,7 @@
# Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
variants: function
dispatch:
CPU: _fused_adam_kernel_cpu_
CUDA: _fused_adam_kernel_cuda_
autogen: _fused_adam, _fused_adam.out
@ -15526,6 +15527,7 @@
device_check: NoCheck
variants: function
dispatch:
CPU: _fused_adam_kernel_cpu_
CUDA: _fused_adam_kernel_cuda_
autogen: _fused_adam.tensor_lr, _fused_adam.tensor_lr_out
@ -15533,6 +15535,7 @@
# Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
variants: function
dispatch:
CPU: _fused_adamw_kernel_cpu_
CUDA: _fused_adamw_kernel_cuda_
autogen: _fused_adamw, _fused_adamw.out
@ -15542,6 +15545,7 @@
device_check: NoCheck
variants: function
dispatch:
CPU: _fused_adamw_kernel_cpu_
CUDA: _fused_adamw_kernel_cuda_
autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out

View File

@ -1168,6 +1168,7 @@ aten_native_source_codegen_list = [
"aten/src/ATen/native/cpu/SpmmReduceKernel.cpp",
"aten/src/ATen/native/cpu/SparseFactories.cpp",
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
"aten/src/ATen/native/cpu/FusedAdamKernel.cpp",
]
# This aten native source file list will not go through aten codegen process
@ -1402,6 +1403,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/xnnpack/OpContext.cpp",
"aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp",
"aten/src/ATen/native/xnnpack/Shim.cpp",
"aten/src/ATen/native/FusedAdam.cpp",
# Files not in native, but depends on native symbols
# "aten/src/ATen/TensorIndexing.cpp",
"aten/src/ATen/TensorIterator.cpp",

View File

@ -21,9 +21,10 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_optimizers import (
optim_db, optims, OptimizerErrorEnum, _get_optim_inputs_including_global_cliquey_kwargs, TensorTracker)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM)
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM, onlyNativeDeviceTypes)
from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase
from torch.testing._internal.common_cuda import _create_scaling_case
from torch.testing._internal.common_dtype import floating_types_and
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
@ -581,6 +582,49 @@ class TestOptimRenewed(TestCase):
self.assertTrue(a1_grad_imags.all_popped())
self.assertTrue(losses.all_popped())
def _compare_between(self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None):
# why 7? iteration 7 is where we start to see differences for RAdam
# params interacting with the small eps value, because that's right
# after rho_t becomes greater than 5 in step 6.
if assert_eq_kwargs is None:
assert_eq_kwargs = {}
kIterations = 7
tracker = TensorTracker(assert_eq_kwargs)
for i in range(kIterations):
state, updated_params = [], []
if not isinstance(inputs, list):
inputs = [inputs, inputs]
for input, model, optimizer in zip(inputs, models, optimizers):
optimizer.zero_grad()
# Test that step behaves as expected (a no-op) when grads are set to None
if i != 3:
output = model(input)
loss = output.sum()
loss.backward()
optimizer.step()
state.append(optimizer.state)
updated_params.append(model.parameters())
og_state, new_state = state
for og_p, new_p in zip(updated_params[0], updated_params[1]):
tracker.add(og_p)
tracker.pop_check_set(new_p, self)
# check that optimizer states are the same
og_p_state = og_state[og_p]
new_p_state = new_state[new_p]
if assert_step_dtype is not None:
if torch.is_tensor(og_p_state.get("step", None)):
self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
if torch.is_tensor(new_p_state.get("step", None)):
self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
for k in og_p_state:
tracker.add(og_p_state[k])
tracker.pop_check_set(new_p_state[k], self)
self.assertTrue(tracker.all_popped())
def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
"""
@ -589,16 +633,12 @@ class TestOptimRenewed(TestCase):
for provided optimizer configurations.
"""
assert flag in ("foreach", "fused")
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
# why 7? iteration 7 is where we start to see differences for RAdam
# params interacting with the small eps value, because that's right
# after rho_t becomes greater than 5 in step 6.
kIterations = 7
optim_inputs = optim_info.optim_inputs_func(device=device)
optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
updated_params, state = [], []
models, optimizers = [], []
kwargs = deepcopy(optim_input.kwargs)
if kwargs.get("capturable", False) and str(device) == "cpu":
# capturable is not supported on CPU
@ -626,39 +666,10 @@ class TestOptimRenewed(TestCase):
params = list(model.parameters()) + [empty_param]
optimizer = optim_cls(params, **kwargs)
models.append(model)
optimizers.append(optimizer)
for i in range(kIterations):
optimizer.zero_grad()
# Test that step behaves as expected (a no-op) when grads are set to None
if i != 3:
output = model(input)
loss = output.sum()
loss.backward()
optimizer.step()
if assert_step_dtype is not None:
p_state = optimizer.state[params[0]]
if torch.is_tensor(p_state.get("step", None)):
self.assertEqual(p_state["step"].dtype, assert_step_dtype)
state.append(optimizer.state)
updated_params.append(model.parameters())
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
og_state, new_state = state
for og_p, new_p in zip(updated_params[0], updated_params[1]):
self.assertEqual(og_p, new_p, **assert_eq_kwargs)
# check that optimizer states are the same
og_p_state = og_state[og_p]
new_p_state = new_state[new_p]
for k in og_p_state:
self.assertEqual(og_p_state[k], new_p_state[k], **assert_eq_kwargs)
self._compare_between(input, models, optimizers, assert_eq_kwargs, assert_step_dtype)
@skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
@optims([optim for optim in optim_db if "foreach" in optim.supported_impls], dtypes=[torch.float64])
@ -847,16 +858,23 @@ class TestOptimRenewed(TestCase):
self.assertLessEqual(mt_max_mem, expected_max_mem)
@onlyCUDA
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float64])
@onlyNativeDeviceTypes
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=floating_types_and(torch.bfloat16, torch.float16, )
)
def test_fused_matches_forloop(self, device, dtype, optim_info):
if device not in optim_info.supports_fused_on:
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
self._test_derived_optimizers(device, dtype, optim_info, "fused")
@onlyCUDA
@largeTensorTest("64GB", "cuda")
@onlyNativeDeviceTypes
@largeTensorTest("64GB")
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float16])
def test_fused_large_tensor(self, device, dtype, optim_info):
if device not in optim_info.supports_fused_on:
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
@ -1304,10 +1322,11 @@ class TestOptimRenewed(TestCase):
# Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable!
capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
fused = state_dict_cpu["param_groups"][0].get("fused", False)
new_state_dict = optimizer_cuda.state_dict()
for state_cpu, state_cuda in zip(state_dict_cpu["state"].values(), new_state_dict["state"].values()):
if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable else "cpu")
self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable or fused else "cpu")
for _ in range(5):
optimizer.step(closure)
@ -1615,6 +1634,104 @@ class TestOptimRenewed(TestCase):
res2 = optim_neg_inf.step(closure)
self.assertEqual(type(res1), type(res2))
@onlyCUDA
@optims(
[optim for optim in optim_db if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on],
dtypes=floating_types_and(torch.bfloat16, torch.float16,)
)
def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device="cpu")
for optim_input in optim_inputs:
inpts, models, optimizers = [], [], []
for dev in ('cpu', 'cuda'):
kwargs = optim_input.kwargs
kwargs["fused"] = True
inpt = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
).reshape(3, 2)
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
)
model.to(dtype=dtype, device=dev)
# foreach/fused optimizers should be tested with a
# zero_size tensor as its last param.
# ref: https://github.com/pytorch/pytorch/issues/100701
empty_param = torch.empty((), device=dev, dtype=dtype, requires_grad=True)
empty_param.grad = torch.rand_like(empty_param)
params = list(model.parameters()) + [empty_param]
optimizer = optim_cls(params, **kwargs)
inpts.append(inpt)
models.append(model)
optimizers.append(optimizer)
self._compare_between(inpts, models, optimizers)
@onlyCPU
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
# This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers
# but only test Adam/AdamW on CPU
# TODO: haozhe, support SGD and unified this ut with the CUDA only one
if device not in optim_info.supports_fused_on:
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
optim_inputs = optim_info.optim_inputs_func(device=device)
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
kwargs = optim_input.kwargs
for _separate_unscale in (True, False):
self._grad_scaling_autocast_fused_optimizers(
optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale)
def _grad_scaling_autocast_fused_optimizers(self, optimizer_ctor, optimizer_kwargs, separate_unscale):
(
mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, _,
) = _create_scaling_case(optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, device='cpu')
kwargs = deepcopy(optimizer_kwargs)
kwargs["fused"] = False
if 'lr' not in optimizer_kwargs:
# _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr
kwargs['lr'] = 1.0
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
scaler = torch.cpu.amp.GradScaler(init_scale=128.0)
for input, target in data:
opt_control.zero_grad()
with torch.autocast('cpu', dtype=torch.half):
output_control = mod_control(input)
loss_control = loss_fn(output_control, target)
scaler.scale(loss_control).backward()
scaler.step(opt_control)
scaler.update()
opt_scaling.zero_grad()
with torch.autocast('cpu', dtype=torch.half):
output_scaling = mod_scaling(input)
loss_scaling = loss_fn(output_scaling, target)
scaler.scale(loss_scaling).backward()
if separate_unscale:
scaler.unscale_(opt_scaling)
scaler.step(opt_scaling)
scaler.update()
self.assertEqual(loss_control, loss_scaling,)
for param_control, param_scaling in zip(mod_control.parameters(), mod_scaling.parameters()):
self.assertEqual(param_control.grad, param_scaling.grad,)
self.assertEqual(param_control, param_scaling,)
state_control, state_scaling = opt_control.state[param_control], opt_scaling.state[param_scaling]
for k in state_control:
actual = state_scaling[k]
if k == "step":
actual = actual.squeeze()
self.assertEqual(state_control[k], actual,)
@onlyCUDA
@optims([o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32])

View File

@ -47,8 +47,7 @@ from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA, onlyCPU,
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
skipMeta,
PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes,
skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes,
get_all_device_types, skipXLA)
from typing import Tuple
import torch.backends.quantized
@ -5932,7 +5931,7 @@ else:
for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW):
self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": True})
@onlyCUDA
@onlyNativeDeviceTypes
def test_grad_scaling_autocast_fused(self, device):
device = torch.device(device)
for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
@ -5952,8 +5951,6 @@ else:
{"foreach": False, "fused": True},
),
):
if device.type != "cuda":
optimizer_kwargs['fused'] = False
with self.subTest(optimizer=optimizer_ctor, optimizer_kwargs=optimizer_kwargs):
self._test_grads_invalidated_between_unscale_and_step(device.type, optimizer_ctor, optimizer_kwargs)

View File

@ -76,7 +76,7 @@ class Adam(Optimizer):
# Support AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
fused_supported_devices = _get_fused_kernels_supported_devices()
fused_supported_devices = _get_fused_kernels_supported_devices() + ["cpu"]
if not all(
p.device.type in fused_supported_devices and torch.is_floating_point(p)
for pg in self.param_groups

View File

@ -75,7 +75,7 @@ class AdamW(Optimizer):
# Suppor AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
fused_supported_devices = _get_fused_kernels_supported_devices()
fused_supported_devices = _get_fused_kernels_supported_devices() + ["cpu"]
if not all(
p.device.type in fused_supported_devices and torch.is_floating_point(p)
for pg in self.param_groups

View File

@ -44,10 +44,7 @@ from torch.testing._internal.common_utils import (
skipIfTorchDynamo,
TEST_WITH_TORCHDYNAMO,
)
from torch.utils._foreach_utils import (
_get_foreach_kernels_supported_devices,
_get_fused_kernels_supported_devices,
)
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
class OptimizerInput:
@ -143,6 +140,7 @@ class OptimizerInfo:
skips=(), # Indicates which tests to skip
decorators=None, # Additional decorators to apply to generated tests
optim_error_inputs_func=None, # Function to generate optim inputs that error
supports_fused_on: Tuple[str] = (),
):
self.optim_cls = optim_cls
self.optim_inputs_func = optim_inputs_func
@ -160,6 +158,7 @@ class OptimizerInfo:
*(skips if skips else []),
)
self.optim_error_inputs_func = optim_error_inputs_func
self.supports_fused_on = supports_fused_on
def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
result = [set_single_threaded_if_parallel_tbb]
@ -291,7 +290,7 @@ def get_error_inputs_for_all_optims(device, dtype):
# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
def optim_inputs_func_adadelta(device):
def optim_inputs_func_adadelta(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -340,7 +339,7 @@ def optim_error_inputs_func_adadelta(device, dtype):
return error_inputs
def optim_inputs_func_adagrad(device):
def optim_inputs_func_adagrad(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
@ -384,7 +383,7 @@ def optim_error_inputs_func_adagrad(device, dtype):
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
# with all implementation code paths...
def optim_inputs_func_adam(device):
def optim_inputs_func_adam(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -399,7 +398,7 @@ def optim_inputs_func_adam(device):
),
]
return [
total = [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
OptimizerInput(
@ -414,6 +413,19 @@ def optim_inputs_func_adam(device):
params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad"
),
] + (cuda_supported_configs if "cuda" in str(device) else [])
if dtype in (torch.float16,):
for input in total:
"""
Too small eps will make denom to be zero for low precision dtype
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
For example,
>>> a
tensor([0.], dtype=torch.float16)
>>> a + 1e-8
tensor([0.], dtype=torch.float16)
"""
input.kwargs["eps"] = 0.1
return total
def optim_error_inputs_func_adam(device, dtype):
@ -473,7 +485,7 @@ def optim_error_inputs_func_adam(device, dtype):
return error_inputs
def optim_inputs_func_adamax(device):
def optim_inputs_func_adamax(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -524,15 +536,15 @@ def optim_error_inputs_func_adamax(device, dtype):
return error_inputs
def optim_inputs_func_adamw(device):
return optim_inputs_func_adam(device)
def optim_inputs_func_adamw(device, dtype=None):
return optim_inputs_func_adam(device, dtype)
def optim_error_inputs_func_adamw(device, dtype):
return optim_error_inputs_func_adam(device, dtype)
def optim_inputs_func_asgd(device):
def optim_inputs_func_asgd(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -584,7 +596,7 @@ def optim_error_inputs_func_asgd(device, dtype):
return error_inputs
def optim_inputs_func_lbfgs(device):
def optim_inputs_func_lbfgs(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
@ -605,7 +617,7 @@ def optim_error_inputs_func_lbfgs(device, dtype):
# Weird story bro, NAdam and RAdam do not have maximize.
def optim_inputs_func_nadam(device):
def optim_inputs_func_nadam(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -676,7 +688,7 @@ def optim_error_inputs_func_nadam(device, dtype):
# Weird story bro, NAdam and RAdam do not have maximize.
def optim_inputs_func_radam(device=None):
def optim_inputs_func_radam(device=None, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -738,7 +750,7 @@ def optim_error_inputs_func_radam(device, dtype):
return error_inputs
def optim_inputs_func_rmsprop(device):
def optim_inputs_func_rmsprop(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -799,7 +811,7 @@ def optim_error_inputs_func_rmsprop(device, dtype):
return error_inputs
def optim_inputs_func_rprop(device):
def optim_inputs_func_rprop(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -841,7 +853,7 @@ def optim_error_inputs_func_rprop(device, dtype):
return error_inputs
def optim_inputs_func_sgd(device):
def optim_inputs_func_sgd(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
@ -886,7 +898,7 @@ def optim_error_inputs_func_sgd(device, dtype):
return error_inputs
def optim_inputs_func_sparseadam(device):
def optim_inputs_func_sparseadam(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
@ -995,10 +1007,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
x
for x in optim_info.supported_impls
if x not in skip
and (
_get_device_type(device) in _get_fused_kernels_supported_devices()
or x != "fused"
)
and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused")
and (
_get_device_type(device) in _get_foreach_kernels_supported_devices()
or x != "foreach"
@ -1196,6 +1205,7 @@ optim_db: List[OptimizerInfo] = [
),
optim_error_inputs_func=optim_error_inputs_func_adam,
supported_impls=("foreach", "differentiable", "fused"),
supports_fused_on=("cpu", "cuda"),
decorators=(
# Expected floating point error between fused and compiled forloop
DecorateInfo(
@ -1205,6 +1215,21 @@ optim_db: List[OptimizerInfo] = [
active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
and kwargs["dtype"] == torch.float64,
),
DecorateInfo(
# Note on tolerances:
# difference comes from the fact that the non fused kernel have
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
# to make sure there is no discrepancies between cuda fused kernel
# and cpu fused kernel
toleranceOverride(
{
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
torch.float16: tol(atol=5e-3, rtol=5e-3),
}
),
"TestOptimRenewed",
"test_fused_matches_forloop",
),
),
skips=(
DecorateInfo(
@ -1364,6 +1389,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_adamw,
optim_error_inputs_func=optim_error_inputs_func_adamw,
supported_impls=("foreach", "differentiable", "fused"),
supports_fused_on=("cpu", "cuda"),
decorators=(
# Expected error between compiled forloop and fused optimizers
DecorateInfo(
@ -1373,6 +1399,21 @@ optim_db: List[OptimizerInfo] = [
active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
and kwargs["dtype"] == torch.float64,
),
DecorateInfo(
toleranceOverride(
# Note on tolerances:
# difference comes from the fact that the non fused kernel have
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
# to make sure there is no discrepancies between cuda fused kernel
# and cpu fused kernel
{
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
torch.float16: tol(atol=5e-3, rtol=5e-3),
}
),
"TestOptimRenewed",
"test_fused_matches_forloop",
),
),
skips=(
DecorateInfo(
@ -1865,6 +1906,7 @@ optim_db: List[OptimizerInfo] = [
},
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
),
supports_fused_on=("cuda",),
skips=(
DecorateInfo(
skipIfTorchDynamo(
@ -2060,7 +2102,10 @@ class TensorTracker:
numerical discrepancies, and so when the test fails, it is likely a real problem.
"""
def __init__(self):
def __init__(self, assert_eq_kwargs=None):
if assert_eq_kwargs is None:
assert_eq_kwargs = {}
self.assert_eq_kwargs = assert_eq_kwargs
self.tensors = []
def add(self, tensor):
@ -2080,7 +2125,7 @@ class TensorTracker:
ref = self.tensors.pop(0)
testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}")
testcase.assertEqual(tensor_to_set, ref)
testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs)
with torch.no_grad():
tensor_to_set.copy_(ref)