mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
move param's device check to _init_group for fused (#131153)
There could be some cases where the params have the meta device when calling optimizer's dunder init and those params are materialized in the first computation. This change would allow such situation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131153 Approved by: https://github.com/mlazos, https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
12b8e29203
commit
702c810780
@ -1029,6 +1029,41 @@ class TestOptimRenewed(TestCase):
|
||||
self.skipTest("MPS supports only torch.float16 and torch.float32")
|
||||
self._test_derived_optimizers(device, dtype, optim_info, "fused")
|
||||
|
||||
@optims(
|
||||
[optim for optim in optim_db if "fused" in optim.supported_impls],
|
||||
dtypes=(torch.float32,),
|
||||
)
|
||||
def test_fused_error_on_params_on_meta(self, device, dtype, optim_info):
|
||||
if _get_device_type(device) not in optim_info.supports_fused_on:
|
||||
self.skipTest(
|
||||
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
|
||||
)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 3),
|
||||
torch.nn.Sigmoid(),
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Sigmoid(),
|
||||
).to(dtype)
|
||||
|
||||
optimizer = optim_info.optim_cls(model.parameters(), fused=True)
|
||||
with torch.device("meta"):
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"`fused=True` requires all the params to be floating point Tensors",
|
||||
):
|
||||
optimizer.step()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
model.to_empty(device=device)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer.step()
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@largeTensorTest("64GB")
|
||||
@optims(
|
||||
|
||||
@ -3,10 +3,10 @@ from typing import cast, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
from .optimizer import (
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_differentiable_doc,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
@ -68,21 +68,9 @@ class Adagrad(Optimizer):
|
||||
if fused:
|
||||
if differentiable:
|
||||
raise RuntimeError("`fused` does not support `differentiable`")
|
||||
self._step_supports_amp_scaling = True
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
# Not support CUDA yet
|
||||
fused_supported_devices.remove("cuda")
|
||||
if not all(
|
||||
p.device.type in fused_supported_devices and torch.is_floating_point(p)
|
||||
for pg in self.param_groups
|
||||
for p in pg["params"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
"`fused=True` requires all the params to be floating point Tensors of "
|
||||
f"supported devices: {fused_supported_devices}."
|
||||
)
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
self._need_device_dtype_check_for_fused = True
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
@ -136,6 +124,13 @@ class Adagrad(Optimizer):
|
||||
has_sparse_grad, has_complex = False, False
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
if group["fused"] and getattr(
|
||||
self,
|
||||
"_need_device_dtype_check_for_fused",
|
||||
True,
|
||||
):
|
||||
_device_dtype_check_for_fused(p, cuda_unsupported=True)
|
||||
self._need_device_dtype_check_for_fused = False
|
||||
has_sparse_grad |= p.grad.is_sparse
|
||||
has_complex |= torch.is_complex(p)
|
||||
params_with_grad.append(p)
|
||||
|
||||
@ -4,11 +4,11 @@ from typing import cast, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
@ -85,16 +85,6 @@ class Adam(Optimizer):
|
||||
# Support AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
if not all(
|
||||
p.device.type in fused_supported_devices and torch.is_floating_point(p)
|
||||
for pg in self.param_groups
|
||||
for p in pg["params"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
"`fused=True` requires all the params to be floating point Tensors of "
|
||||
f"supported devices: {fused_supported_devices}."
|
||||
)
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
|
||||
@ -145,6 +135,8 @@ class Adam(Optimizer):
|
||||
state = self.state[p]
|
||||
# Lazy state initialization
|
||||
if len(state) == 0:
|
||||
if group["fused"]:
|
||||
_device_dtype_check_for_fused(p)
|
||||
# note(crcrpar): [special device hosting for step]
|
||||
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
|
||||
@ -4,11 +4,11 @@ from typing import cast, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
@ -80,20 +80,6 @@ class AdamW(Optimizer):
|
||||
if differentiable:
|
||||
raise RuntimeError("`fused` does not support `differentiable`")
|
||||
self._step_supports_amp_scaling = True
|
||||
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||
# Suppor AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
if not all(
|
||||
p.device.type in fused_supported_devices and torch.is_floating_point(p)
|
||||
for pg in self.param_groups
|
||||
for p in pg["params"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
"`fused=True` requires all the params to be floating point Tensors of "
|
||||
f"supported devices: {fused_supported_devices}."
|
||||
)
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
|
||||
@ -145,6 +131,8 @@ class AdamW(Optimizer):
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
if group["fused"]:
|
||||
_device_dtype_check_for_fused(p)
|
||||
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state["step"] = (
|
||||
|
||||
@ -192,6 +192,19 @@ def _default_to_fused_or_foreach(
|
||||
return fused, foreach
|
||||
|
||||
|
||||
def _device_dtype_check_for_fused(
|
||||
p: torch.Tensor, cuda_unsupported: bool = False
|
||||
) -> None:
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
if cuda_unsupported:
|
||||
fused_supported_devices.remove("cuda")
|
||||
if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)):
|
||||
raise RuntimeError(
|
||||
"`fused=True` requires all the params to be floating point Tensors of "
|
||||
f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}"
|
||||
)
|
||||
|
||||
|
||||
def _view_as_real(params, *state_and_grads):
|
||||
for i, p in enumerate(params):
|
||||
if torch.is_complex(p):
|
||||
|
||||
@ -4,10 +4,10 @@ from typing import cast, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
|
||||
|
||||
from .optimizer import (
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_differentiable_doc,
|
||||
_foreach_doc,
|
||||
_fused_doc,
|
||||
@ -62,17 +62,7 @@ class SGD(Optimizer): # noqa: D101
|
||||
|
||||
if fused:
|
||||
self._step_supports_amp_scaling = True
|
||||
|
||||
fused_supported_devices = _get_fused_kernels_supported_devices()
|
||||
if not all(
|
||||
p.device.type in fused_supported_devices and torch.is_floating_point(p)
|
||||
for pg in self.param_groups
|
||||
for p in pg["params"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
"`fused=True` requires all the params to be floating point Tensors of "
|
||||
f"supported devices: {fused_supported_devices}."
|
||||
)
|
||||
self._need_device_dtype_check_for_fused = True
|
||||
if differentiable:
|
||||
raise RuntimeError("`fused` does not support `differentiable`")
|
||||
if foreach:
|
||||
@ -92,6 +82,11 @@ class SGD(Optimizer): # noqa: D101
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
if group["fused"] and getattr(
|
||||
self, "_need_device_dtype_check_for_fused", True
|
||||
):
|
||||
_device_dtype_check_for_fused(p)
|
||||
self._need_device_dtype_check_for_fused = False
|
||||
params.append(p)
|
||||
grads.append(p.grad)
|
||||
if p.grad.is_sparse:
|
||||
|
||||
Reference in New Issue
Block a user