Compare commits

...

6 Commits

Author SHA1 Message Date
a0375cbff7 Support tensor betas in Adam and AdamW
ghstack-source-id: 87320da60be589e62491da7186048d7185b8ac35
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134171
2024-08-30 14:36:46 -07:00
eaa026ced6 Support more foreach ops for tensor beta support
ghstack-source-id: 84d4ab8b822f90a17942874576a1dfc7b56302f2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134170

Fix test counts

fix counts
2024-08-30 14:36:46 -07:00
2fe6299fd8 Update compiled optimizer tests for tensor betas
ghstack-source-id: 1a402bdc348023a641afabef17d68cd1b7ea6097
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134169
2024-08-30 07:33:53 -07:00
e95444648e [dynamo] rewrite addcmul_ to remove graph break
ghstack-source-id: ccc89cd5c0e88399513bc9b23ce6b5f6287452d9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134168
2024-08-30 07:33:53 -07:00
d6a2f54853 [dynamo] Rewrite foreach pow to broadcast scalar argument
ghstack-source-id: c8e8a89ab9d32de0921318b2123169c5e6989834
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134167
2024-08-30 07:33:53 -07:00
1f091655f1 [dynamo] Rewrite lerp to avoid item call in aten
ghstack-source-id: b3163082c78e315d57f80c6f798beefd35efed3a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134166
2024-08-30 07:33:52 -07:00
10 changed files with 195 additions and 17 deletions

View File

@ -156,6 +156,64 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
if a is not None and b is not None:
return a + b
def test_foreach_lerp_(self):
def fn(x, y, s):
return torch._foreach_lerp_(x, y, s)
cnt = torch._dynamo.testing.CompileCounter()
fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
expected = fn(
[torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14],
[torch.ones(2, 2), torch.ones(2, 2)],
torch.tensor(0.5),
)
actual = fn_opt(
[torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14],
[torch.ones(2, 2), torch.ones(2, 2)],
torch.tensor(0.5),
)
self.assertTrue(same(expected, actual))
def test_broadcast_foreach_pow(self):
from torch._dynamo.utils import same
def fn(x, y):
return torch._foreach_pow(x, y)
cnt = torch._dynamo.testing.CompileCounter()
fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
inps = (torch.tensor(0.80), [torch.tensor(3.4), torch.tensor(7.8)])
actual = fn_opt(*inps)
expected = fn(*inps)
self.assertTrue(same(actual, expected))
self.assertTrue(cnt.frame_count, 1)
def test_addcmul_(self):
from copy import deepcopy
from torch._dynamo.utils import same
def fn(x, y, z, s):
return x.addcmul_(y, z, value=s)
cnt = torch._dynamo.testing.CompileCounter()
fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
inps = (
torch.ones(2, 2),
torch.ones(2, 2) + 1,
torch.rand(2, 2),
torch.tensor(0.3),
)
inps_2 = deepcopy(inps)
actual = fn_opt(*inps)
expected = fn(*inps_2)
self.assertTrue(same(actual, expected))
self.assertEqual(cnt.frame_count, 1)
@make_test
def test_functools_partial(a, b):
return clip01(a + b)

View File

