Compare commits

...

5 Commits

Author SHA1 Message Date
654f95cec9 Add LRScheduler E2E Tests
ghstack-source-id: e2cf263d50e4dbb13da49f0635a17b0d796adf32
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125653
2024-05-08 12:15:22 -07:00
643939a782 LRScheduler composability tests
ghstack-source-id: d6ebd0c56ba894672b04007215ae09666fadcd5c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125383
2024-05-08 12:15:22 -07:00
989a3aaf66 Enable LR Scheduler to update tensor LR
ghstack-source-id: 7fe8a5a8bae6557f689d4ce0fbcf37d1daf4ec3b
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123753
2024-05-08 12:15:21 -07:00
5dec0f5f2f Fix user warning for tensor LR
ghstack-source-id: d0da131e19d1beb3faa0ee9379f750028b8634f2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123752
2024-05-08 12:15:21 -07:00
86d9bbc75d Switch warning from counter to flag
ghstack-source-id: 2ffaf06c9ef2729b03f4eed8c7780070cafea493
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123751
2024-05-08 12:15:20 -07:00
5 changed files with 304 additions and 162 deletions

View File

@ -12,6 +12,7 @@ import torch
import torch._inductor
import torch._inductor.cudagraph_trees
import torch.optim.lr_scheduler
from torch._inductor import config
from torch._inductor.test_case import TestCase
@ -30,6 +31,25 @@ from torch.optim import (
SGD,
SparseAdam,
)
from torch.optim.lr_scheduler import (
ChainedScheduler,
ConstantLR,
CosineAnnealingLR,
CosineAnnealingWarmRestarts,
CyclicLR,
ExponentialLR,
LambdaLR,
LinearLR,
MultiplicativeLR,
MultiStepLR,
OneCycleLR,
PolynomialLR,
ReduceLROnPlateau,
SequentialLR,
StepLR,
)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
skipCUDAIf,
@ -46,6 +66,45 @@ from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA, has_triton
from torch.testing._internal.triton_utils import requires_cuda
# Note: we use atypical values to amplify error
LR_SCHEDULER_TO_KWARGS = {
LambdaLR: {"lr_lambda": lambda x: 10},
MultiplicativeLR: {"lr_lambda": lambda x: 10},
StepLR: {"step_size": 1, "gamma": 100},
MultiStepLR: {"milestones": [1, 2], "gamma": 100},
ExponentialLR: {"gamma": 100},
SequentialLR: {"schedulers": None, "milestones": [1, 2]},
CosineAnnealingLR: {"T_max": 7},
ChainedScheduler: {"schedulers": None},
CyclicLR: {"base_lr": 0.001, "max_lr": 0.02, "cycle_momentum": False},
CosineAnnealingWarmRestarts: {"T_0": 1},
OneCycleLR: {
"max_lr": 0.02,
"cycle_momentum": False,
"steps_per_epoch": 1,
"epochs": 10,
},
ConstantLR: {"factor": 0.001},
LinearLR: {},
ReduceLROnPlateau: {"factor": 0.99, "patience": 1},
PolynomialLR: {},
}
def create_scheduler(scheduler, optim):
kwargs = LR_SCHEDULER_TO_KWARGS[scheduler]
if "schedulers" in kwargs:
kwargs["schedulers"] = [
create_scheduler(torch.optim.lr_scheduler.ConstantLR, optim)
for _ in range(2)
] + [create_scheduler(torch.optim.lr_scheduler.LambdaLR, optim)]
if scheduler == ChainedScheduler:
return scheduler(**kwargs)
else:
return scheduler(optim, **kwargs)
class KernelCounts(NamedTuple):
multitensor: int
singletensor: int
@ -128,6 +187,7 @@ def build_opt_kwarg_db():
kwargs = dict(optim_inputs.kwargs)
name = f"test_{optim_info.optim_cls.__name__.lower()}"
has_tensor_lr = False
for key, val in kwargs.items():
if not key == "lr" and (
not isinstance(val, bool) or (isinstance(val, bool) and val)
@ -135,6 +195,7 @@ def build_opt_kwarg_db():
name += "_" + key
if key == "lr" and isinstance(kwargs["lr"], torch.Tensor):
has_tensor_lr = True
name += "_tensor_lr"
name += f"_{device}"
@ -152,7 +213,19 @@ def build_opt_kwarg_db():
if kwargs["kernel_count"] is None or kwargs.get("fused", False):
continue
compiled_opt_db.append((optim_info.optim_cls, name, kwargs))
if has_tensor_lr:
for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys():
name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}"
compiled_opt_db.append(
(
optim_info.optim_cls,
name_w_scheduler,
kwargs,
scheduler_cls,
)
)
else:
compiled_opt_db.append((optim_info.optim_cls, name, kwargs, None))
return compiled_opt_db
@ -174,6 +247,13 @@ except (unittest.SkipTest, ImportError) as e:
raise
def call_scheduler(scheduler):
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(1.0) # we won't reduce the metric over two iters anyway
else:
scheduler.step()
def compile_opt(opt_compiled, closure=None, fullgraph=True):
# run the patcher so that step has the expected structure
torch._dynamo.eval_frame.TorchPatcher.patch()
@ -185,6 +265,10 @@ def compile_opt(opt_compiled, closure=None, fullgraph=True):
# and instead manually disables grad before calling step, which is fine
# for now as dynamo does not support differentiable optimizers anyway
step_fn = opt_compiled.step.__wrapped__.__wrapped__
# This ensures we don't receive spam of warnings from LR Scheduler
opt_compiled._opt_called = True
if closure is not None:
def fn():
@ -236,6 +320,7 @@ def check_optim(
def make_test(
optim_cls,
closure=None,
scheduler_cls=None,
kernel_count=2,
device="cuda",
**kwargs,
@ -249,8 +334,10 @@ def make_test(
if run_cudagraphs:
stack.enter_context(config.patch({"triton.cudagraphs": True}))
kwargs_compiled = deepcopy(kwargs)
if isinstance(kwargs.get("lr", None), torch.Tensor):
kwargs["lr"] = kwargs["lr"].to(device)
kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)
torch._dynamo.reset()
torch._inductor.metrics.reset()
@ -265,14 +352,23 @@ def make_test(
model_compiled(input).sum().backward()
opt_eager = optim_cls(model_eager.parameters(), **kwargs)
opt_compiled = optim_cls(model_compiled.parameters(), **kwargs)
opt_compiled = optim_cls(model_compiled.parameters(), **kwargs_compiled)
compiled_step = compile_opt(opt_compiled, closure=closure)
if scheduler_cls:
scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled)
scheduler_eager = create_scheduler(scheduler_cls, opt_eager)
# some schedulers only change after at least an epoch has passed
scheduler_compiled.last_epoch = 1
scheduler_eager.last_epoch = 1
with torch.set_grad_enabled(False):
compiled_step()
for i in range(2):
compiled_step()
opt_eager.step()
opt_eager.step()
if scheduler_cls:
call_scheduler(scheduler_eager)
call_scheduler(scheduler_compiled)
check_optim(
self,
@ -369,6 +465,12 @@ class CompiledOptimizerParityTests(TestCase):
for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
use_scheduler = isinstance(kwargs.get("lr", None), torch.Tensor)
scheduler_classes = (
list(LR_SCHEDULER_TO_KWARGS.keys()) if use_scheduler else [None]
)
for scheduler_cls in scheduler_classes:
torch._dynamo.reset()
torch._inductor.metrics.reset()
input = torch.ones([10, 10], device=device)
@ -385,9 +487,18 @@ class CompiledOptimizerParityTests(TestCase):
for param in model_compiled.parameters():
param.grad = param.grad.to_sparse()
opt_compiled = optim_cls(model_compiled.parameters(), **kwargs)
opt_eager = optim_cls(model_eager.parameters(), **kwargs)
opt_compiled = optim_cls(
model_compiled.parameters(), **deepcopy(kwargs)
)
opt_eager = optim_cls(model_eager.parameters(), **deepcopy(kwargs))
if scheduler_cls:
scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled)
scheduler_eager = create_scheduler(scheduler_cls, opt_eager)
# some schedulers only change after at least an epoch has passed
scheduler_compiled.last_epoch = 1
scheduler_eager.last_epoch = 1
num_steps = 2
if use_closure:
@torch.compile()
@ -401,6 +512,8 @@ class CompiledOptimizerParityTests(TestCase):
return loss
opt_compiled.step(closure)
if scheduler_cls:
call_scheduler(scheduler_compiled)
def closure_eager():
loss = model_eager(input).sum()
@ -411,18 +524,24 @@ class CompiledOptimizerParityTests(TestCase):
return loss
for _ in range(num_steps):
opt_eager.step(closure_eager)
opt_eager.step(closure_eager)
if scheduler_cls:
call_scheduler(scheduler_eager)
else:
@torch.compile()
def fn():
opt_compiled.step()
if scheduler_cls:
call_scheduler(scheduler_compiled)
for _ in range(num_steps):
opt_eager.step()
opt_eager.step()
if scheduler_cls:
call_scheduler(scheduler_eager)
fn()
for _ in range(num_steps):
fn()
check_optim(
@ -648,8 +767,12 @@ class CompiledOptimizerTests(TestCase):
self.assertEqual(ret_val, x)
for optim_cls, name, kwargs in COMPILED_OPT_KWARG_DB:
setattr(CompiledOptimizerTests, name, make_test(optim_cls, **kwargs))
for optim_cls, name, kwargs, scheduler_cls in COMPILED_OPT_KWARG_DB:
setattr(
CompiledOptimizerTests,
name,
make_test(optim_cls, scheduler_cls=scheduler_cls, **kwargs),
)
instantiate_device_type_tests(CompiledOptimizerParityTests, globals())

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: optimizer"]
import functools
import itertools
import math
import tempfile
from typing import Any, Dict, Tuple
@ -125,9 +124,13 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32])
def test_forloop_goes_right_direction(self, device, dtype, optim_info, contiguous, with_lrsched):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
schedulers_constructors = optim_info.scheduler_inputs if with_lrsched else [None]
for optim_input, schedulers_constructor in itertools.product(optim_inputs, schedulers_constructors):
for schedulers_constructor in schedulers_constructors:
# with tensor LR we need fresh inputs for each scheduler
# or mutating it will carry across iters
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop
if contiguous:
@ -178,9 +181,12 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32])
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info, with_lrsched):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
schedulers_constructors = optim_info.scheduler_inputs if with_lrsched else [None]
for optim_input, schedulers_constructor in itertools.product(optim_inputs, schedulers_constructors):
for schedulers_constructor in schedulers_constructors:
# We need a fresh set of inputs if we have a tensor LR
# to not carry mutations across iterations.
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop

