Deprecate device-specific GradScaler autocast API (#126527)

# Motivation

## for `torch.amp.GradScaler`,
- `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`.
- `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`.

So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`.

## for `custom_fwd` and `custom_bwd`,
this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU.
So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`.

# Additional Context
Add UT to cover the deprecated warning.
No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them.
To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126527
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2024-05-24 21:49:05 +00:00
committed by PyTorch MergeBot
parent ef86a27dba
commit c09205a057
8 changed files with 43 additions and 19 deletions

View File

@ -1159,7 +1159,7 @@ t2.start()
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_grad_scaling_scale(self):
scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0")
t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1")
# Create some nested iterables of tensors on different devices.
@ -1205,8 +1205,12 @@ t2.start()
opt_scaling1,
) = _create_scaling_models_optimizers(device=dev1)
scaler = torch.cuda.amp.GradScaler(
init_scale=128.0, growth_factor=2.0, enabled=enabled, growth_interval=1
scaler = torch.amp.GradScaler(
device="cuda",
init_scale=128.0,
growth_factor=2.0,
enabled=enabled,
growth_interval=1,
)
def run(model0, model1, optimizer0, optimizer1, try_scaling_api):