mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[optim] add fused_adagrad support for CPU device (#124905)
Support fused_sgd_kernel support for CPU. ## Bench result: 32 core/sockets ICX Test Scripts: https://gist.github.com/zhuhaozhe/79e842e0a6e25d6d7fa1e4598807272c https://gist.github.com/zhuhaozhe/b4c6998a509dcea1796dd05b3005c969 ``` Tensor Size: 262144, Num Tensor 4, Num Threads: 1 _single_tensor_adagrad time: 0.2500 seconds _fused_adagrad time: 0.0933 seconds Tensor Size: 4194304, Num Tensor 32, Num Threads: 32 _single_tensor_adagrad time: 2.8819 seconds _fused_adagrad time: 1.7591 seconds ``` ## Test Plan: ``` python test_optim.py -k test_fused_matches_forloop python test_optim.py -k test_fused_large_tensor python test_optim.py -k test_can_load_older_state_dict python test_optim.py -k test_grad_scaling_autocast_fused_optimizers python test_torch.py -k test_grad_scaling_autocast_fused python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step ``` Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124905 Approved by: https://github.com/jgong5, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b88a5bd0b
commit
1c3fe84033
@ -911,6 +911,8 @@ class TestOptimRenewed(TestCase):
|
||||
@onlyCUDA
|
||||
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
|
||||
def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info):
|
||||
if device not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"{device} is not supported for fused on {optim_info.optim_cls.__name__}")
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
num_params = 5
|
||||
@ -940,9 +942,12 @@ class TestOptimRenewed(TestCase):
|
||||
# Since this is a unit test, it is more expedient to simulate what the state_dict
|
||||
# would look like, which is basically CPU tensors with fused/capturable flag = True.
|
||||
optim_cls = optim_info.optim_cls
|
||||
if optim_cls.__name__ == "SGD" and impl == "capturable":
|
||||
# Capturable SGD does not exist
|
||||
opt_name = optim_cls.__name__
|
||||
if opt_name in ("SGD", "Adagrad", ) and impl == "capturable":
|
||||
# Capturable SGD/Adagrad does not exist
|
||||
self.skipTest("SGD does not currently support capturable")
|
||||
if impl == "fused" and device not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"{device} is not supported for fused on {opt_name}")
|
||||
|
||||
cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
|
||||
for optim_input in cpu_optim_inputs:
|
||||
@ -1318,6 +1323,8 @@ class TestOptimRenewed(TestCase):
|
||||
return closure_loss if optim_info.step_requires_closure else None
|
||||
|
||||
for optim_input in cpu_optim_inputs:
|
||||
if "fused" in optim_input.kwargs and "cuda" not in optim_info.supports_fused_on:
|
||||
self.skipTest(f"cuda is not supported for fused on {optim_cls.__name__}")
|
||||
params = [Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) for _ in range(2)]
|
||||
for p in params:
|
||||
p.grad = torch.randn_like(p)
|
||||
|
Reference in New Issue
Block a user