Files
pytorch/torch/cuda/amp/grad_scaler.py
Xuehai Pan f3fce597e9 [BE][Easy][17/19] enforce style for empty lines in import segments in torch/[a-c]*/ and torch/[e-n]*/ (#129769)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769
Approved by: https://github.com/ezyang
2024-08-04 10:24:09 +00:00

39 lines
1.0 KiB
Python

from typing_extensions import deprecated
import torch
# We need to keep this unused import for BC reasons
from torch.amp.grad_scaler import OptState # noqa: F401
__all__ = ["GradScaler"]
class GradScaler(torch.amp.GradScaler):
r"""
See :class:`torch.amp.GradScaler`.
``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead.
"""
@deprecated(
"`torch.cuda.amp.GradScaler(args...)` is deprecated. "
"Please use `torch.amp.GradScaler('cuda', args...)` instead.",
category=FutureWarning,
)
def __init__(
self,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
enabled: bool = True,
) -> None:
super().__init__(
"cuda",
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)