mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE]: Improve decorator typing for Optimizer subclasses (#153374)
Improves typing so that all the optimizer subclasses (which all of them that subtype step) do not erase their type signature when this decorator is used. Now *kwarg values and returns will propogate This complements @tsunghsienlee PR #153367 as the type signature of step() was being erased on all the optimizer subclasses by this untyped decorator Pull Request resolved: https://github.com/pytorch/pytorch/pull/153374 Approved by: https://github.com/janeyx99, https://github.com/tsunghsienlee
This commit is contained in:
committed by
PyTorch MergeBot
parent
b0f2891e43
commit
f05b38aa26
@ -56,10 +56,11 @@ class _RequiredParameter:
|
|||||||
required = _RequiredParameter()
|
required = _RequiredParameter()
|
||||||
|
|
||||||
|
|
||||||
def _use_grad_for_differentiable(func):
|
def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
def _use_grad(self, *args, **kwargs):
|
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
|
|
||||||
|
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
|
||||||
prev_grad = torch.is_grad_enabled()
|
prev_grad = torch.is_grad_enabled()
|
||||||
try:
|
try:
|
||||||
# Note on graph break below:
|
# Note on graph break below:
|
||||||
@ -76,7 +77,7 @@ def _use_grad_for_differentiable(func):
|
|||||||
# see https://github.com/pytorch/pytorch/issues/104053
|
# see https://github.com/pytorch/pytorch/issues/104053
|
||||||
torch.set_grad_enabled(self.defaults["differentiable"])
|
torch.set_grad_enabled(self.defaults["differentiable"])
|
||||||
torch._dynamo.graph_break()
|
torch._dynamo.graph_break()
|
||||||
ret = func(self, *args, **kwargs)
|
ret = func(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
torch._dynamo.graph_break()
|
torch._dynamo.graph_break()
|
||||||
torch.set_grad_enabled(prev_grad)
|
torch.set_grad_enabled(prev_grad)
|
||||||
|
Reference in New Issue
Block a user