Files
pytorch/test/dynamo/test_activation_checkpointing.py
eellison 481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00

1416 lines
49 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import copy
import functools
import math
import unittest # noqa: F811
from importlib import import_module
import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
SM90OrLater,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import (
checkpoint,
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_distributed = functools.partial(
unittest.skipIf, not dist.is_available(), "requires distributed"
)
def checkpoint_wrapper(fn):
def inner(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
return inner
def count_ops(
gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None
):
def match_rng_op(node, op):
if isinstance(node.target, torch._ops.HigherOrderOperator):
if node.name == "run_and_save_rng_state":
return node.args[0] == op
elif node.name == "run_with_rng_state":
return node.args[1] == op
elif node.name == "graphsafe_run_with_rng_state":
return node.args[0] == op
return False
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
if op is not None:
assert not isinstance(op, list)
ops = [op]
if freq is not None:
freqs = [freq]
if freq_ge is not None:
freqs_ge = [freq_ge]
if freqs:
for op, freq in zip(ops, freqs):
actual_count = 0
for node in gm.graph.nodes:
if match_rng_op(node, op) or node.target == op:
actual_count += 1
err_msg = f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}."
assert actual_count == freq, err_msg
else:
assert freqs_ge is not None
for op, freq_ge in zip(ops, freqs_ge):
actual_count = 0
for node in gm.graph.nodes:
if match_rng_op(node, op) or node.target == op:
actual_count += 1
assert (
actual_count >= freq_ge
), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
return gm
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: set[str]):
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
return_node = list(graph.nodes)[-1]
assert return_node.target == "output"
for x in return_node.args[0]:
fwd_outputs.add(str(x))
class _InvalidContext:
def __init__(self) -> None:
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def _invalid_context_gen():
return _InvalidContext(), _InvalidContext()
def find_first_node(gm, func):
for node in gm.graph.nodes:
if node.target is func:
return node
return None
def op_count(gm):
result = 0
for node in gm.graph.nodes:
if "call" in node.op:
result += 1
return result
def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
def _custom_policy(ctx, func, *args, **kwargs):
if no_recompute_list is not None and func in no_recompute_list:
return CheckpointPolicy.MUST_SAVE
if must_recompute_list is not None and func in must_recompute_list:
return CheckpointPolicy.MUST_RECOMPUTE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
return _custom_policy
class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
def _validate(
self,
fn,
backend,
*args,
skip_check=False,
fullgraph=True,
compiled_autograd=False,
):
cloned_args = []
for arg in args:
cloned_args.append(arg.detach().clone().requires_grad_(arg.requires_grad))
cloned_fn = copy.deepcopy(fn)
torch.manual_seed(0)
expected = fn(*args)
expected.sum().backward()
torch.manual_seed(0)
compiled_fn = torch.compile(cloned_fn, fullgraph=fullgraph, backend=backend)
ctx = contextlib.nullcontext()
if compiled_autograd:
ctx = torch._dynamo.compiled_autograd._enable(
lambda gm: torch.compile(gm, fullgraph=fullgraph, backend=backend)
)
with ctx:
result = compiled_fn(*cloned_args)
result.sum().backward()
if not skip_check:
self.assertEqual(
result,
expected,
msg="Output mismatch between torch.compile and eager versions",
)
for arg, cloned_arg in zip(args, cloned_args):
self.assertEqual(
arg.grad,
cloned_arg.grad,
msg="Gradient mismatch between torch.compile and eager versions",
)
def _compare_orig_and_checkpointed_fns(
self, orig_fn, checkpointed_fn, *args, fullgraph=True
):
# The original version and the checkpointed version of the same function
# should produce the same outputs and the same gradients under torch.compile.
# Run original version
cloned_args_orig_fn = []
for arg in args:
cloned_args_orig_fn.append(
arg.detach().clone().requires_grad_(arg.requires_grad)
)
torch.manual_seed(0)
compiled_orig_fn = torch.compile(
orig_fn, fullgraph=fullgraph, backend="inductor"
)
result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn)
result_orig_fn.sum().backward()
# Run checkpointed version
cloned_args_checkpointed_fn = []
for arg in args:
cloned_args_checkpointed_fn.append(
arg.detach().clone().requires_grad_(arg.requires_grad)
)
torch.manual_seed(0)
compiled_checkpointed_fn = torch.compile(
checkpointed_fn, fullgraph=fullgraph, backend="inductor"
)
result_checkpointed_fn = compiled_checkpointed_fn(*cloned_args_checkpointed_fn)
result_checkpointed_fn.sum().backward()
# Check that outputs and gradients are equal
self.assertEqual(
result_orig_fn,
result_checkpointed_fn,
msg="Output mismatch between the original version and the checkpointed version of the same function",
)
for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip(
cloned_args_orig_fn, cloned_args_checkpointed_fn
):
self.assertEqual(
cloned_arg_orig_fn.grad,
cloned_arg_checkpointed_fn.grad,
msg="Gradient mismatch between the original version and the checkpointed version of the same function",
)
def test_tags_function(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True
)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda
def test_tags_function_via_global_checkpoint(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
# This goes through VariableBuilder
return checkpoint(gn, torch.sin(x), y, use_reentrant=True)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda
def test_tags_function_with_kwargs(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda
def test_tags_sequential_layers(self, device):
def gn(x):
x = x.cos()
for _ in range(3):
x = torch.mm(x, x)
x = x.cos()
return x
def fn(x):
x = torch.utils.checkpoint.checkpoint(gn, x)
x = torch.utils.checkpoint.checkpoint(gn, x)
return x
x = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=6, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops,
freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda
def test_tags_multiple_checkpoints(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(z)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return z
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda
def test_tags_module(self, device):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.sigmoid(self.linear(x))
mod = MockModule().to(device)
def fn(x):
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)
x = torch.randn(10, 10, device=device, requires_grad=True)
fw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda
def test_tags_decomps(self, device):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.nn.functional.gelu(self.linear(x))
mod = MockModule().to(device)
def fn(x):
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)
x = torch.randn(10, 10, device=device, requires_grad=True)
fw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.erf.default
)
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.erf.default
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
)
self._validate(fn, backend, x)
@requires_cuda
@torch._inductor.config.patch(fallback_random=True)
def test_tags_recomputed_rand(self, device):
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return z
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
# bw_compiler = functools.partial(
# count_ops, freq=6, op=torch.ops.aten.mm.default
# ) # mm recomputed in the bwd
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = "inductor"
self._validate(fn, backend, x, y)
@requires_cuda
@torch._inductor.config.patch(fallback_random=True)
def test_tags_rand(self, device):
def gn(x, y):
x = torch.mm(x, y)
x = torch.mm(x, y)
return x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
# x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return x
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
# bw_compiler = functools.partial(
# count_ops, freq=6, op=torch.ops.aten.mm.default
# ) # mm recomputed in the bwd
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
# backend = "aot_eager"
backend = "inductor"
self._validate(fn, backend, x, y)
@requires_cuda
@torch._inductor.config.patch(fallback_random=True)
def test_tags_dropout(self, device):
# Figure out a way to test the number of inductor_random calls
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.dropout = torch.nn.Dropout(0.2)
def forward(self, x):
return self.dropout(self.linear(x))
mod = MockModule().to(device)
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
x = torch.randn(10, 10, device=device, requires_grad=True)
backend = "inductor"
# rand decomps do not have have numerical results as eager
self._validate(fn, backend, x, skip_check=True)
@skipIfHpu
@torch._functorch.config.patch(recompute_views=True)
@torch._inductor.config.patch(fx_graph_cache=False)
def test_tags_must_save_tensor_that_has_backward_hook(self):
def my_post_forward_hook(submod, args, output):
output.register_hook(my_backward_hook)
return output
def my_backward_hook(grad):
return grad
class MySubmod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
y = torch.matmul(x, x)
z = y * y
return z
class MyMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = MySubmod()
self.norm = torch.nn.LayerNorm(4)
def forward(self, x):
out = torch.utils.checkpoint.checkpoint(
self.submod, x, use_reentrant=False
)
norm_out = self.norm(out)
return norm_out
def _factory_fn():
mod = MyMod()
x = torch.ones(4, 4, dtype=torch.float32, requires_grad=True)
backend = "inductor"
return mod, x, backend
mod_no_hook, x, backend = _factory_fn()
mod_no_hook_fwd_outputs = set()
with torch._inductor.config.patch(
post_grad_custom_pre_pass=functools.partial(
collect_fwd_graph_outputs, fwd_outputs=mod_no_hook_fwd_outputs
)
):
self._validate(
mod_no_hook, backend, x, fullgraph=True, compiled_autograd=True
)
torch._dynamo.reset()
mod_with_hook, x, backend = _factory_fn()
mod_with_hook.submod.register_forward_hook(my_post_forward_hook)
mod_with_hook_fwd_outputs = set()
with torch._inductor.config.patch(
post_grad_custom_pre_pass=functools.partial(
collect_fwd_graph_outputs, fwd_outputs=mod_with_hook_fwd_outputs
)
):
self._validate(
mod_with_hook, backend, x, fullgraph=True, compiled_autograd=True
)
# If `z` has a backward hook, result of `z = y * y` should also be saved in addition to the usual saved tensors.
mod_no_hook_fwd_outputs_no_primal = {
x for x in mod_no_hook_fwd_outputs if not x.startswith("primals_")
}
mod_with_hook_fwd_outputs_no_primal = {
x for x in mod_with_hook_fwd_outputs if not x.startswith("primals_")
}
additional_saved_tensors = (
mod_with_hook_fwd_outputs_no_primal - mod_no_hook_fwd_outputs_no_primal
)
expected_additional_saved_tensors = {"mul"}
self.assertEqual(
additional_saved_tensors,
expected_additional_saved_tensors,
f"""
Expected additional saved tensors: {expected_additional_saved_tensors} but got: {additional_saved_tensors}.
Non-primal fwd outputs from model w/ backward hook: {mod_with_hook_fwd_outputs_no_primal}.
Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""",
)
@requires_cuda
def test_fallback(self, device):
def gn(x, y):
torch._dynamo.graph_break()
a = torch.sigmoid(torch.matmul(x, y))
torch._dynamo.graph_break()
return torch.cos(a)
def fn(x, y):
return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False))
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
args = (x, y)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
expected = fn(*args)
result = torch.compile(fn, backend=cnt)(*args)
self.assertEqual(result, expected)
# One graph for torch.sin on the input, and other for torch.cos.
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(len(cnt.graphs), 2)
@requires_cuda
def test_kwargs(self, device):
def gn(x, y, z=None):
a = torch.matmul(x, y)
if z is not None:
return torch.matmul(a, z)
return a
def fn(x, y, z):
return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
z = torch.randn(4, 4, requires_grad=True, device=device)
args = (x, y, z)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
expected = fn(*args)
result = torch.compile(fn, backend=cnt)(*args)
self.assertEqual(result, expected)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(cnt.graphs), 1)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
# one for checkpoint, and 3 for x, y, z
self.assertEqual(len(wrap_node.args), 4)
body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
self.assertEqual(op_count(body_function), 2)
@requires_cuda
def test_symints_location(self, device):
def gn(x, y):
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
opt_fn = torch.compile(fn, backend=cnt)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
x = torch.randn(5, 5, requires_grad=True, device=device)
y = torch.randn(5, 5, requires_grad=True, device=device)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(len(cnt.graphs), 2)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
self.assertEqual(len(wrap_node.args), 3)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_recompute(self, device):
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(
must_recompute_list=must_recompute_list,
),
)
def context_fn_no_recompute_mm():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(
no_recompute_list=no_recompute_list,
),
)
def _test(context_fn, bw_compiler):
def gn(x):
return torch.sigmoid(torch.matmul(x, x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
gn,
x,
use_reentrant=False,
context_fn=context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=1,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x)
_test(
context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
op=torch.ops.aten.mm.default,
),
)
_test(
context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=2, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default,
),
)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 6 here
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
# if we didn't enable selective checkpointing.
freq=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_tensor_subclass(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
rand_tensor = torch.randn(4, 4, requires_grad=True, device=device)
# tensor subclasses as inputs
x = TwoTensor(rand_tensor, rand_tensor.clone())
y = TwoTensor(rand_tensor.clone(), rand_tensor.clone())
fw_compiler = functools.partial(
count_ops,
freq=4,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 12 here
# (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12)
# if we didn't enable selective checkpointing.
freq=8,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_custom_rule(self, device):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
]
def _custom_policy(mode, func, *args, **kwargs):
mm_count_key = f"{mode}_mm_count"
if mm_count_key not in meta:
meta[mm_count_key] = 0
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except second mm
# (i.e. we will hint the partitioner to recompute second mm in backward pass)
return func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] == 2
)
return _custom_policy
def selective_checkpointing_context_fn():
meta = {}
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
def gn(x, y):
return torch.sigmoid(
torch.sigmoid(torch.matmul(torch.matmul(x, y) * y, y) * y)
)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# Q: How do we come to this number 4?
# A: We have 2 matmuls in the forward pass, each matmul contributes 2 `mm` ops in the backward pass,
# so we have at least 4 `mm` ops in backward pass. It's "at least" because whether second matmul in
# the forward pass is recomputed in the backward pass is up to the partitioner to decide.
freq_ge=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=functools.partial(
selective_checkpointing_context_fn, [torch.ops.aten.mm.default]
),
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 6 here
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
# if we didn't enable selective checkpointing.
freq=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_outplace_op(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list),
)
def gn(x, y):
return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu()
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[4, 0],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skip(
"In-place op support in selective checkpointing + torch.compile "
"requires TorchDispatchMode + torch.compile work to complete"
)
@requires_cuda
def test_compile_selective_checkpoint_inplace_op(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(
torch.selu_(torch.matmul(torch.matmul(x, y), y))
).relu_()
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[4, 0],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
def test_compile_selective_checkpoint_random_op(self, device):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.sigmoid.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x):
return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True))
def fn(x):
return torch.utils.checkpoint.checkpoint(
gn,
x,
use_reentrant=False,
# Regardless of whether `preserve_rng_state` is True or False,
# we will always preserve RNG state when using `torch.compile`.
preserve_rng_state=preserve_rng_state,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[
torch.ops.aten.sigmoid.default,
torch.ops.aten.native_dropout.default,
],
)
bw_compiler = functools.partial(
count_ops,
# NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1).
freqs=[0, 1],
ops=[
torch.ops.aten.sigmoid.default,
torch.ops.aten.native_dropout.default,
],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
# because eager version doesn't preserve RNG state while torch.compile still does.
# Hence when `preserve_rng_state` is False, we skip the output and gradient comparison
# between torch.compile and eager.
self._validate(fn, backend, x, skip_check=not preserve_rng_state)
self._compare_orig_and_checkpointed_fns(gn, fn, x)
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_invalid_context(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=_invalid_context_gen,
)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
fw_compiler = functools.partial(
count_ops,
freq=1,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
freq_ge=2,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
):
self._validate(fn, backend, x, y)
@requires_cuda
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_compile_selective_checkpoint_parametrization(self):
def sac_policy():
def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
to_recompute = func in {
torch.ops.aten.mul.Tensor,
torch.ops.aten.sigmoid.default,
}
return (
CheckpointPolicy.MUST_RECOMPUTE
if to_recompute
else CheckpointPolicy.MUST_SAVE
)
return _custom_policy
return create_selective_checkpoint_contexts(_recomp_policy())
class Parametrization(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def parametrization(self, x):
return torch.sigmoid(torch.mul(x, x))
def forward(self, x):
return checkpoint(
self.parametrization, x, use_reentrant=False, context_fn=sac_policy
)
def apply_parametrization(model):
modules = list(model.modules())
for mod in modules:
params_dict = dict(mod.named_parameters(recurse=False))
for p_name, p in params_dict.items():
mod.register_parameter(p_name, nn.Parameter(p))
nn.utils.parametrize.register_parametrization(
mod, p_name, Parametrization(), unsafe=True
)
return model
class MLPModule(nn.Module):
def __init__(self) -> None:
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(16, 16, bias=False)
def forward(self, x):
return self.net1(x)
def reset_parameters(self):
self.net1.reset_parameters()
fw_compiler = functools.partial(
count_ops,
freqs=[1, 1],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[
2, # 1 from mul recompute, 1 from mul backward
1,
],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
model = MLPModule()
model = apply_parametrization(model)
model_compiled = torch.compile(
copy.deepcopy(model), backend=backend, fullgraph=True
)
input = torch.randn(8, 16, requires_grad=True)
input_compiled = copy.deepcopy(input)
out = model(input)
out.sum().backward()
out_compiled = model_compiled(input_compiled)
out_compiled.sum().backward()
self.assertEqual(out, out_compiled)
self.assertEqual(input.grad, input_compiled.grad)
@skipIfRocm
@requires_cuda
def test_autocast_flash_attention(self, device):
def fn(primals_1, primals_2, primals_3):
return torch.ops.aten._scaled_dot_product_efficient_attention.default(
primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687
)[0]
def gn(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
with torch.autocast(device_type=device):
x = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
y = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
z = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
args = (x, y, z)
torch.manual_seed(0)
ref = gn(*args)
opt_gn = torch.compile(gn)
torch.manual_seed(0)
res = opt_gn(*args)
self.assertEqual(ref, res)
@requires_cuda
def test_error_msg(self, device):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
x = torch.sin(x)
torch._dynamo.graph_break()
x = torch.cos(x)
return x
mod = MockModule().to(device)
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
x = torch.randn(4, 4).to(device)
opt_fn = torch.compile(fn, fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "User-inserted graph break"
):
opt_fn(x)
@requires_cuda
def test_list_inputs(self, device):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, ys):
a = torch.sin(x) # noqa: F841
b = torch.cos(ys[0])
c = torch.cos(ys[1])
return (x, [b, c])
mod = MockModule().to(device)
def fn(x, ys):
return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True)
x = torch.randn(4, 4).to(device)
y = torch.randn(4, 4).to(device)
z = torch.randn(4, 4).to(device)
ref = fn(x, [y, z])
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x, [y, z])
self.assertEqual(ref, res)
@requires_cuda
def test_pattern_matcher(self, device):
# Check that the sdpa op is recomputed in the backward graph
# tests percolate_tags
@checkpoint_wrapper
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return (
torch.matmul(query, key.transpose(-2, -1))
.mul(1.0 / math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.matmul(value)
)
def fn(query, key, value):
# Checks that sin is not recomputed in the backward graph
return dot_prod_attention(query.sin(), key, value)
tensor_shape = (4, 2, 16, 32)
dtype = torch.float16
args1 = [
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
]
# Save the AOT graphs
aot_graphs = []
from torch._inductor import compile_fx
def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
aot_graphs.append(graph)
return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs)
backend = functools.partial(
compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
)
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
opt_fn(*args1).sum().backward()
if PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
op = torch.ops.aten._scaled_dot_product_cudnn_attention.default
else:
op = torch.ops.aten._scaled_dot_product_flash_attention.default
fwd_graph = aot_graphs[0]
self.assertTrue(
count_ops(
fwd_graph,
[],
freq=1,
op=op,
)
)
bwd_graph = aot_graphs[1]
# Check that sin is not recomputed in the backward graph - checks percolate tags
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
# Check that the sdpa op is recomputed in the backward graph
self.assertTrue(
count_ops(
bwd_graph,
[],
freq=1,
op=op,
)
)
@requires_distributed()
@requires_cuda
def test_distributed_utils_checkpoint_wrapper(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as dist_checkpoint_wrapper,
)
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 4)
self.c = 2
def forward(self, x):
x = torch.sin(x)
x = self.linear(x)
x = torch.cos(x)
return x * self.c
mod = dist_checkpoint_wrapper(MockModule())
x = torch.randn(4, 4)
ref = mod(x)
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
res = opt_mod(x)
self.assertEqual(ref, res)
@requires_distributed()
@requires_cuda
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_dynamo_does_not_trace_getattr_as_top_frame(self):
# inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do.
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
cnt = CompileCounterWithBackend("eager")
lin = torch.nn.Linear(1, 1)
mod = torch.nn.Sequential(lin, lin)
mod = CheckpointWrapper(mod)
mod._checkpoint_wrapped_module.a = torch.ones(1, 1)
def fn(x):
return mod(x) * mod.a
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
x = torch.randn(1, 1)
self.assertEqual(opt_fn(x), fn(x))
devices = ["cuda", "hpu"]
instantiate_device_type_tests(
ActivationCheckpointingViaTagsTests, globals(), only_for=devices
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()