[Dynamo] Fix nested function resume execution (#100426)

Fixes #99665

Let me explain the root cause using the unit test I added:
* This bug is triggered when:
  * ```wrapped``` is a nested function.
  * ```wrapped``` is in another module which is different from the main function ```fn```.
  * There is a graph break inside of ```wrapped```.
* The root cause is when resuming nested function, actually we are using the outermost function(```fn``` in my example)'s global variables, but ```wrapped``` calls ```inner_func``` which is not part of ```fn```'s globals, so we have to set correct globals when nested function resume execution.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100426
Approved by: https://github.com/jansel
This commit is contained in:
Yanbo Liang
2023-05-11 03:10:23 +00:00
committed by PyTorch MergeBot
parent c84627c2ee
commit 075d36d37f
8 changed files with 93 additions and 26 deletions

View File

@ -233,6 +233,7 @@ test_dynamo_shard() {
--exclude-distributed-tests \
--exclude \
test_autograd \
test_jit \
test_proxy_tensor \
test_quantization \
test_public_bindings \

View File

@ -5157,6 +5157,26 @@ def fn():
self.assertTrue(isinstance(compile_out, torch.Size))
self.assertEqual(eager_out, compile_out)
def test_nested_function_resuming_with_correct_globals(self):
# https://github.com/pytorch/pytorch/issues/99665
try:
from .utils import outer_func
except ImportError:
from utils import outer_func
def gn(x, y):
return x + y
def fn(x, y):
return outer_func(gn)(x, y)
x = torch.rand([3])
y = torch.rand([3])
opt_fn = torch.compile(backend="eager")(fn)
ref = fn(x, y)
res = opt_fn(x, y)
self.assertTrue(same(ref, res))
class CustomFunc1(torch.autograd.Function):
@staticmethod

17
test/dynamo/utils.py Normal file
View File

@ -0,0 +1,17 @@
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo
def inner_func():
return torch.is_grad_enabled()
def outer_func(func):
def wrapped(*args):
a = func(*args)
torch._dynamo.graph_break()
return torch.sin(a + 1), inner_func()
return wrapped

View File

@ -35,6 +35,7 @@ if __name__ == '__main__':
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestTracer(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_large_nbr_kernel_args(self):
@ -1990,6 +1991,7 @@ class TestTracer(JitTestCase):
self.assertEqual(model(**input_dict), traced_model(**input_dict))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestMixTracingScripting(JitTestCase):
def test_trace_script(self):
@torch.jit.script

View File

@ -27,6 +27,7 @@ from torch.testing._internal.common_utils import (
numpy_to_torch_dtype_dict,
TEST_SCIPY,
set_default_dtype,
skipIfTorchDynamo,
)
from torch.testing._internal.common_device_type import (
expectedFailureMeta,
@ -1852,6 +1853,7 @@ class TestBinaryUfuncs(TestCase):
_scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)
@onlyNativeDeviceTypes
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_div_and_floordiv_script_vs_python(self, device):
# Creates jitted functions of two tensors
def _wrapped_div(a, b):
@ -1924,6 +1926,7 @@ class TestBinaryUfuncs(TestCase):
self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
@onlyNativeDeviceTypes
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_idiv_and_ifloordiv_vs_python(self, device):
def _wrapped_idiv_tensor(a, b):
a /= b

View File

@ -12,7 +12,7 @@ import numpy as np
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TestCase, run_tests, TEST_WITH_TORCHDYNAMO)
TestCase, run_tests, skipIfTorchDynamo)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
onlyNativeDeviceTypes, skipXLA)
@ -738,10 +738,7 @@ class TestIndexing(TestCase):
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
self.assertEqual(len(w), 2)
@unittest.skipIf(
TEST_WITH_TORCHDYNAMO,
"This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
)
@skipIfTorchDynamo("This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472")
def test_index_put_accumulate_large_tensor(self, device):
# This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
N = (1 << 31) + 5
@ -839,6 +836,7 @@ class TestIndexing(TestCase):
self.assertEqual(out_cuda.cpu(), out_cpu)
@onlyCUDA
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_index_put_accumulate_with_optional_tensors(self, device):
# TODO: replace with a better solution.
# Currently, here using torchscript to put None into indices.
@ -935,6 +933,7 @@ class TestIndexing(TestCase):
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_jit_indexing(self, device):
def fn1(x):
x[x < 50] = 1.0

View File

@ -2,7 +2,7 @@
from typing import Optional, List
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
# End-to-end tests of features in native_functions.yaml
@ -81,6 +81,7 @@ class TestNativeFunctions(TestCase):
return torch._C._nn._test_optional_floatlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_floatlist(self):
self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
@ -134,6 +135,7 @@ class TestNativeFunctions(TestCase):
return torch._C._nn._test_optional_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_intlist(self):
self.do_test_optional_intlist_with_module(IntListWrapperModule())
self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
@ -187,6 +189,7 @@ class TestNativeFunctions(TestCase):
return torch._C._nn._test_optional_filled_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_optional_filled_intlist(self):
def f(n: int):

View File

@ -3,7 +3,6 @@ import enum
import functools
import inspect
import itertools
import sys
import types
from typing import Dict, List
@ -11,11 +10,7 @@ import torch
from .. import variables
from ..allowed_functions import is_allowed, is_builtin_callable
from ..bytecode_transformation import (
create_call_function,
create_instruction,
create_rot_n,
)
from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import unimplemented
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import istensor, istype, make_cell
@ -89,6 +84,26 @@ def init_cellvars(parent, result, code):
return closure_cells
def _create_nested_fn(
code, f_globals, name, defaults, closure, kwdefaults, annotations
):
from types import FunctionType
func = FunctionType(code, f_globals, name, defaults, closure)
func.__kwdefaults__ = kwdefaults
if isinstance(annotations, tuple):
from itertools import pairwise
annotations = dict(pairwise(annotations))
# TypeError: __annotations__ must be set to a dict object
assert annotations is None or isinstance(annotations, dict)
func.__annotations__ = annotations
return func
class BaseUserFunctionVariable(VariableTracker):
def get_filename(self):
return self.get_code().co_filename
@ -460,17 +475,27 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
parent.symbolic_locals[var] = child.symbolic_locals[var]
def reconstruct(self, codegen):
flags = 0x00
codegen.load_import_from(__name__, "_create_nested_fn")
codegen(self.code)
codegen.extend_output([codegen._create_load_const(self.f_globals)])
codegen(self.fn_name)
if self.defaults:
flags |= 0x01
codegen(self.defaults)
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.closure:
codegen(self.closure)
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.kwdefaults:
flags |= 0x02
codegen(self.kwdefaults)
if isinstance(
self.annotations, (variables.ConstDictVariable, variables.TupleVariable)
):
flags |= 0x04
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.annotations:
try:
if isinstance(self.annotations, variables.ConstDictVariable):
annotations = {
@ -484,13 +509,10 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
codegen.extend_output([codegen._create_load_const(annotations)])
except NotImplementedError:
codegen(self.annotations)
if self.closure:
flags |= 0x08
codegen(self.closure)
codegen(self.code)
if sys.version_info < (3, 11):
codegen(self.fn_name)
codegen.extend_output([create_instruction("MAKE_FUNCTION", arg=flags)])
else:
codegen.extend_output([codegen.create_load_const(None)])
codegen.extend_output(create_call_function(7, push_null=True))
if self.wraps_source:
codegen.load_import_from("functools", "wraps")