mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[optim] add fused_adam/adamw_kernel support for CPU device (#123074)
On par with `CUDA` implementation.
For `autocast` logic, same with `CUDA` + `Fused Adam`:
- check inf in `gradscalar.step`
- In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.
**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused
# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict
# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108))
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```
**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)
```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```
**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
// std::cout << "exp_avg_sq " << exp_avg_sq_ptr[d] << std::endl;
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.
Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
9a71d12d92
commit
b412b75b42
@ -21,9 +21,10 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_optimizers import (
|
||||
optim_db, optims, OptimizerErrorEnum, _get_optim_inputs_including_global_cliquey_kwargs, TensorTracker)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM)
|
||||
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM, onlyNativeDeviceTypes)
|
||||
from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase
|
||||
|
||||
from torch.testing._internal.common_cuda import _create_scaling_case
|
||||
from torch.testing._internal.common_dtype import floating_types_and
|
||||
|
||||
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
|
||||
|
||||
@ -581,6 +582,49 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertTrue(a1_grad_imags.all_popped())
|
||||
self.assertTrue(losses.all_popped())
|
||||
|
||||
def _compare_between(self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None):
|
||||
# why 7? iteration 7 is where we start to see differences for RAdam
|
||||
# params interacting with the small eps value, because that's right
|
||||
# after rho_t becomes greater than 5 in step 6.
|
||||
if assert_eq_kwargs is None:
|
||||
assert_eq_kwargs = {}
|
||||
kIterations = 7
|
||||
tracker = TensorTracker(assert_eq_kwargs)
|
||||
for i in range(kIterations):
|
||||
state, updated_params = [], []
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs, inputs]
|
||||
for input, model, optimizer in zip(inputs, models, optimizers):
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Test that step behaves as expected (a no-op) when grads are set to None
|
||||
if i != 3:
|
||||
output = model(input)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
state.append(optimizer.state)
|
||||
updated_params.append(model.parameters())
|
||||
|
||||
og_state, new_state = state
|
||||
for og_p, new_p in zip(updated_params[0], updated_params[1]):
|
||||
tracker.add(og_p)
|
||||
tracker.pop_check_set(new_p, self)
|
||||
|
||||
# check that optimizer states are the same
|
||||
og_p_state = og_state[og_p]
|
||||
new_p_state = new_state[new_p]
|
||||
if assert_step_dtype is not None:
|
||||
if torch.is_tensor(og_p_state.get("step", None)):
|
||||
self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
|
||||
if torch.is_tensor(new_p_state.get("step", None)):
|
||||
self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
|
||||
for k in og_p_state:
|
||||
tracker.add(og_p_state[k])
|
||||
tracker.pop_check_set(new_p_state[k], self)
|
||||
|
||||
self.assertTrue(tracker.all_popped())
|
||||
|
||||
def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
|
||||
"""
|
||||
@ -589,16 +633,12 @@ class TestOptimRenewed(TestCase):
|
||||
for provided optimizer configurations.
|
||||
"""
|
||||
assert flag in ("foreach", "fused")
|
||||
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
|
||||
|
||||
# why 7? iteration 7 is where we start to see differences for RAdam
|
||||
# params interacting with the small eps value, because that's right
|
||||
# after rho_t becomes greater than 5 in step 6.
|
||||
kIterations = 7
|
||||
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
|
||||
optim_cls = optim_info.optim_cls
|
||||
for optim_input in optim_inputs:
|
||||
updated_params, state = [], []
|
||||
models, optimizers = [], []
|
||||
kwargs = deepcopy(optim_input.kwargs)
|
||||
if kwargs.get("capturable", False) and str(device) == "cpu":
|
||||
# capturable is not supported on CPU
|
||||
@ -626,39 +666,10 @@ class TestOptimRenewed(TestCase):
|
||||
params = list(model.parameters()) + [empty_param]
|
||||
|
||||
optimizer = optim_cls(params, **kwargs)
|
||||
models.append(model)
|
||||
optimizers.append(optimizer)
|
||||
|
||||
for i in range(kIterations):
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Test that step behaves as expected (a no-op) when grads are set to None
|
||||
if i != 3:
|
||||
output = model(input)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
if assert_step_dtype is not None:
|
||||
p_state = optimizer.state[params[0]]
|
||||
if torch.is_tensor(p_state.get("step", None)):
|
||||
self.assertEqual(p_state["step"].dtype, assert_step_dtype)
|
||||
|
||||
state.append(optimizer.state)
|
||||
updated_params.append(model.parameters())
|
||||
|
||||
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
|
||||
|
||||
og_state, new_state = state
|
||||
for og_p, new_p in zip(updated_params[0], updated_params[1]):
|
||||
self.assertEqual(og_p, new_p, **assert_eq_kwargs)
|
||||
|
||||
# check that optimizer states are the same
|
||||
og_p_state = og_state[og_p]
|
||||
new_p_state = new_state[new_p]
|
||||
|
||||
for k in og_p_state:
|
||||
self.assertEqual(og_p_state[k], new_p_state[k], **assert_eq_kwargs)
|
||||
|
||||
self._compare_between(input, models, optimizers, assert_eq_kwargs, assert_step_dtype)
|
||||
|
||||
@skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
|
||||
@optims([optim for optim in optim_db if "foreach" in optim.supported_impls], dtypes=[torch.float64])
|
||||
@ -847,16 +858,23 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertLessEqual(mt_max_mem, expected_max_mem)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float64])
|
||||
@onlyNativeDeviceTypes
|
||||
@optims(
|
||||
[optim for optim in optim_db if "fused" in optim.supported_impls],
|
||||
dtypes=floating_types_and(torch.bfloat16, torch.float16, )
|
||||
)
|
||||
def test_fused_matches_forloop(self, device, dtype, optim_info):
|
||||
if device not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
|
||||
self._test_derived_optimizers(device, dtype, optim_info, "fused")
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("64GB", "cuda")
|
||||
@onlyNativeDeviceTypes
|
||||
@largeTensorTest("64GB")
|
||||
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float16])
|
||||
def test_fused_large_tensor(self, device, dtype, optim_info):
|
||||
if device not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
for optim_input in optim_inputs:
|
||||
@ -1304,10 +1322,11 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
# Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable!
|
||||
capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
|
||||
fused = state_dict_cpu["param_groups"][0].get("fused", False)
|
||||
new_state_dict = optimizer_cuda.state_dict()
|
||||
for state_cpu, state_cuda in zip(state_dict_cpu["state"].values(), new_state_dict["state"].values()):
|
||||
if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
|
||||
self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable else "cpu")
|
||||
self.assertEqual(state_cuda["step"].device.type, "cuda" if capturable or fused else "cpu")
|
||||
|
||||
for _ in range(5):
|
||||
optimizer.step(closure)
|
||||
@ -1615,6 +1634,104 @@ class TestOptimRenewed(TestCase):
|
||||
res2 = optim_neg_inf.step(closure)
|
||||
self.assertEqual(type(res1), type(res2))
|
||||
|
||||
@onlyCUDA
|
||||
@optims(
|
||||
[optim for optim in optim_db if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on],
|
||||
dtypes=floating_types_and(torch.bfloat16, torch.float16,)
|
||||
)
|
||||
def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device="cpu")
|
||||
for optim_input in optim_inputs:
|
||||
inpts, models, optimizers = [], [], []
|
||||
for dev in ('cpu', 'cuda'):
|
||||
kwargs = optim_input.kwargs
|
||||
kwargs["fused"] = True
|
||||
inpt = torch.tensor(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
|
||||
).reshape(3, 2)
|
||||
|
||||
torch.manual_seed(1)
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 3),
|
||||
torch.nn.Sigmoid(),
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
model.to(dtype=dtype, device=dev)
|
||||
|
||||
# foreach/fused optimizers should be tested with a
|
||||
# zero_size tensor as its last param.
|
||||
# ref: https://github.com/pytorch/pytorch/issues/100701
|
||||
empty_param = torch.empty((), device=dev, dtype=dtype, requires_grad=True)
|
||||
empty_param.grad = torch.rand_like(empty_param)
|
||||
params = list(model.parameters()) + [empty_param]
|
||||
|
||||
optimizer = optim_cls(params, **kwargs)
|
||||
inpts.append(inpt)
|
||||
models.append(model)
|
||||
optimizers.append(optimizer)
|
||||
self._compare_between(inpts, models, optimizers)
|
||||
|
||||
@onlyCPU
|
||||
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
|
||||
def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
|
||||
# This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers
|
||||
# but only test Adam/AdamW on CPU
|
||||
# TODO: haozhe, support SGD and unified this ut with the CUDA only one
|
||||
if device not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
optim_cls = optim_info.optim_cls
|
||||
for optim_input in optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
for _separate_unscale in (True, False):
|
||||
self._grad_scaling_autocast_fused_optimizers(
|
||||
optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, separate_unscale=_separate_unscale)
|
||||
|
||||
def _grad_scaling_autocast_fused_optimizers(self, optimizer_ctor, optimizer_kwargs, separate_unscale):
|
||||
(
|
||||
mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, _,
|
||||
) = _create_scaling_case(optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, device='cpu')
|
||||
kwargs = deepcopy(optimizer_kwargs)
|
||||
kwargs["fused"] = False
|
||||
if 'lr' not in optimizer_kwargs:
|
||||
# _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr
|
||||
kwargs['lr'] = 1.0
|
||||
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
|
||||
|
||||
scaler = torch.cpu.amp.GradScaler(init_scale=128.0)
|
||||
for input, target in data:
|
||||
opt_control.zero_grad()
|
||||
with torch.autocast('cpu', dtype=torch.half):
|
||||
output_control = mod_control(input)
|
||||
loss_control = loss_fn(output_control, target)
|
||||
scaler.scale(loss_control).backward()
|
||||
scaler.step(opt_control)
|
||||
scaler.update()
|
||||
|
||||
opt_scaling.zero_grad()
|
||||
with torch.autocast('cpu', dtype=torch.half):
|
||||
output_scaling = mod_scaling(input)
|
||||
loss_scaling = loss_fn(output_scaling, target)
|
||||
scaler.scale(loss_scaling).backward()
|
||||
if separate_unscale:
|
||||
scaler.unscale_(opt_scaling)
|
||||
scaler.step(opt_scaling)
|
||||
scaler.update()
|
||||
|
||||
self.assertEqual(loss_control, loss_scaling,)
|
||||
for param_control, param_scaling in zip(mod_control.parameters(), mod_scaling.parameters()):
|
||||
self.assertEqual(param_control.grad, param_scaling.grad,)
|
||||
self.assertEqual(param_control, param_scaling,)
|
||||
|
||||
state_control, state_scaling = opt_control.state[param_control], opt_scaling.state[param_scaling]
|
||||
|
||||
for k in state_control:
|
||||
actual = state_scaling[k]
|
||||
if k == "step":
|
||||
actual = actual.squeeze()
|
||||
self.assertEqual(state_control[k], actual,)
|
||||
|
||||
@onlyCUDA
|
||||
@optims([o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32])
|
||||
|
||||
Reference in New Issue
Block a user