Compare commits

...

6 Commits

Author SHA1 Message Date
05ae40859d Enable dynamo traced test_param_group_with_lrscheduler_goes_right_direction
ghstack-source-id: 1458e58daddc843fdb14ae3dee69ca0e125b1c36
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124544
2024-05-09 00:31:52 -07:00
914d8e14d4 Fix capturable enablement conditions
ghstack-source-id: 46e7795411dc7da6a7005f5399659fc8347782ed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125826
2024-05-09 00:31:51 -07:00
64193bda26 Tighten fallback conditions for compiled optim
ghstack-source-id: 8d4bf3ce18be3d87ac0b31baedaf230c03800e6c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125825
2024-05-09 00:31:51 -07:00
3452d0c807 Stack tensors for adamax
ghstack-source-id: 1f10defcc2edc3d212621b81942fcc984bbe4419
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125833
2024-05-09 00:31:51 -07:00
ea6eab8847 Fix item call in radam
ghstack-source-id: 213bc92a9d396cd07bf2fd9fdc2dbf6f2159be03
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125824
2024-05-08 22:04:19 -07:00
3b01f03ef8 Fix call to item in asgd
ghstack-source-id: 6bd0f70ff1f24f27f432d2da073a06fe70cbe91e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125823
2024-05-08 22:04:18 -07:00
7 changed files with 70 additions and 25 deletions

View File

@ -232,7 +232,9 @@ class TestOptimRenewed(TestCase):
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
inpt = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": 0.01}])
# avoid endless recompiles by wrapping LR in a tensor if we're compiling
lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
def closure():

View File

@ -124,21 +124,38 @@ class OptimizerVariable(UserDefinedObjectVariable):
from . import LazyVariableTracker
from .builder import VariableBuilder
# Set capturable to True
for group in self.value.param_groups:
if "capturable" in group:
# We only set capturable if params are on cuda
# and the state is not initialized
def safe_to_set_capturable(group):
all_uninitialized = True
all_cuda = True
for p in group.get("params", list()):
all_cuda &= p.is_cuda
all_uninitialized &= p not in self.value.state
return "capturable" in group and all_uninitialized and all_cuda
# track indices to not set so we don't need to
# in the variable tracker realize the whole state
# we handle guarding the state specially
indices_to_ignore = set()
for ind, group in enumerate(self.value.param_groups):
if safe_to_set_capturable(group):
group["capturable"] = True
else:
indices_to_ignore.add(ind)
param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
)
for param_group_vt in param_groups_vt.items:
for ind, param_group_vt in enumerate(param_groups_vt.items):
key = ConstDictVariable._HashableTracker(
ConstantVariable.create("capturable")
)
if key in param_group_vt.items:
if key in param_group_vt.items and ind not in indices_to_ignore:
param_group_vt.items[key] = ConstantVariable.create(True)
def get_python_args(self, *args, **kwargs):

View File

@ -12,6 +12,7 @@ from .optimizer import (
_get_scalar_dtype,
_get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
@ -386,7 +387,9 @@ def _multi_tensor_adamax(
bias_corrections = [
1 - beta1 ** _get_value(step) for step in grouped_state_steps
]
step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections]
step_size = _stack_if_compiling(
[(_get_value(lr) / bc) * -1 for bc in bias_corrections]
)
torch._foreach_addcdiv_(
grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size
)

View File

@ -246,7 +246,9 @@ def _single_tensor_asgd(
param.add_(grad, alpha=-eta_value) # update parameter
# averaging
if capturable or mu.item() != 1:
# The compiler will only launch one kernel
# and it does not support data-dependent control flow
if torch._utils.is_compiling() or capturable or mu.item() != 1:
ax.add_(param.sub(ax).mul_(mu))
else:
ax.copy_(param)
@ -375,7 +377,7 @@ def _multi_tensor_asgd(
torch._foreach_mul_(new_etas, lr)
torch._foreach_copy_(grouped_etas, new_etas)
else:
step = grouped_state_steps[0].item()
step = _get_value(grouped_state_steps[0])
new_etas = []
new_mus = []

View File

@ -127,12 +127,36 @@ def _disable_dynamo_if_unsupported(single_tensor_fn=None):
globals()[single_tensor_fn.__name__] = single_tensor_fn
def wrapper(func):
import inspect
disabled_func = torch._disable_dynamo(func)
ps = inspect.signature(func).parameters
has_state_steps = True
try:
state_steps_ind = list(ps.keys()).index("state_steps")
except ValueError:
has_state_steps = False
# Today, there are cases where we stack state steps
# and pass them as the value arg of foreach ops.
# Having state steps on cuda as the value arg is not supported in eager,
# but this only occurs in the rare case that the user explicitly deletes
# the capturable flag. If capturable=True, this is not a problem.
@functools.wraps(func)
def maybe_fallback(self, *args, **kwargs):
if is_compiling() and not kwargs.get("capturable", False):
return torch._disable_dynamo(func(self, *args, **kwargs))
def maybe_fallback(*args, **kwargs):
if is_compiling() and (
not kwargs.get("capturable", False)
and has_state_steps
and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
or (
"state_steps" in kwargs
and kwargs["state_steps"]
and kwargs["state_steps"][0].is_cuda
)
):
return disabled_func(*args, **kwargs)
else:
return func(self, *args, **kwargs)
return func(*args, **kwargs)
return maybe_fallback

View File

@ -313,7 +313,9 @@ def _single_tensor_radam(
return (bias_correction2**0.5) / exp_avg_sq_sqrt
# Compute the variance rectification term and update parameters accordingly
if capturable:
# data-dependent control flow is not supported by the compiler
# so use torch.where if compiling
if capturable or torch._utils.is_compiling():
update = torch.where(
rho_t > 5.0, _compute_rect() * _compute_adaptive_lr(), 1.0
)
@ -478,7 +480,12 @@ def _multi_tensor_radam(
else 0
for rho_t in rho_t_list
]
unrectified = [0 if rect > 0 else 1.0 for rect in rect]
# data-dependent control flow is not supported by the compiler
if torch._utils.is_compiling():
unrectified = [torch.where(rect > 0, 0.0, 1.0) for rect in rect]
else:
unrectified = [0.0 if rect > 0.0 else 1.0 for rect in rect]
bias_correction1 = [
1 - beta1 ** _get_value(step) for step in grouped_state_steps

View File

@ -1245,11 +1245,6 @@ optim_db: List[OptimizerInfo] = [
"test_forloop_goes_right_direction",
active_if=lambda kwargs: not kwargs["contiguous"],
),
DecorateInfo(
skipIfTorchDynamo("initial_value is incorrect in dynamo, see #123202"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
@ -1714,11 +1709,6 @@ optim_db: List[OptimizerInfo] = [
"cuda",
),
skips=(
DecorateInfo(
skipIfTorchDynamo("initial_value is incorrect in dynamo, see #123202"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",