mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
We are seeing crashes of the form ``` Traceback (most recent call last): File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 1487, in run while self.step(): File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 1348, in step self.dispatch_table[inst.opcode](self, inst) File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 2437, in LOAD_ATTR self._load_attr(inst) File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 2425, in _load_attr result = BuiltinVariable(getattr).call_function( File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builtin.py", line 1347, in call_function return handler(tx, args, kwargs) File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builtin.py", line 967, in <lambda> tx, [v.realize() for v in args], kwargs File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builtin.py", line 967, in <listcomp> tx, [v.realize() for v in args], kwargs File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/lazy.py", line 72, in realize self._cache.realize() File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/lazy.py", line 33, in realize self.vt = builder.VariableBuilder(tx, self.source)(self.value) File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builder.py", line 445, in __call__ vt = self._wrap(value) File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builder.py", line 1043, in _wrap torch._dynamo.utils.store_user_object_weakref(value) File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/utils.py", line 4694, in store_user_object_weakref user_obj_id_to_weakref[obj_id] = weakref.ref(obj) torch._dynamo.exc.InternalTorchDynamoError: TypeError: cannot create weak reference to 'torch.Event' object ``` This pull request makes us gracefully graph break, vs explicitly crashing. I've added a test which reproduces the issue. There is a side discussion re: how did torch.Event support ever work here, since it appears you cannot take a weakref to a torch.Event Pull Request resolved: https://github.com/pytorch/pytorch/pull/163168 Approved by: https://github.com/Lucaskabela, https://github.com/jansel
298 lines
8.7 KiB
Python
298 lines
8.7 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import inspect
|
|
import io
|
|
import os
|
|
import tempfile
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._dynamo.testing import CompileCounter
|
|
|
|
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.linear(x))
|
|
|
|
|
|
class InPlaceCompilationTests(TestCase):
|
|
def test_compilation(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
cnt = CompileCounter()
|
|
model.compile(backend=cnt)
|
|
x = torch.randn(10, 10)
|
|
model(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_overwrite_call_impl(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
self.assertTrue(model._compiled_call_impl is None)
|
|
model.compile()
|
|
self.assertTrue(model._compiled_call_impl is not None)
|
|
|
|
def test_save(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
model.compile()
|
|
model(torch.randn(1, 10))
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
torch.save(model, os.path.join(tmpdirname, "model.pt"))
|
|
# weights_only=False as this is a legacy use case that loads a module
|
|
loaded_model = torch.load(
|
|
os.path.join(tmpdirname, "model.pt"), weights_only=False
|
|
)
|
|
loaded_model(torch.randn(1, 10))
|
|
|
|
def test_state_dict_save(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
model.compile()
|
|
model(torch.randn(1, 10))
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
|
|
loaded_model = ToyModel()
|
|
loaded_model.load_state_dict(
|
|
# weights_only=False as this is a legacy use case that loads a module
|
|
torch.load(os.path.join(tmpdirname, "model.pt"), weights_only=False)
|
|
)
|
|
loaded_model(torch.randn(1, 10))
|
|
|
|
def test_jit_save(self):
|
|
torch._dynamo.reset()
|
|
model = ToyModel()
|
|
model.compile()
|
|
model(torch.randn(1, 10))
|
|
scripted_model = torch.jit.script(model)
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
|
|
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
|
|
loaded_model(torch.randn(1, 10))
|
|
|
|
def test_compilation_callback(self):
|
|
torch._dynamo.reset()
|
|
|
|
@torch._dynamo.on_compile_start
|
|
def start_callback(_):
|
|
print("Compilation started.")
|
|
|
|
@torch._dynamo.on_compile_end
|
|
def end_callback(_):
|
|
print("Compilation ended.")
|
|
|
|
mod = ToyModel()
|
|
x = torch.randn(10, 10)
|
|
|
|
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
|
opt_mod = torch.compile(backend="eager", fullgraph=True)(mod)
|
|
opt_mod(x)
|
|
printed_output = mock_stdout.getvalue().strip()
|
|
|
|
self.assertEqual(printed_output, "Compilation started.\nCompilation ended.")
|
|
|
|
def test_compile_eager_options(self):
|
|
@torch.compile(backend="eager", options={"foo": 2})
|
|
def f(x):
|
|
return x + x
|
|
|
|
f(torch.randn(3))
|
|
|
|
@torch.compile(backend="aot_eager", options={"foo": 2})
|
|
def g(x):
|
|
return x + x
|
|
|
|
g(torch.randn(3))
|
|
|
|
def test_compilation_callback_with_graph_break(self):
|
|
torch._dynamo.reset()
|
|
counter = 0
|
|
|
|
@torch._dynamo.on_compile_start
|
|
def start_callback(_):
|
|
nonlocal counter
|
|
counter += 1
|
|
print(f"Counter = {counter}")
|
|
|
|
@torch._dynamo.on_compile_end
|
|
def end_callback(_):
|
|
nonlocal counter
|
|
counter += 1
|
|
print(f"Counter = {counter}")
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
x = x + 1
|
|
torch._dynamo.graph_break()
|
|
return torch.sin(x)
|
|
|
|
x = torch.randn(10, 10)
|
|
|
|
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
|
fn(x)
|
|
printed_output = mock_stdout.getvalue().strip()
|
|
|
|
self.assertEqual(
|
|
printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4"
|
|
)
|
|
|
|
def test_compilation_constant_hasattr_fail(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return x.max()
|
|
|
|
# We should fallback to normal mode, and throw a AttributeError, not a internal dynamo exception
|
|
with self.assertRaises(AttributeError):
|
|
fn(None)
|
|
|
|
def test_compilation_evnum_hasattr_fail(self):
|
|
from enum import Enum
|
|
|
|
class TestEnum(Enum):
|
|
VALID = 1
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return x.max()
|
|
|
|
# We should fallback to normal mode, and throw a AttributeError, not a internal dynamo exception
|
|
with self.assertRaises(AttributeError):
|
|
fn(TestEnum.VALID)
|
|
|
|
def test_compilation_name_error(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
x = x + 1
|
|
does_not_exist() # noqa: F821
|
|
return x
|
|
|
|
x = torch.randn(10, 10)
|
|
with self.assertRaises(NameError):
|
|
fn(x)
|
|
|
|
def test_compilation_tensor_invalid_method(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
y = torch.tensor(x)
|
|
return y.doesnotexist()
|
|
|
|
x = torch.randn(10, 10)
|
|
|
|
with self.assertRaises(AttributeError):
|
|
fn(x)
|
|
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
|
|
def test_compilation_nn_module_invalid_method(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x + self.doesnotexist
|
|
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend="eager")
|
|
x = torch.randn(1, 1)
|
|
with self.assertRaises(AttributeError):
|
|
opt_mod(x)
|
|
|
|
def test_torch_script_compilation(self):
|
|
@torch.jit.script
|
|
def fn(x: torch.Tensor) -> torch.Tensor:
|
|
return x
|
|
|
|
a = torch.randn(1, 1)
|
|
out = torch.compile(fn)(a)
|
|
self.assertEqual(out, a)
|
|
|
|
def test_to_sparse_to_dense_with_graph_break(self):
|
|
def fn(x):
|
|
x = x.to_sparse()
|
|
x = x.to_dense()
|
|
return x
|
|
|
|
x = torch.tensor([[1.0]])
|
|
c_fn = torch.compile(fn)
|
|
|
|
output = fn(x)
|
|
c_output = c_fn(x)
|
|
self.assertEqual(output, c_output)
|
|
|
|
def test_list_bad_access(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x, y):
|
|
a = [x]
|
|
return a[y]
|
|
|
|
with self.assertRaises(IndexError):
|
|
fn(torch.randn(10), 99)
|
|
|
|
def test_list_bad_weakref(self):
|
|
import weakref
|
|
|
|
a = torch.Event()
|
|
with self.assertRaises(TypeError):
|
|
weakref.ref(a)
|
|
|
|
@torch.compile(backend="eager")
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self, event):
|
|
super().__init__()
|
|
self.event = event
|
|
|
|
def forward(self, x):
|
|
return x * int(self.event.query())
|
|
|
|
e = torch.Event()
|
|
m = Mod(e)
|
|
a = torch.randn(10)
|
|
self.assertEqual(m(a), a)
|
|
|
|
|
|
# The private variants of the below functions are extensively tested
|
|
# So as long as the signatures match we're good
|
|
class PublicTorchCompilerTests(TestCase):
|
|
def check_signature(self, public_fn_name, private_fn_name, private_namespace):
|
|
public_fn = getattr(torch.compiler, public_fn_name)
|
|
private_fn = getattr(private_namespace, private_fn_name)
|
|
|
|
public_sig = inspect.signature(public_fn)
|
|
private_sig = inspect.signature(private_fn)
|
|
|
|
matching = public_sig == private_sig
|
|
matching |= len(public_sig.parameters) < len(private_sig.parameters) and all(
|
|
public == private
|
|
for public, private in zip(
|
|
public_sig.parameters.items(), private_sig.parameters.items()
|
|
)
|
|
)
|
|
|
|
self.assertEqual(
|
|
matching,
|
|
True,
|
|
f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
|
|
)
|
|
|
|
def test_dynamo_signatures(self):
|
|
function_names = [
|
|
"reset",
|
|
"allow_in_graph",
|
|
"list_backends",
|
|
"assume_constant_result",
|
|
"disable",
|
|
]
|
|
|
|
for fn_name in function_names:
|
|
self.check_signature(fn_name, fn_name, torch._dynamo)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|