[dynamo][easy] Add AC test and improve graph break message (#121394)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121394
Approved by: https://github.com/yanboliang
This commit is contained in:
Animesh Jain
2024-04-05 14:43:31 -07:00
committed by PyTorch MergeBot
parent 954d750516
commit 8e98fda7a9
2 changed files with 33 additions and 1 deletions

View File

@ -9,6 +9,7 @@ import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.distributed as dist
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
@ -20,6 +21,9 @@ from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_distributed = functools.partial(
unittest.skipIf, not dist.is_available(), "requires distributed"
)
def checkpoint_wrapper(fn):
@ -1079,6 +1083,32 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
)
)
@requires_cuda
@requires_distributed()
def test_distributed_utils_checkpoint_wrapper(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as dist_checkpoint_wrapper,
)
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
self.c = 2
def forward(self, x):
x = torch.sin(x)
x = self.linear(x)
x = torch.cos(x)
return x * self.c
mod = dist_checkpoint_wrapper(MockModule())
x = torch.randn(4, 4)
ref = mod(x)
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
res = opt_mod(x)
self.assertEqual(ref, res)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -242,7 +242,9 @@ class NNModuleVariable(VariableTracker):
# Support possibly common cases of class members
return VariableBuilder(tx, NNModuleSource(source))(subobj)
else:
unimplemented(f"class property {typestr(base)} {typestr(subobj)}")
unimplemented(
f"class property {name} - {typestr(base)} {typestr(subobj)}"
)
return variables.GetAttrVariable(self, name, source=source)