move param's device check to _init_group for fused (#131153)

There could be some cases where the params have the meta device when calling optimizer's dunder init and those params are materialized in the first computation. This change would allow such situation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131153
Approved by: https://github.com/mlazos, https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
Masaki Kozuki
2024-08-17 04:49:47 +00:00
committed by PyTorch MergeBot
parent 12b8e29203
commit 702c810780
6 changed files with 70 additions and 52 deletions

View File

@ -1029,6 +1029,41 @@ class TestOptimRenewed(TestCase):
self.skipTest("MPS supports only torch.float16 and torch.float32")
self._test_derived_optimizers(device, dtype, optim_info, "fused")
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=(torch.float32,),
)
def test_fused_error_on_params_on_meta(self, device, dtype, optim_info):
if _get_device_type(device) not in optim_info.supports_fused_on:
self.skipTest(
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
)
with torch.device("meta"):
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
).to(dtype)
optimizer = optim_info.optim_cls(model.parameters(), fused=True)
with torch.device("meta"):
for p in model.parameters():
p.grad = torch.rand_like(p)
with self.assertRaisesRegex(
RuntimeError,
"`fused=True` requires all the params to be floating point Tensors",
):
optimizer.step()
optimizer.zero_grad(set_to_none=True)
model.to_empty(device=device)
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer.step()
@onlyNativeDeviceTypes
@largeTensorTest("64GB")
@optims(

View File

@ -3,10 +3,10 @@ from typing import cast, List, Optional, Union
import torch
from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import (
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_foreach_doc,
_get_scalar_dtype,
@ -68,21 +68,9 @@ class Adagrad(Optimizer):
if fused:
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
fused_supported_devices = _get_fused_kernels_supported_devices()
# Not support CUDA yet
fused_supported_devices.remove("cuda")
if not all(
p.device.type in fused_supported_devices and torch.is_floating_point(p)
for pg in self.param_groups
for p in pg["params"]
):
raise RuntimeError(
"`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices}."
)
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
self._need_device_dtype_check_for_fused = True
for group in self.param_groups:
for p in group["params"]:
@ -136,6 +124,13 @@ class Adagrad(Optimizer):
has_sparse_grad, has_complex = False, False
for p in group["params"]:
if p.grad is not None:
if group["fused"] and getattr(
self,
"_need_device_dtype_check_for_fused",
True,
):
_device_dtype_check_for_fused(p, cuda_unsupported=True)
self._need_device_dtype_check_for_fused = False
has_sparse_grad |= p.grad.is_sparse
has_complex |= torch.is_complex(p)
params_with_grad.append(p)

View File

@ -4,11 +4,11 @@ from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
@ -85,16 +85,6 @@ class Adam(Optimizer):
# Support AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
fused_supported_devices = _get_fused_kernels_supported_devices()
if not all(
p.device.type in fused_supported_devices and torch.is_floating_point(p)
for pg in self.param_groups
for p in pg["params"]
):
raise RuntimeError(
"`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices}."
)
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
@ -145,6 +135,8 @@ class Adam(Optimizer):
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
if group["fused"]:
_device_dtype_check_for_fused(p)
# note(crcrpar): [special device hosting for step]
# Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.

View File

@ -4,11 +4,11 @@ from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
@ -80,20 +80,6 @@ class AdamW(Optimizer):
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
# TODO(crcrpar): [low prec params & their higher prec copy]
# Suppor AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
fused_supported_devices = _get_fused_kernels_supported_devices()
if not all(
p.device.type in fused_supported_devices and torch.is_floating_point(p)
for pg in self.param_groups
for p in pg["params"]
):
raise RuntimeError(
"`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices}."
)
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
@ -145,6 +131,8 @@ class AdamW(Optimizer):
# State initialization
if len(state) == 0:
if group["fused"]:
_device_dtype_check_for_fused(p)
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (

View File

@ -192,6 +192,19 @@ def _default_to_fused_or_foreach(
return fused, foreach
def _device_dtype_check_for_fused(
p: torch.Tensor, cuda_unsupported: bool = False
) -> None:
fused_supported_devices = _get_fused_kernels_supported_devices()
if cuda_unsupported:
fused_supported_devices.remove("cuda")
if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)):
raise RuntimeError(
"`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}"
)
def _view_as_real(params, *state_and_grads):
for i, p in enumerate(params):
if torch.is_complex(p):

View File

@ -4,10 +4,10 @@ from typing import cast, List, Optional, Union
import torch
from torch import Tensor
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
from .optimizer import (
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_foreach_doc,
_fused_doc,
@ -62,17 +62,7 @@ class SGD(Optimizer): # noqa: D101
if fused:
self._step_supports_amp_scaling = True
fused_supported_devices = _get_fused_kernels_supported_devices()
if not all(
p.device.type in fused_supported_devices and torch.is_floating_point(p)
for pg in self.param_groups
for p in pg["params"]
):
raise RuntimeError(
"`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices}."
)
self._need_device_dtype_check_for_fused = True
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
if foreach:
@ -92,6 +82,11 @@ class SGD(Optimizer): # noqa: D101
for p in group["params"]:
if p.grad is not None:
if group["fused"] and getattr(
self, "_need_device_dtype_check_for_fused", True
):
_device_dtype_check_for_fused(p)
self._need_device_dtype_check_for_fused = False
params.append(p)
grads.append(p.grad)
if p.grad.is_sparse: