mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE]: Improve decorator typing for Optimizer subclasses (#153374)
Improves typing so that all the optimizer subclasses (which all of them that subtype step) do not erase their type signature when this decorator is used. Now *kwarg values and returns will propogate This complements @tsunghsienlee PR #153367 as the type signature of step() was being erased on all the optimizer subclasses by this untyped decorator Pull Request resolved: https://github.com/pytorch/pytorch/pull/153374 Approved by: https://github.com/janeyx99, https://github.com/tsunghsienlee
This commit is contained in:
committed by
PyTorch MergeBot
parent
b0f2891e43
commit
f05b38aa26
@ -56,10 +56,11 @@ class _RequiredParameter:
|
||||
required = _RequiredParameter()
|
||||
|
||||
|
||||
def _use_grad_for_differentiable(func):
|
||||
def _use_grad(self, *args, **kwargs):
|
||||
def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
import torch._dynamo
|
||||
|
||||
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
|
||||
prev_grad = torch.is_grad_enabled()
|
||||
try:
|
||||
# Note on graph break below:
|
||||
@ -76,7 +77,7 @@ def _use_grad_for_differentiable(func):
|
||||
# see https://github.com/pytorch/pytorch/issues/104053
|
||||
torch.set_grad_enabled(self.defaults["differentiable"])
|
||||
torch._dynamo.graph_break()
|
||||
ret = func(self, *args, **kwargs)
|
||||
ret = func(*args, **kwargs)
|
||||
finally:
|
||||
torch._dynamo.graph_break()
|
||||
torch.set_grad_enabled(prev_grad)
|
||||
|
Reference in New Issue
Block a user