mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Deprecate device-specific GradScaler autocast API (#126527)
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126527 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ef86a27dba
commit
c09205a057
@ -1159,7 +1159,7 @@ t2.start()
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_grad_scaling_scale(self):
|
||||
scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
|
||||
scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
|
||||
t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0")
|
||||
t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1")
|
||||
# Create some nested iterables of tensors on different devices.
|
||||
@ -1205,8 +1205,12 @@ t2.start()
|
||||
opt_scaling1,
|
||||
) = _create_scaling_models_optimizers(device=dev1)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=128.0, growth_factor=2.0, enabled=enabled, growth_interval=1
|
||||
scaler = torch.amp.GradScaler(
|
||||
device="cuda",
|
||||
init_scale=128.0,
|
||||
growth_factor=2.0,
|
||||
enabled=enabled,
|
||||
growth_interval=1,
|
||||
)
|
||||
|
||||
def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
|
||||
|
Reference in New Issue
Block a user