[optim] add fused_adagrad support for CPU device (#124905)

Support fused_sgd_kernel support for CPU.

## Bench result:
32 core/sockets ICX
Test Scripts:
https://gist.github.com/zhuhaozhe/79e842e0a6e25d6d7fa1e4598807272c
https://gist.github.com/zhuhaozhe/b4c6998a509dcea1796dd05b3005c969
```
Tensor Size: 262144, Num Tensor 4, Num Threads: 1
_single_tensor_adagrad time: 0.2500 seconds
_fused_adagrad time: 0.0933 seconds
Tensor Size: 4194304, Num Tensor 32, Num Threads: 32
_single_tensor_adagrad time: 2.8819 seconds
_fused_adagrad time: 1.7591 seconds
```
## Test Plan:
```
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_optim.py -k test_can_load_older_state_dict
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
python test_torch.py -k test_grad_scaling_autocast_fused
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
```

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124905
Approved by: https://github.com/jgong5, https://github.com/janeyx99
This commit is contained in:
haozhe.zhu
2024-05-11 05:44:39 +00:00
committed by PyTorch MergeBot
parent 4b88a5bd0b
commit 1c3fe84033
12 changed files with 499 additions and 12 deletions

View File

@ -911,6 +911,8 @@ class TestOptimRenewed(TestCase):
@onlyCUDA
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info):
if device not in optim_info.supports_fused_on:
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
num_params = 5
@ -940,9 +942,12 @@ class TestOptimRenewed(TestCase):
# Since this is a unit test, it is more expedient to simulate what the state_dict
# would look like, which is basically CPU tensors with fused/capturable flag = True.
optim_cls = optim_info.optim_cls
if optim_cls.__name__ == "SGD" and impl == "capturable":
# Capturable SGD does not exist
opt_name = optim_cls.__name__
if opt_name in ("SGD", "Adagrad", ) and impl == "capturable":
# Capturable SGD/Adagrad does not exist
self.skipTest("SGD does not currently support capturable")
if impl == "fused" and device not in optim_info.supports_fused_on:
self.skipTest(f"{device} is not supported for fused on {opt_name}")
cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
for optim_input in cpu_optim_inputs:
@ -1318,6 +1323,8 @@ class TestOptimRenewed(TestCase):
return closure_loss if optim_info.step_requires_closure else None
for optim_input in cpu_optim_inputs:
if "fused" in optim_input.kwargs and "cuda" not in optim_info.supports_fused_on:
self.skipTest(f"cuda is not supported for fused on {optim_cls.__name__}")
params = [Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) for _ in range(2)]
for p in params:
p.grad = torch.randn_like(p)