mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fixes #99665 Let me explain the root cause using the unit test I added: * This bug is triggered when: * ```wrapped``` is a nested function. * ```wrapped``` is in another module which is different from the main function ```fn```. * There is a graph break inside of ```wrapped```. * The root cause is when resuming nested function, actually we are using the outermost function(```fn``` in my example)'s global variables, but ```wrapped``` calls ```inner_func``` which is not part of ```fn```'s globals, so we have to set correct globals when nested function resume execution. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100426 Approved by: https://github.com/jansel
18 lines
295 B
Python
18 lines
295 B
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
|
|
|
|
def inner_func():
|
|
return torch.is_grad_enabled()
|
|
|
|
|
|
def outer_func(func):
|
|
def wrapped(*args):
|
|
a = func(*args)
|
|
torch._dynamo.graph_break()
|
|
return torch.sin(a + 1), inner_func()
|
|
|
|
return wrapped
|