mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[optim] Move test_grad_scaling_autocast_fused_optimizers to test_cuda.py (#126418)
this PR address the comments in this PR #124904 - Move test_grad_scaling_autocast_fused_optimizers to test_cuda.py - Combine _grad_scaling_autocast_fused_optimizers into test_grad_scaling_autocast_fused_optimizers - Move to OptimizerInfo framework. - For failing tests test_grad_scaling_autocast_fused_optimizers AdamW_cuda_float32, Adam_cuda_float32 - Added toleranceOverride in this PR - created a issue #127000 ``` > (c2env) [sandish@devgpu166.ash6 ~/pytorch (refactoroptimizers)]$ python test/test_cuda.py -k test_grad_scaling_autocast_fused_optimizers -v /home/sandish/pytorch/torch/backends/cudnn/__init__.py:106: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system. warnings.warn( /home/sandish/pytorch/torch/backends/cudnn/__init__.py:106: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system. warnings.warn( test_grad_scaling_autocast_fused_optimizers_Adagrad_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True} {'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'lr': 0.1, 'fused': True} {'lr': 0.1, 'fused': True} {'initial_accumulator_value': 0.1, 'weight_decay': 0.1, 'fused': True} {'initial_accumulator_value': 0.1, 'weight_decay': 0.1, 'fused': True} {'lr': 0.1, 'lr_decay': 0.5, 'weight_decay': 0.1, 'fused': True} {'lr': 0.1, 'lr_decay': 0.5, 'weight_decay': 0.1, 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'lr': tensor(0.0010), 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_AdamW_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_Adam_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_SGD_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'momentum': 0.9, 'fused': True} {'momentum': 0.9, 'fused': True} {'momentum': 0.9, 'dampening': 0.5, 'fused': True} {'momentum': 0.9, 'dampening': 0.5, 'fused': True} {'momentum': 0.9, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_Adagrad_cuda_float32 (__main__.TestCudaOptimsCUDA) ... skipped 'cuda is not supported for fused on Adagrad' test_grad_scaling_autocast_fused_optimizers_AdamW_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'capturable': True, 'fused': True} {'capturable': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True} {'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True} {'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_Adam_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'fused': True} {'capturable': True, 'fused': True} {'capturable': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True} {'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True} {'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True} {'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True} ok test_grad_scaling_autocast_fused_optimizers_SGD_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True} {'fused': True} {'lr': 0.01, 'fused': True} {'lr': 0.01, 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'lr': tensor(0.0010), 'fused': True} {'momentum': 0.9, 'fused': True} {'momentum': 0.9, 'fused': True} {'momentum': 0.9, 'dampening': 0.5, 'fused': True} {'momentum': 0.9, 'dampening': 0.5, 'fused': True} {'momentum': 0.9, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True} {'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} {'weight_decay': 0.1, 'maximize': True, 'fused': True} ok ---------------------------------------------------------------------- Ran 8 tests in 16.117s OK (skipped=1) > lintrunner test/test_cuda.py ---------------------------------------------------------------------- ok No lint issues. > lintrunner torch/testing/_internal/common_optimizers.py ---------------------------------------------------------------------- ok No lint issues. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126418 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
67739d8c6f
commit
da39461d61
@ -20,7 +20,7 @@ from torch.optim.optimizer import (
|
||||
register_optimizer_step_post_hook,
|
||||
register_optimizer_step_pre_hook,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import _create_scaling_case, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
largeTensorTest,
|
||||
@ -1954,100 +1954,6 @@ class TestOptimRenewed(TestCase):
|
||||
optimizers.append(optimizer)
|
||||
self._compare_between(inpts, models, optimizers)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@optims(
|
||||
[optim for optim in optim_db if "fused" in optim.supported_impls],
|
||||
dtypes=[torch.float32],
|
||||
)
|
||||
def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
|
||||
# This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers
|
||||
# but only test Adam/AdamW on CPU
|
||||
# TODO: haozhe, support SGD and unified this ut with the CUDA only one
|
||||
if device not in optim_info.supports_fused_on:
|
||||
self.skipTest(
|
||||
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
|
||||
)
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
optim_cls = optim_info.optim_cls
|
||||
for optim_input in optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
kwargs["fused"] = True
|
||||
for _separate_unscale in (True, False):
|
||||
self._grad_scaling_autocast_fused_optimizers(
|
||||
device=device,
|
||||
optimizer_ctor=optim_cls,
|
||||
optimizer_kwargs=kwargs,
|
||||
separate_unscale=_separate_unscale,
|
||||
)
|
||||
|
||||
def _grad_scaling_autocast_fused_optimizers(
|
||||
self, device, optimizer_ctor, optimizer_kwargs, separate_unscale
|
||||
):
|
||||
torch.manual_seed(20)
|
||||
(
|
||||
mod_control,
|
||||
mod_scaling,
|
||||
opt_control,
|
||||
opt_scaling,
|
||||
data,
|
||||
loss_fn,
|
||||
_,
|
||||
) = _create_scaling_case(
|
||||
optimizer_ctor=optimizer_ctor,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
device="cpu",
|
||||
)
|
||||
kwargs = deepcopy(optimizer_kwargs)
|
||||
kwargs["fused"] = False
|
||||
if "lr" not in optimizer_kwargs:
|
||||
# _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr
|
||||
kwargs["lr"] = 1.0
|
||||
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
|
||||
|
||||
scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0)
|
||||
scaler_control = torch.amp.GradScaler(device, init_scale=128.0)
|
||||
tracker = TensorTracker()
|
||||
for input, target in data:
|
||||
opt_control.zero_grad()
|
||||
with torch.autocast(device_type=device, dtype=torch.half):
|
||||
output_control = mod_control(input)
|
||||
loss_control = loss_fn(output_control, target)
|
||||
scaler_control.scale(loss_control).backward()
|
||||
scaler_control.step(opt_control)
|
||||
scaler_control.update()
|
||||
|
||||
opt_scaling.zero_grad()
|
||||
with torch.autocast(device_type=device, dtype=torch.half):
|
||||
output_scaling = mod_scaling(input)
|
||||
loss_scaling = loss_fn(output_scaling, target)
|
||||
scaler_scaling.scale(loss_scaling).backward()
|
||||
if separate_unscale:
|
||||
scaler_scaling.unscale_(opt_scaling)
|
||||
scaler_scaling.step(opt_scaling)
|
||||
scaler_scaling.update()
|
||||
|
||||
tracker.add(loss_control)
|
||||
tracker.pop_check_set(loss_scaling, self)
|
||||
for param_control, param_scaling in zip(
|
||||
mod_control.parameters(), mod_scaling.parameters()
|
||||
):
|
||||
tracker.add(param_control.grad)
|
||||
tracker.pop_check_set(param_scaling.grad, self)
|
||||
tracker.add(param_control)
|
||||
tracker.pop_check_set(param_scaling, self)
|
||||
|
||||
state_control, state_scaling = (
|
||||
opt_control.state[param_control],
|
||||
opt_scaling.state[param_scaling],
|
||||
)
|
||||
|
||||
for k in state_control:
|
||||
actual = state_scaling[k]
|
||||
if k == "step":
|
||||
actual = actual.squeeze()
|
||||
tracker.add(state_control[k])
|
||||
tracker.pop_check_set(actual, self)
|
||||
|
||||
@onlyCUDA
|
||||
@optims(
|
||||
[o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32]
|
||||
|
Reference in New Issue
Block a user