@ -120,8 +120,7 @@ KERNEL_COUNT_OVERRIDES = {
"test_adamw_amsgrad_capturable_foreach_xpu": 3,
"test_adamw_amsgrad_capturable_cuda": 6,
"test_adamw_amsgrad_capturable_xpu": 6,
"test_adamw_tensor_lr_amsgrad_capturable_foreach_cuda": 3,
"test_adamw_tensor_lr_amsgrad_capturable_foreach_xpu": 3,
"test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6,
"test_adamw_tensor_lr_amsgrad_capturable_cuda": 6,
"test_adamw_tensor_lr_amsgrad_capturable_xpu": 6,
"test_adam_tensor_lr_amsgrad_capturable_cuda": 6,
@ -132,8 +131,6 @@ KERNEL_COUNT_OVERRIDES = {
"test_adadelta_tensor_lr_capturable_xpu": 6,
"test_rmsprop_tensor_lr_capturable_cuda": 6,
"test_rmsprop_tensor_lr_capturable_xpu": 6,
"test_adadelta_tensor_lr_capturable_foreach_cuda": 4,
"test_adadelta_tensor_lr_capturable_foreach_xpu": 4,
"test_adadelta_foreach_weight_decay_maximize_cpu": 12,
"test_adadelta_foreach_rho_weight_decay_cpu": 12,
"test_adadelta_foreach_weight_decay_cpu": 12,
@ -169,8 +166,6 @@ KERNEL_COUNT_OVERRIDES = {
"test_adamax_tensor_lr_weight_decay_capturable_xpu": 6,
"test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5,
"test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8,
"test_asgd_tensor_lr_weight_decay_maximize_capturable_foreach_cuda": 4,
"test_asgd_tensor_lr_weight_decay_maximize_capturable_foreach_xpu": 4,
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6,
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9,
"test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_foreach_cuda": 3,
@ -217,7 +212,7 @@ def build_opt_kwarg_db():
has_tensor_lr = False
for key, val in kwargs.items():
if not key == "lr" and (
if (not key == "lr" and not key == "betas") and (
not isinstance(val, bool) or (isinstance(val, bool) and val)
):
name += "_" + key
@ -226,6 +221,9 @@ def build_opt_kwarg_db():
has_tensor_lr = True
name += "_tensor_lr"
if key == "betas" and isinstance(kwargs["betas"][0], torch.Tensor):
name += "_tensor_betas"
name += f"_{device}"
kwargs["device"] = device
@ -367,6 +365,16 @@ def make_test(
kwargs["lr"] = kwargs["lr"].to(device)
kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)
if "betas" in kwargs and isinstance(kwargs["betas"][0], torch.Tensor):
kwargs["betas"] = (
kwargs["betas"][0].to(device),
kwargs["betas"][1].to(device),
)
kwargs_compiled["betas"] = (
kwargs_compiled["betas"][0].to(device),
kwargs_compiled["betas"][1].to(device),
)
torch._dynamo.reset()
torch._inductor.metrics.reset()
input = torch.ones([10, 10], device=device)

View File

@ -133,3 +133,19 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
if isinstance(obj, cls):
obj.__init__(*args, **kwargs)
return obj
def foreach_lerp_inplace(self, end, weight):
# decompose foreach lerp into constituent ops, prevents a graph break due to
# converting a value to a scalar when arg[2] is a single tensor
result = torch._foreach_sub(end, self)
result = torch._foreach_mul(result, weight)
return torch._foreach_add_(self, result)
def foreach_pow_scalar(scalar, exps):
return torch._foreach_pow([scalar for _ in exps], exps)
def addcmul_inplace(self, tensor1, tensor2, value):
return self.add_(tensor1 * tensor2 * value)

View File

@ -795,6 +795,20 @@ class TensorVariable(VariableTracker):
tx = InstructionTranslator.current_tx()
return self.call_method(tx, "size", [ConstantVariable.create(0)], {})
def method_addcmul_(self, tensor1, tensor2, *, value=None):
from ..symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
if value is not None:
from .. import polyfills
from .builder import SourcelessBuilder
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.addcmul_inplace),
[self, tensor1, tensor2, value],
{},
)
def method___setitem__(self, key, value):
def has_bool_key(v):
if isinstance(v, TensorVariable):

View File

@ -581,6 +581,30 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
tx, [args[0], result], {}
)
@register(torch._foreach_lerp_)
def handle_inplace_foreach_lerp_scalar(
self, tx: "InstructionTranslator", *args, **kwargs
):
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace),
args,
kwargs,
)
@register(torch._foreach_pow)
def handle_foreach_pow_scalar(
self, tx: "InstructionTranslator", *args, **kwargs
):
# In eager it's more performant to call item() from within the C op implementation
# in compile, it's more performant to not graph break.
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar),
args,
kwargs,
)
@register(torch._assert)
def handle_assert(self, tx: "InstructionTranslator", condition, message):
if (condition.is_python_constant() and condition.as_python_constant()) or (

View File

@ -5911,14 +5911,17 @@ foreach_add_scalar = register_foreach_pointwise(
)
register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True)
foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul)
register_foreach_pointwise(aten._foreach_mul.Tensor, mul)
foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul)
register_foreach_pointwise(aten._foreach_sub.List, sub)
register_foreach_pointwise(aten._foreach_sub.Scalar, sub)
register_foreach_pointwise(aten._foreach_neg.default, neg)
register_foreach_pointwise(aten._foreach_abs.default, abs)
register_foreach_pointwise(aten._foreach_pow.Scalar, pow)
register_foreach_pointwise(aten._foreach_pow.List, pow)
register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow)
foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
register_foreach_pointwise(aten._foreach_div.Tensor, div)
foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div)
register_foreach_pointwise(aten._foreach_sqrt, sqrt)
register_foreach_pointwise(aten._foreach_maximum.List, maximum)

