[BE]: Improve decorator typing for Optimizer subclasses (#153374)

Improves typing so that all the optimizer subclasses (which all of them that subtype step) do not erase their type signature when this decorator is used. Now *kwarg values and returns will propogate

This complements @tsunghsienlee PR #153367  as the type signature of step() was being erased on all the optimizer subclasses by this untyped decorator

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153374
Approved by: https://github.com/janeyx99, https://github.com/tsunghsienlee
This commit is contained in:
Aaron Gokaslan
2025-05-12 22:55:25 +00:00
committed by PyTorch MergeBot
parent b0f2891e43
commit f05b38aa26

View File

@ -56,10 +56,11 @@ class _RequiredParameter:
required = _RequiredParameter() required = _RequiredParameter()
def _use_grad_for_differentiable(func): def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
def _use_grad(self, *args, **kwargs): def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
import torch._dynamo import torch._dynamo
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
prev_grad = torch.is_grad_enabled() prev_grad = torch.is_grad_enabled()
try: try:
# Note on graph break below: # Note on graph break below:
@ -76,7 +77,7 @@ def _use_grad_for_differentiable(func):
# see https://github.com/pytorch/pytorch/issues/104053 # see https://github.com/pytorch/pytorch/issues/104053
torch.set_grad_enabled(self.defaults["differentiable"]) torch.set_grad_enabled(self.defaults["differentiable"])
torch._dynamo.graph_break() torch._dynamo.graph_break()
ret = func(self, *args, **kwargs) ret = func(*args, **kwargs)
finally: finally:
torch._dynamo.graph_break() torch._dynamo.graph_break()
torch.set_grad_enabled(prev_grad) torch.set_grad_enabled(prev_grad)