[optim] add fused_adam/adamw_kernel support for CPU device (#123074)

On par with `CUDA` implementation.

For `autocast` logic, same with `CUDA` + `Fused Adam`:
 - check inf in `gradscalar.step`
 - In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.

**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused

# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict

# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108))
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```

**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)

```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```

**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
  exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
  //  std::cout << "exp_avg_sq " <<   exp_avg_sq_ptr[d] << std::endl;
  exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
      scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.

Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
This commit is contained in:
Jane Xu
2024-04-19 09:54:05 +00:00
committed by PyTorch MergeBot
parent 9a71d12d92
commit b412b75b42
10 changed files with 827 additions and 78 deletions

View File

@ -21,9 +21,10 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_optimizers import (
optim_db, optims, OptimizerErrorEnum, _get_optim_inputs_including_global_cliquey_kwargs, TensorTracker)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM)
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM, onlyNativeDeviceTypes)
from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase
from torch.testing._internal.common_cuda import _create_scaling_case
from torch.testing._internal.common_dtype import floating_types_and
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
@ -581,6 +582,49 @@ class TestOptimRenewed(TestCase):
self.assertTrue(a1_grad_imags.all_popped())
self.assertTrue(losses.all_popped())
def _compare_between(self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None):
# why 7? iteration 7 is where we start to see differences for RAdam
# params interacting with the small eps value, because that's right
# after rho_t becomes greater than 5 in step 6.
if assert_eq_kwargs is None:
assert_eq_kwargs = {}
kIterations = 7
tracker = TensorTracker(assert_eq_kwargs)
for i in range(kIterations):
state, updated_params = [], []
if not isinstance(inputs, list):
inputs = [inputs, inputs]
for input, model, optimizer in zip(inputs, models, optimizers):
optimizer.zero_grad()
# Test that step behaves as expected (a no-op) when grads are set to None
if i != 3:
output = model(input)
loss = output.sum()
loss.backward()
optimizer.step()
state.append(optimizer.state)
updated_params.append(model.parameters())
og_state, new_state = state
for og_p, new_p in zip(updated_params[0], updated_params[1]):
tracker.add(og_p)
tracker.pop_check_set(new_p, self)
# check that optimizer states are the same
og_p_state = og_state[og_p]
new_p_state = new_state[new_p]
if assert_step_dtype is not None:
if torch.is_tensor(og_p_state.get("step", None)):
self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
if torch.is_tensor(new_p_state.get("step", None)):
self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
for k in og_p_state:
tracker.add(og_p_state[k])
tracker.pop_check_set(new_p_state[k], self)
self.assertTrue(tracker.all_popped())
def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
"""
@ -589,16 +633,12 @@ class TestOptimRenewed(TestCase):
for provided optimizer configurations.
"""
assert flag in ("foreach", "fused")
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
# why 7? iteration 7 is where we start to see differences for RAdam
# params interacting with the small eps value, because that's right
# after rho_t becomes greater than 5 in step 6.
kIterations = 7
optim_inputs = optim_info.optim_inputs_func(device=device)
optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
updated_params, state = [], []
models, optimizers = [], []
kwargs = deepcopy(optim_input.kwargs)
if kwargs.get("capturable", False) and str(device) == "cpu":
# capturable is not supported on CPU
@ -626,39 +666,10 @@ class TestOptimRenewed(TestCase):
params = list(model.parameters()) + [empty_param]
optimizer = optim_cls(params, **kwargs)
models.append(model)
optimizers.append(optimizer)
for i in range(kIterations):
optimizer.zero_grad()
# Test that step behaves as expected (a no-op) when grads are set to None
if i != 3:
output = model(input)
loss = output.sum()
loss.backward()
optimizer.step()
if assert_step_dtype is not None:
p_state = optimizer.state[params[0]]
if torch.is_tensor(p_state.get("step", None)):
self.assertEqual(p_state["step"].dtype, assert_step_dtype)
state.append(optimizer.state)
updated_params.append(model.parameters())
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
og_state, new_state = state
for og_p, new_p in zip(updated_params[0], updated_params[1]):
self.assertEqual(og_p, new_p, **assert_eq_kwargs)
# check that optimizer states are the same
og_p_state = og_state[og_p]
new_p_state = new_state[new_p]
for k in og_p_state:
self.assertEqual(og_p_state[k], new_p_state[k], **assert_eq_kwargs)
self._compare_between(input, models, optimizers, assert_eq_kwargs, assert_step_dtype)
@skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
@optims([optim for optim in optim_db if "foreach" in optim.supported_impls], dtypes=[torch.float64])
@ -847,16 +858,23 @@ class TestOptimRenewed(TestCase):
self.assertLessEqual(mt_max_mem, expected_max_mem)
@onlyCUDA
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float64])
@onlyNativeDeviceTypes
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=floating_types_and(torch.bfloat16, torch.float16, )
)
def test_fused_matches_forloop(self, device, dtype, optim_info):
if device not in optim_info.supports_fused_on:
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
self._test_derived_optimizers(device, dtype, optim_info, "fused")
@onlyCUDA
@largeTensorTest("64GB", "cuda")
@onlyNativeDeviceTypes
@largeTensorTest("64GB")
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float16])
def test_fused_large_tensor(self, device, dtype, optim_info):
if device not in optim_info.supports_fused_on:
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
@ -1304,10 +1322,11 @@ class TestOptimRenewed(TestCase):
# Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable!
capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
fused = state_dict_cpu["param_groups"][0].get("fused", False)
new_state_dict = optimizer_cuda.state_dict()
for state_cpu, state_cuda in zip(state_dict_cpu["state"].values(), new_state_dict["state"].values()):
if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable else "cpu")
self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable or fused else "cpu")
for _ in range(5):
optimizer.step(closure)
@ -1615,6 +1634,104 @@ class TestOptimRenewed(TestCase):
res2 = optim_neg_inf.step(closure)
self.assertEqual(type(res1), type(res2))
@onlyCUDA
@optims(
[optim for optim in optim_db if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on],
dtypes=floating_types_and(torch.bfloat16, torch.float16,)
)
def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device="cpu")
for optim_input in optim_inputs:
inpts, models, optimizers = [], [], []
for dev in ('cpu', 'cuda'):
kwargs = optim_input.kwargs
kwargs["fused"] = True
inpt = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
).reshape(3, 2)
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
)
model.to(dtype=dtype, device=dev)
# foreach/fused optimizers should be tested with a
# zero_size tensor as its last param.
# ref: https://github.com/pytorch/pytorch/issues/100701
empty_param = torch.empty((), device=dev, dtype=dtype, requires_grad=True)
empty_param.grad = torch.rand_like(empty_param)
params = list(model.parameters()) + [empty_param]
optimizer = optim_cls(params, **kwargs)
inpts.append(inpt)
models.append(model)
optimizers.append(optimizer)
self._compare_between(inpts, models, optimizers)
@onlyCPU
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
# This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers
# but only test Adam/AdamW on CPU
# TODO: haozhe, support SGD and unified this ut with the CUDA only one
if device not in optim_info.supports_fused_on:
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
optim_inputs = optim_info.optim_inputs_func(device=device)
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
kwargs = optim_input.kwargs
for _separate_unscale in (True, False):
self._grad_scaling_autocast_fused_optimizers(
optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale)
def _grad_scaling_autocast_fused_optimizers(self, optimizer_ctor, optimizer_kwargs, separate_unscale):
(
mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, _,
) = _create_scaling_case(optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, device='cpu')
kwargs = deepcopy(optimizer_kwargs)
kwargs["fused"] = False
if 'lr' not in optimizer_kwargs:
# _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr
kwargs['lr'] = 1.0
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
scaler = torch.cpu.amp.GradScaler(init_scale=128.0)
for input, target in data:
opt_control.zero_grad()
with torch.autocast('cpu', dtype=torch.half):
output_control = mod_control(input)
loss_control = loss_fn(output_control, target)
scaler.scale(loss_control).backward()
scaler.step(opt_control)
scaler.update()
opt_scaling.zero_grad()
with torch.autocast('cpu', dtype=torch.half):
output_scaling = mod_scaling(input)
loss_scaling = loss_fn(output_scaling, target)
scaler.scale(loss_scaling).backward()
if separate_unscale:
scaler.unscale_(opt_scaling)
scaler.step(opt_scaling)
scaler.update()
self.assertEqual(loss_control, loss_scaling,)
for param_control, param_scaling in zip(mod_control.parameters(), mod_scaling.parameters()):
self.assertEqual(param_control.grad, param_scaling.grad,)
self.assertEqual(param_control, param_scaling,)
state_control, state_scaling = opt_control.state[param_control], opt_scaling.state[param_scaling]
for k in state_control:
actual = state_scaling[k]
if k == "step":
actual = actual.squeeze()
self.assertEqual(state_control[k], actual,)
@onlyCUDA
@optims([o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32])