View File

@ -17,6 +17,7 @@ from .optimizer import (
_get_scalar_dtype,
_get_value,
_maximize_doc,
_maybe_copy_beta,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
@ -375,7 +376,9 @@ def _single_tensor_adam(
param = torch.view_as_real(param)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
beta1, device_beta1 = _maybe_copy_beta(beta1, grad.device, grad.dtype)
exp_avg.lerp_(grad, 1 - device_beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
if capturable or differentiable:
@ -483,6 +486,7 @@ def _multi_tensor_adam(
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
@ -496,6 +500,9 @@ def _multi_tensor_adam(
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
beta1, device_beta1 = _maybe_copy_beta(
beta1, device=device_params[0].device, dtype=device_params[0].dtype
)
# Handle complex parameters
if has_complex:
@ -537,15 +544,29 @@ def _multi_tensor_adam(
)
# Decay the first and second moment running average coefficient
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
# Use device beta1 if beta1 is a tensor to ensure all
# tensors are on the same device
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
# Due to the strictness of the _foreach_addcmul API, we can't have a single
# tensor scalar as the scalar arg (only python number is supported there)
# as a result, separate out the value mul
if isinstance(beta2, torch.Tensor):
scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2)
value = 1.0
else:
scaled_device_grads = device_grads
value = 1 - beta2
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, 1 - beta2
device_exp_avg_sqs, scaled_device_grads, device_grads, value
)
# Delete the local intermediate since it won't be used anymore to save on peak memory
# Delete the local intermediate(s) since they won't be used anymore to save on peak memory
del device_grads
del scaled_device_grads
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]

View File

@ -17,6 +17,7 @@ from .optimizer import (
_get_scalar_dtype,
_get_value,
_maximize_doc,
_maybe_copy_beta,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
@ -372,7 +373,8 @@ def _single_tensor_adamw(
param.mul_(1 - lr * weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
beta1, device_beta1 = _maybe_copy_beta(beta1, grad.device, grad.dtype)
exp_avg.lerp_(grad, 1 - device_beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if capturable or differentiable:
@ -493,6 +495,9 @@ def _multi_tensor_adamw(
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
beta1, device_beta1 = _maybe_copy_beta(
beta1, device=device_params[0].device, dtype=device_params[0].dtype
)
if has_complex:
if amsgrad:
@ -528,15 +533,26 @@ def _multi_tensor_adamw(
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
# Decay the first and second moment running average coefficient
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
# Due to the strictness of the _foreach_addcmul API, we can't have a single
# tensor scalar as the scalar arg (only python number is supported there)
# as a result, separate out the value mul
if isinstance(beta2, torch.Tensor):
scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2)
value = 1.0
else:
scaled_device_grads = device_grads
value = 1 - beta2
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, 1 - beta2
device_exp_avg_sqs, scaled_device_grads, device_grads, value
)
# Delete the local intermediate since it won't be used anymore to save on peak memory
# Delete the local intermediate(s) since they won't be used anymore to save on peak memory
del device_grads
del scaled_device_grads
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]

View File

@ -68,6 +68,16 @@ class _RequiredParameter:
required = _RequiredParameter()
def _maybe_copy_beta(beta, device, dtype):
"""Get copy of beta on cuda if beta is a tensor"""
if isinstance(beta, torch.Tensor):
beta_cuda = beta.to(device=device, dtype=dtype)
else:
beta_cuda = beta
return beta, beta_cuda
def _use_grad_for_differentiable(func):
def _use_grad(self, *args, **kwargs):
import torch._dynamo

View File

@ -482,8 +482,6 @@ def optim_error_inputs_func_adagrad(device, dtype):
return error_inputs
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
# with all implementation code paths...
def optim_inputs_func_adam(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
@ -497,6 +495,16 @@ def optim_inputs_func_adam(device, dtype=None):
kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True},
desc="Tensor lr with capturable and amsgrad",
),
OptimizerInput(
params=None,
kwargs={
"lr": torch.tensor(0.001),
"betas": (torch.tensor(0.9), torch.tensor(0.99)),
"amsgrad": True,
"capturable": True,
},
desc="Tensor lr, Tensor betas, with capturable and amsgrad",
),
]
mps_supported_configs = [
OptimizerInput(