mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] Fix nested function resume execution (#100426)
Fixes #99665 Let me explain the root cause using the unit test I added: * This bug is triggered when: * ```wrapped``` is a nested function. * ```wrapped``` is in another module which is different from the main function ```fn```. * There is a graph break inside of ```wrapped```. * The root cause is when resuming nested function, actually we are using the outermost function(```fn``` in my example)'s global variables, but ```wrapped``` calls ```inner_func``` which is not part of ```fn```'s globals, so we have to set correct globals when nested function resume execution. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100426 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
c84627c2ee
commit
075d36d37f
@ -233,6 +233,7 @@ test_dynamo_shard() {
|
||||
--exclude-distributed-tests \
|
||||
--exclude \
|
||||
test_autograd \
|
||||
test_jit \
|
||||
test_proxy_tensor \
|
||||
test_quantization \
|
||||
test_public_bindings \
|
||||
|
@ -5157,6 +5157,26 @@ def fn():
|
||||
self.assertTrue(isinstance(compile_out, torch.Size))
|
||||
self.assertEqual(eager_out, compile_out)
|
||||
|
||||
def test_nested_function_resuming_with_correct_globals(self):
|
||||
# https://github.com/pytorch/pytorch/issues/99665
|
||||
try:
|
||||
from .utils import outer_func
|
||||
except ImportError:
|
||||
from utils import outer_func
|
||||
|
||||
def gn(x, y):
|
||||
return x + y
|
||||
|
||||
def fn(x, y):
|
||||
return outer_func(gn)(x, y)
|
||||
|
||||
x = torch.rand([3])
|
||||
y = torch.rand([3])
|
||||
opt_fn = torch.compile(backend="eager")(fn)
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
|
||||
class CustomFunc1(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
17
test/dynamo/utils.py
Normal file
17
test/dynamo/utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
|
||||
def inner_func():
|
||||
return torch.is_grad_enabled()
|
||||
|
||||
|
||||
def outer_func(func):
|
||||
def wrapped(*args):
|
||||
a = func(*args)
|
||||
torch._dynamo.graph_break()
|
||||
return torch.sin(a + 1), inner_func()
|
||||
|
||||
return wrapped
|
@ -35,6 +35,7 @@ if __name__ == '__main__':
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
class TestTracer(JitTestCase):
|
||||
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
||||
def test_large_nbr_kernel_args(self):
|
||||
@ -1990,6 +1991,7 @@ class TestTracer(JitTestCase):
|
||||
self.assertEqual(model(**input_dict), traced_model(**input_dict))
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
class TestMixTracingScripting(JitTestCase):
|
||||
def test_trace_script(self):
|
||||
@torch.jit.script
|
||||
|
@ -27,6 +27,7 @@ from torch.testing._internal.common_utils import (
|
||||
numpy_to_torch_dtype_dict,
|
||||
TEST_SCIPY,
|
||||
set_default_dtype,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
expectedFailureMeta,
|
||||
@ -1852,6 +1853,7 @@ class TestBinaryUfuncs(TestCase):
|
||||
_scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_div_and_floordiv_script_vs_python(self, device):
|
||||
# Creates jitted functions of two tensors
|
||||
def _wrapped_div(a, b):
|
||||
@ -1924,6 +1926,7 @@ class TestBinaryUfuncs(TestCase):
|
||||
self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_idiv_and_ifloordiv_vs_python(self, device):
|
||||
def _wrapped_idiv_tensor(a, b):
|
||||
a /= b
|
||||
|
@ -12,7 +12,7 @@ import numpy as np
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, TEST_WITH_TORCHDYNAMO)
|
||||
TestCase, run_tests, skipIfTorchDynamo)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
|
||||
onlyNativeDeviceTypes, skipXLA)
|
||||
@ -738,10 +738,7 @@ class TestIndexing(TestCase):
|
||||
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
|
||||
self.assertEqual(len(w), 2)
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
"This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472"
|
||||
)
|
||||
@skipIfTorchDynamo("This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472")
|
||||
def test_index_put_accumulate_large_tensor(self, device):
|
||||
# This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
|
||||
N = (1 << 31) + 5
|
||||
@ -839,6 +836,7 @@ class TestIndexing(TestCase):
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_index_put_accumulate_with_optional_tensors(self, device):
|
||||
# TODO: replace with a better solution.
|
||||
# Currently, here using torchscript to put None into indices.
|
||||
@ -935,6 +933,7 @@ class TestIndexing(TestCase):
|
||||
r = v[c > 0]
|
||||
self.assertEqual(r.shape, (num_ones, 3))
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_jit_indexing(self, device):
|
||||
def fn1(x):
|
||||
x[x < 50] = 1.0
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Optional, List
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
|
||||
|
||||
# End-to-end tests of features in native_functions.yaml
|
||||
|
||||
@ -81,6 +81,7 @@ class TestNativeFunctions(TestCase):
|
||||
return torch._C._nn._test_optional_floatlist(values, const)
|
||||
return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_optional_floatlist(self):
|
||||
self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
|
||||
self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
|
||||
@ -134,6 +135,7 @@ class TestNativeFunctions(TestCase):
|
||||
return torch._C._nn._test_optional_intlist(values, const)
|
||||
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_optional_intlist(self):
|
||||
self.do_test_optional_intlist_with_module(IntListWrapperModule())
|
||||
self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
|
||||
@ -187,6 +189,7 @@ class TestNativeFunctions(TestCase):
|
||||
return torch._C._nn._test_optional_filled_intlist(values, const)
|
||||
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
def test_optional_filled_intlist(self):
|
||||
|
||||
def f(n: int):
|
||||
|
@ -3,7 +3,6 @@ import enum
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import sys
|
||||
import types
|
||||
from typing import Dict, List
|
||||
|
||||
@ -11,11 +10,7 @@ import torch
|
||||
|
||||
from .. import variables
|
||||
from ..allowed_functions import is_allowed, is_builtin_callable
|
||||
from ..bytecode_transformation import (
|
||||
create_call_function,
|
||||
create_instruction,
|
||||
create_rot_n,
|
||||
)
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n
|
||||
from ..exc import unimplemented
|
||||
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||
from ..utils import istensor, istype, make_cell
|
||||
@ -89,6 +84,26 @@ def init_cellvars(parent, result, code):
|
||||
return closure_cells
|
||||
|
||||
|
||||
def _create_nested_fn(
|
||||
code, f_globals, name, defaults, closure, kwdefaults, annotations
|
||||
):
|
||||
from types import FunctionType
|
||||
|
||||
func = FunctionType(code, f_globals, name, defaults, closure)
|
||||
func.__kwdefaults__ = kwdefaults
|
||||
|
||||
if isinstance(annotations, tuple):
|
||||
from itertools import pairwise
|
||||
|
||||
annotations = dict(pairwise(annotations))
|
||||
|
||||
# TypeError: __annotations__ must be set to a dict object
|
||||
assert annotations is None or isinstance(annotations, dict)
|
||||
func.__annotations__ = annotations
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class BaseUserFunctionVariable(VariableTracker):
|
||||
def get_filename(self):
|
||||
return self.get_code().co_filename
|
||||
@ -460,17 +475,27 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
||||
parent.symbolic_locals[var] = child.symbolic_locals[var]
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
flags = 0x00
|
||||
codegen.load_import_from(__name__, "_create_nested_fn")
|
||||
codegen(self.code)
|
||||
codegen.extend_output([codegen._create_load_const(self.f_globals)])
|
||||
codegen(self.fn_name)
|
||||
|
||||
if self.defaults:
|
||||
flags |= 0x01
|
||||
codegen(self.defaults)
|
||||
else:
|
||||
codegen.extend_output([codegen.create_load_const(None)])
|
||||
|
||||
if self.closure:
|
||||
codegen(self.closure)
|
||||
else:
|
||||
codegen.extend_output([codegen.create_load_const(None)])
|
||||
|
||||
if self.kwdefaults:
|
||||
flags |= 0x02
|
||||
codegen(self.kwdefaults)
|
||||
if isinstance(
|
||||
self.annotations, (variables.ConstDictVariable, variables.TupleVariable)
|
||||
):
|
||||
flags |= 0x04
|
||||
else:
|
||||
codegen.extend_output([codegen.create_load_const(None)])
|
||||
|
||||
if self.annotations:
|
||||
try:
|
||||
if isinstance(self.annotations, variables.ConstDictVariable):
|
||||
annotations = {
|
||||
@ -484,13 +509,10 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
||||
codegen.extend_output([codegen._create_load_const(annotations)])
|
||||
except NotImplementedError:
|
||||
codegen(self.annotations)
|
||||
if self.closure:
|
||||
flags |= 0x08
|
||||
codegen(self.closure)
|
||||
codegen(self.code)
|
||||
if sys.version_info < (3, 11):
|
||||
codegen(self.fn_name)
|
||||
codegen.extend_output([create_instruction("MAKE_FUNCTION", arg=flags)])
|
||||
else:
|
||||
codegen.extend_output([codegen.create_load_const(None)])
|
||||
|
||||
codegen.extend_output(create_call_function(7, push_null=True))
|
||||
|
||||
if self.wraps_source:
|
||||
codegen.load_import_from("functools", "wraps")
|
||||
|
Reference in New Issue
Block a user