View File

@ -3295,6 +3295,13 @@ FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE = re.compile(
f".*({'|'.join(map(re.escape, FBCODE_INLINE_FILES_IN_SKIPPED_DIRS))})"
)
# torch.optim is a special case,
# we usually want to inline it, but the directory
# structure does not match the module structure
# and we want to skip the functions in optim/lr_scheduler.py
# this has precedence over all other rules in check_file
FORCE_SKIP_FILES = {f"{_module_dir(torch)}optim/lr_scheduler.py"}
def _recompile_re():
global SKIP_DIRS_RE
@ -3328,6 +3335,8 @@ def check_file(filename, is_inlined_call=False):
"""Should skip this file?"""
if filename is None:
return SkipResult(True, "filename is None")
if filename in FORCE_SKIP_FILES:
return SkipResult(True, "FORCE_SKIP_FILES")
if any(filename.startswith(d) for d in get_legacy_mod_inlinelist()):
return SkipResult(
False,

View File

@ -99,9 +99,13 @@ class ASGD(Optimizer):
state["step"] = torch.zeros(
(), device=p.device, dtype=_get_scalar_dtype()
)
state["eta"] = torch.tensor(
state["eta"] = (
torch.as_tensor(
group["lr"], device=p.device, dtype=_get_scalar_dtype()
)
.clone()
.detach()
)
state["mu"] = torch.ones(
(), device=p.device, dtype=_get_scalar_dtype()
)

View File

@ -1,13 +1,13 @@
import math
import types
import warnings
import weakref
from bisect import bisect_right
from collections import Counter
from functools import partial, wraps
from functools import partial
from typing import Optional, Sequence
from weakref import ref
from torch import inf
from torch import inf, Tensor
from .optimizer import Optimizer
@ -76,39 +76,31 @@ class LRScheduler:
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(method):
if getattr(method, "_with_counter", False):
# `optimizer.step()` has already been replaced, return.
return method
def patch_track_step_called(opt):
if hasattr(opt.step, "_wrapped_by_lr_sched"):
# we've already patched
return opt.step
# Keep a weak reference to the optimizer instance to prevent
# cyclic references.
instance_ref = weakref.ref(method.__self__)
# Get the unbound method for the same purpose.
func = method.__func__
cls = instance_ref().__class__
del method
def wrap_step(step_fn):
opt_ref = ref(self.optimizer)
func = step_fn.__func__
@wraps(func)
def wrapper(*args, **kwargs):
instance = instance_ref()
instance._step_count += 1
wrapped = func.__get__(instance, cls)
return wrapped(*args, **kwargs)
opt = opt_ref()
opt._opt_called = True
return func.__get__(opt, opt.__class__)(*args, **kwargs)
# Note that the returned function here is no longer a bound method,
# so attributes like `__func__` and `__self__` no longer exist.
wrapper._with_counter = True
wrapper._wrapped_by_lr_sched = True
return wrapper
self.optimizer.step = with_counter(self.optimizer.step)
self.verbose = _check_verbose_deprecated_warning(verbose)
opt.step = wrap_step(opt.step)
patch_track_step_called(self.optimizer)
self.verbose = _check_verbose_deprecated_warning(verbose)
self._initial_step()
def _initial_step(self):
"""Initialize step counts and performs a step"""
self.optimizer._step_count = 0
self._step_count = 0
self.step()
@ -154,7 +146,7 @@ class LRScheduler:
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"):
warnings.warn(
"Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
@ -164,7 +156,7 @@ class LRScheduler:
)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
elif not getattr(self.optimizer, "_opt_called", False):
warnings.warn(
"Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
@ -190,6 +182,10 @@ class LRScheduler:
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
if isinstance(param_group["lr"], Tensor):
lr_val = lr.item() if isinstance(lr, Tensor) else lr
param_group["lr"].fill_(lr)
else:
param_group["lr"] = lr
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
@ -1421,6 +1417,10 @@ class CyclicLR(LRScheduler):
base_lrs = self._format_param("base_lr", optimizer, base_lr)
if last_epoch == -1:
for lr, group in zip(base_lrs, optimizer.param_groups):
if isinstance(group["lr"], Tensor):
lr_val = lr.item() if isinstance(lr, Tensor) else lr
group["lr"].fill_(lr)
else:
group["lr"] = lr
self.max_lrs = self._format_param("max_lr", optimizer, max_lr)