mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][easy] Add AC test and improve graph break message (#121394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121394 Approved by: https://github.com/yanboliang
This commit is contained in:
committed by
PyTorch MergeBot
parent
954d750516
commit
8e98fda7a9
@ -9,6 +9,7 @@ import torch._dynamo.config
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._functorch.config
|
||||
import torch.distributed as dist
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
@ -20,6 +21,9 @@ from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
|
||||
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
requires_distributed = functools.partial(
|
||||
unittest.skipIf, not dist.is_available(), "requires distributed"
|
||||
)
|
||||
|
||||
|
||||
def checkpoint_wrapper(fn):
|
||||
@ -1079,6 +1083,32 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@requires_distributed()
|
||||
def test_distributed_utils_checkpoint_wrapper(self):
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
checkpoint_wrapper as dist_checkpoint_wrapper,
|
||||
)
|
||||
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
self.c = 2
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.sin(x)
|
||||
x = self.linear(x)
|
||||
x = torch.cos(x)
|
||||
return x * self.c
|
||||
|
||||
mod = dist_checkpoint_wrapper(MockModule())
|
||||
x = torch.randn(4, 4)
|
||||
ref = mod(x)
|
||||
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
||||
res = opt_mod(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -242,7 +242,9 @@ class NNModuleVariable(VariableTracker):
|
||||
# Support possibly common cases of class members
|
||||
return VariableBuilder(tx, NNModuleSource(source))(subobj)
|
||||
else:
|
||||
unimplemented(f"class property {typestr(base)} {typestr(subobj)}")
|
||||
unimplemented(
|
||||
f"class property {name} - {typestr(base)} {typestr(subobj)}"
|
||||
)
|
||||
|
||||
return variables.GetAttrVariable(self, name, source=source)
|
||||
|
||||
|
Reference in New Issue
Block a user