Compare commits

...

3 Commits

Author SHA1 Message Date
26004dc2e5 Enable dynamo traced test_param_group_with_lrscheduler_goes_right_direction
ghstack-source-id: 73462085c1665607b0ca6cc09a1c4924de8116e6
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124544
2024-05-10 18:10:30 -07:00
7fb495730a Fix capturable enablement conditions
ghstack-source-id: 0681111bbafcf6c47a4b086b95079d75c5d1a47f
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125826
2024-05-10 18:10:30 -07:00
595c67e8ea Tighten fallback conditions for compiled optim
ghstack-source-id: 1e738b29711afc6013d781802e698cffbd40c458
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125825
2024-05-10 18:10:29 -07:00
4 changed files with 50 additions and 21 deletions

View File

@ -232,7 +232,9 @@ class TestOptimRenewed(TestCase):
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
inpt = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": 0.01}])
# avoid endless recompiles by wrapping LR in a tensor if we're compiling
lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
def closure():

View File

@ -124,9 +124,23 @@ class OptimizerVariable(UserDefinedObjectVariable):
from . import LazyVariableTracker
from .builder import VariableBuilder
# Set capturable to True
for group in self.value.param_groups:
if "capturable" in group:
# We only set capturable if params are on cuda
# and the state is not initialized
def safe_to_set_capturable(group):
all_uninitialized = True
all_cuda = True
for p in group.get("params", list()):
all_cuda &= p.is_cuda
all_uninitialized &= p not in self.value.state
return "capturable" in group and all_uninitialized and all_cuda
# track indices to not set so we don't need to
# in the variable tracker realize the whole state
# we handle guarding the state specially
for ind, group in enumerate(self.value.param_groups):
if safe_to_set_capturable(group):
group["capturable"] = True
param_groups_vt = LazyVariableTracker.realize_all(
@ -134,12 +148,11 @@ class OptimizerVariable(UserDefinedObjectVariable):
self.value.param_groups
)
)
for param_group_vt in param_groups_vt.items:
for ind, param_group_vt in enumerate(param_groups_vt.items):
key = ConstDictVariable._HashableTracker(
ConstantVariable.create("capturable")
)
if key in param_group_vt.items:
param_group_vt.items[key] = ConstantVariable.create(True)
param_group_vt.items[key] = ConstantVariable.create(True)
def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""

View File

@ -127,12 +127,36 @@ def _disable_dynamo_if_unsupported(single_tensor_fn=None):
globals()[single_tensor_fn.__name__] = single_tensor_fn
def wrapper(func):
import inspect
disabled_func = torch._disable_dynamo(func)
ps = inspect.signature(func).parameters
has_state_steps = True
try:
state_steps_ind = list(ps.keys()).index("state_steps")
except ValueError:
has_state_steps = False
# Today, there are cases where we stack state steps
# and pass them as the value arg of foreach ops.
# Having state steps on cuda as the value arg is not supported in eager,
# but this only occurs in the rare case that the user explicitly deletes
# the capturable flag. If capturable=True, this is not a problem.
@functools.wraps(func)
def maybe_fallback(self, *args, **kwargs):
if is_compiling() and not kwargs.get("capturable", False):
return torch._disable_dynamo(func(self, *args, **kwargs))
def maybe_fallback(*args, **kwargs):
if is_compiling() and (
not kwargs.get("capturable", False)
and has_state_steps
and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
or (
"state_steps" in kwargs
and kwargs["state_steps"]
and kwargs["state_steps"][0].is_cuda
)
):
return disabled_func(*args, **kwargs)
else:
return func(self, *args, **kwargs)
return func(*args, **kwargs)
return maybe_fallback

View File

@ -1245,11 +1245,6 @@ optim_db: List[OptimizerInfo] = [
"test_forloop_goes_right_direction",
active_if=lambda kwargs: not kwargs["contiguous"],
),
DecorateInfo(
skipIfTorchDynamo("initial_value is incorrect in dynamo, see #123202"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
@ -1714,11 +1709,6 @@ optim_db: List[OptimizerInfo] = [
"cuda",
),
skips=(
DecorateInfo(
skipIfTorchDynamo("initial_value is incorrect in dynamo, see #123202"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",