Files
pytorch/test/dynamo/test_activation_checkpointing.py
Animesh Jain 616c6bdf8f [dynamo][ac] Config flag to allow eager and compile AC divergence for side-effects (#165775)
Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775
Approved by: https://github.com/zou3519
ghstack dependencies: #165734
2025-10-17 22:04:19 +00:00

1683 lines
58 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import copy
import functools
import math
import unittest # noqa: F811
from importlib import import_module
import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import (
checkpoint,
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
if HAS_CUDA_AND_TRITON:
import triton
from triton import language as tl
@triton.jit
def add_one_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = x + 1
tl.store(out_ptr + offsets, output, mask=mask)
requires_distributed = functools.partial(
unittest.skipIf, not dist.is_available(), "requires distributed"
)
def checkpoint_wrapper(fn):
def inner(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
return inner
def count_ops(
gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None
):
def match_rng_op(node, op):
if isinstance(node.target, torch._ops.HigherOrderOperator):
if node.name == "run_and_save_rng_state":
return node.args[0] == op
elif node.name == "run_with_rng_state":
return node.args[1] == op
elif node.name == "graphsafe_run_with_rng_state":
return node.args[0] == op
return False
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
if op is not None:
assert not isinstance(op, list)
ops = [op]
if freq is not None:
freqs = [freq]
if freq_ge is not None:
freqs_ge = [freq_ge]
if freqs:
for op, freq in zip(ops, freqs):
actual_count = 0
for node in gm.graph.nodes:
if match_rng_op(node, op) or node.target == op:
actual_count += 1
err_msg = f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}."
assert actual_count == freq, err_msg
else:
assert freqs_ge is not None
for op, freq_ge in zip(ops, freqs_ge):
actual_count = 0
for node in gm.graph.nodes:
if match_rng_op(node, op) or node.target == op:
actual_count += 1
assert actual_count >= freq_ge, (
f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
)
return gm
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: set[str]):
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
return_node = list(graph.nodes)[-1]
assert return_node.target == "output"
for x in return_node.args[0]:
fwd_outputs.add(str(x))
class _InvalidContext:
def __init__(self) -> None:
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def _invalid_context_gen():
return _InvalidContext(), _InvalidContext()
def find_first_node(gm, func):
for node in gm.graph.nodes:
if node.target is func:
return node
return None
def op_count(gm):
result = 0
for node in gm.graph.nodes:
if "call" in node.op:
result += 1
return result
def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
def _custom_policy(ctx, func, *args, **kwargs):
if no_recompute_list is not None and func in no_recompute_list:
return CheckpointPolicy.MUST_SAVE
if must_recompute_list is not None and func in must_recompute_list:
return CheckpointPolicy.MUST_RECOMPUTE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
return _custom_policy
class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
def _validate(
self,
fn,
backend,
*args,
skip_check=False,
fullgraph=True,
compiled_autograd=False,
):
cloned_args = []
for arg in args:
cloned_args.append(arg.detach().clone().requires_grad_(arg.requires_grad))
cloned_fn = copy.deepcopy(fn)
torch.manual_seed(0)
expected = fn(*args)
expected.sum().backward()
torch.manual_seed(0)
compiled_fn = torch.compile(cloned_fn, fullgraph=fullgraph, backend=backend)
ctx = contextlib.nullcontext()
if compiled_autograd:
ctx = torch._dynamo.compiled_autograd._enable(
lambda gm: torch.compile(gm, fullgraph=fullgraph, backend=backend)
)
with ctx:
result = compiled_fn(*cloned_args)
result.sum().backward()
if not skip_check:
self.assertEqual(
result,
expected,
msg="Output mismatch between torch.compile and eager versions",
)
for arg, cloned_arg in zip(args, cloned_args):
self.assertEqual(
arg.grad,
cloned_arg.grad,
msg="Gradient mismatch between torch.compile and eager versions",
)
def _compare_orig_and_checkpointed_fns(
self, orig_fn, checkpointed_fn, *args, fullgraph=True
):
# The original version and the checkpointed version of the same function
# should produce the same outputs and the same gradients under torch.compile.
def clone_args(args):
cloned_args = []
for arg in args:
cloned_args.append(
arg.detach().clone().requires_grad_(arg.requires_grad)
)
return cloned_args
def run(compiler):
# Run original version
cloned_args_orig_fn = clone_args(args)
torch.manual_seed(0)
compiled_orig_fn = compiler(orig_fn)
result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn)
result_orig_fn.sum().backward()
# Run checkpointed version
cloned_args_checkpointed_fn = clone_args(args)
torch.manual_seed(0)
compiled_checkpointed_fn = compiler(copy.deepcopy(checkpointed_fn))
result_checkpointed_fn = compiled_checkpointed_fn(
*cloned_args_checkpointed_fn
)
result_checkpointed_fn.sum().backward()
# Check that outputs and gradients are equal
self.assertEqual(
result_orig_fn,
result_checkpointed_fn,
msg="Output mismatch between the original version and the checkpointed version of the same function",
)
for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip(
cloned_args_orig_fn, cloned_args_checkpointed_fn
):
self.assertEqual(
cloned_arg_orig_fn.grad,
cloned_arg_checkpointed_fn.grad,
msg="Gradient mismatch between the original version and the checkpointed version of the same function",
)
run(functools.partial(torch.compile, fullgraph=fullgraph))
if fullgraph:
def export_compiler(fn):
class WrapAsModule(nn.Module):
def forward(self, *args, **kwargs):
return fn(*args, **kwargs)
mod = WrapAsModule()
def runtime_wrapper(*runtime_args):
from torch.export import _trace
gm = _trace._export_to_torch_ir(
f=mod,
args=tuple(clone_args(args)),
kwargs={},
dynamic_shapes=None,
preserve_module_call_signature=(),
restore_fqn=False,
prefer_deferred_runtime_asserts_over_guards=False,
_log_export_usage=False,
)
# NOTE: this is necessary for rng to be added to the exported graph
return torch.compile(gm, fullgraph=False)(*runtime_args)
return runtime_wrapper
run(export_compiler)
def test_tags_function(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True
)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_via_global_checkpoint(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
# This goes through VariableBuilder
return checkpoint(gn, torch.sin(x), y, use_reentrant=True)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_with_kwargs(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_sequential_layers(self, device):
def gn(x):
x = x.cos()
for _ in range(3):
x = torch.mm(x, x)
x = x.cos()
return x
def fn(x):
x = torch.utils.checkpoint.checkpoint(gn, x)
x = torch.utils.checkpoint.checkpoint(gn, x)
return x
x = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=6, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops,
freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_multiple_checkpoints(self, device):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(z)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return z
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_module(self, device):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.sigmoid(self.linear(x))
mod = MockModule().to(device)
def fn(x):
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)
x = torch.randn(10, 10, device=device, requires_grad=True)
fw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_decomps(self, device):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.nn.functional.gelu(self.linear(x))
mod = MockModule().to(device)
def fn(x):
return torch.utils.checkpoint.checkpoint(
mod, torch.sin(x), use_reentrant=True
)
x = torch.randn(10, 10, device=device, requires_grad=True)
fw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.erf.default
)
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.erf.default
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
@torch._inductor.config.patch(fallback_random=True)
def test_tags_recomputed_rand(self, device):
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return z
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
# bw_compiler = functools.partial(
# count_ops, freq=6, op=torch.ops.aten.mm.default
# ) # mm recomputed in the bwd
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = "inductor"
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@torch._inductor.config.patch(fallback_random=True)
def test_tags_rand(self, device):
def gn(x, y):
x = torch.mm(x, y)
x = torch.mm(x, y)
return x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
# x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
return x
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
# bw_compiler = functools.partial(
# count_ops, freq=6, op=torch.ops.aten.mm.default
# ) # mm recomputed in the bwd
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
# backend = "aot_eager"
backend = "inductor"
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@torch._inductor.config.patch(fallback_random=True)
def test_tags_dropout(self, device):
# Figure out a way to test the number of inductor_random calls
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.dropout = torch.nn.Dropout(0.2)
def forward(self, x):
return self.dropout(self.linear(x))
mod = MockModule().to(device)
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
x = torch.randn(10, 10, device=device, requires_grad=True)
backend = "inductor"
# rand decomps do not have have numerical results as eager
self._validate(fn, backend, x, skip_check=True)
@skipIfHpu
@torch._functorch.config.patch(recompute_views=True)
@torch._inductor.config.patch(fx_graph_cache=False)
def test_tags_must_save_tensor_that_has_backward_hook(self):
def my_post_forward_hook(submod, args, output):
output.register_hook(my_backward_hook)
return output
def my_backward_hook(grad):
return grad
class MySubmod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
y = torch.matmul(x, x)
z = y * y
return z
class MyMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = MySubmod()
self.norm = torch.nn.LayerNorm(4)
def forward(self, x):
out = torch.utils.checkpoint.checkpoint(
self.submod, x, use_reentrant=False
)
norm_out = self.norm(out)
return norm_out
def _factory_fn():
mod = MyMod()
x = torch.ones(4, 4, dtype=torch.float32, requires_grad=True)
backend = "inductor"
return mod, x, backend
mod_no_hook, x, backend = _factory_fn()
mod_no_hook_fwd_outputs = set()
with torch._inductor.config.patch(
post_grad_custom_pre_pass=functools.partial(
collect_fwd_graph_outputs, fwd_outputs=mod_no_hook_fwd_outputs
)
):
self._validate(
mod_no_hook, backend, x, fullgraph=True, compiled_autograd=True
)
torch._dynamo.reset()
mod_with_hook, x, backend = _factory_fn()
mod_with_hook.submod.register_forward_hook(my_post_forward_hook)
mod_with_hook_fwd_outputs = set()
with torch._inductor.config.patch(
post_grad_custom_pre_pass=functools.partial(
collect_fwd_graph_outputs, fwd_outputs=mod_with_hook_fwd_outputs
)
):
self._validate(
mod_with_hook, backend, x, fullgraph=True, compiled_autograd=True
)
# If `z` has a backward hook, result of `z = y * y` should also be saved in addition to the usual saved tensors.
mod_no_hook_fwd_outputs_no_primal = {
x for x in mod_no_hook_fwd_outputs if not x.startswith("primals_")
}
mod_with_hook_fwd_outputs_no_primal = {
x for x in mod_with_hook_fwd_outputs if not x.startswith("primals_")
}
additional_saved_tensors = (
mod_with_hook_fwd_outputs_no_primal - mod_no_hook_fwd_outputs_no_primal
)
expected_additional_saved_tensors = {"mul"}
self.assertEqual(
additional_saved_tensors,
expected_additional_saved_tensors,
f"""
Expected additional saved tensors: {expected_additional_saved_tensors} but got: {additional_saved_tensors}.
Non-primal fwd outputs from model w/ backward hook: {mod_with_hook_fwd_outputs_no_primal}.
Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""",
)
@requires_cuda_and_triton
def test_fallback(self, device):
def gn(x, y):
torch._dynamo.graph_break()
a = torch.sigmoid(torch.matmul(x, y))
torch._dynamo.graph_break()
return torch.cos(a)
def fn(x, y):
return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False))
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
args = (x, y)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
expected = fn(*args)
result = torch.compile(fn, backend=cnt)(*args)
self.assertEqual(result, expected)
# One graph for torch.sin on the input, and other for torch.cos.
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(len(cnt.graphs), 2)
@requires_cuda_and_triton
def test_kwargs(self, device):
def gn(x, y, z=None):
a = torch.matmul(x, y)
if z is not None:
return torch.matmul(a, z)
return a
def fn(x, y, z):
return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
z = torch.randn(4, 4, requires_grad=True, device=device)
args = (x, y, z)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
expected = fn(*args)
result = torch.compile(fn, backend=cnt)(*args)
self.assertEqual(result, expected)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(cnt.graphs), 1)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
# one for checkpoint, and 3 for x, y, z
self.assertEqual(len(wrap_node.args), 4)
body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
self.assertEqual(op_count(body_function), 2)
@requires_cuda_and_triton
def test_symints_location(self, device):
def gn(x, y):
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
backend = "aot_eager"
cnt = CompileCounterWithBackend(backend)
opt_fn = torch.compile(fn, backend=cnt)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
x = torch.randn(5, 5, requires_grad=True, device=device)
y = torch.randn(5, 5, requires_grad=True, device=device)
args = (x, y)
expected = fn(*args)
result = opt_fn(*args)
self.assertEqual(result.shape, expected.shape)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(len(cnt.graphs), 2)
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
self.assertEqual(len(wrap_node.args), 3)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_recompute(self, device):
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(
must_recompute_list=must_recompute_list,
),
)
def context_fn_no_recompute_mm():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(
no_recompute_list=no_recompute_list,
),
)
def _test(context_fn, bw_compiler):
def gn(x):
return torch.sigmoid(torch.matmul(x, x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
gn,
x,
use_reentrant=False,
context_fn=context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=1,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x)
_test(
context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
op=torch.ops.aten.mm.default,
),
)
_test(
context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=2, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default,
),
)
def test_sac_with_partial_context_fn(self):
class CustomPolicy:
def __init__(self):
super().__init__()
def __call__(self, ctx, out, func, *args, **kwargs):
return CheckpointPolicy.MUST_SAVE
def f(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
context_fn1 = functools.partial(
create_selective_checkpoint_contexts, CustomPolicy()
)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
f,
x,
y,
use_reentrant=False,
context_fn=context_fn1,
)
opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True)
a = torch.randn(4, 4, requires_grad=True, device="cpu")
b = torch.randn(4, 4, requires_grad=True, device="cpu")
expected = fn(a, b)
result = opt_fn(a, b)
self.assertEqual(result, expected)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 6 here
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
# if we didn't enable selective checkpointing.
freq=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device
):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=1,
op=torch.ops.aten.sigmoid.default,
)
bw_compiler = functools.partial(
count_ops,
# Main check here is just that sigmoid is properly recomputed
# (we will see a sigmoid() and sigmoid_backward() in the bw graph)
freq=1,
op=torch.ops.aten.sigmoid.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
disable_functionalization=True,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_triton_kernel(self, device):
# Copy of the above test, but make sure that having a triton kernel in the
# region does not error.
def add_one(x):
out = torch.empty_like(x)
n_elements = x.numel()
add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
class AddOne(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return add_one(x)
@staticmethod
def backward(ctx, x):
return x
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return (
torch.sigmoid(torch.matmul(torch.matmul(AddOne.apply(x.sin()), y), y))
* y
)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 6 here
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
# if we didn't enable selective checkpointing.
freq=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_tensor_subclass(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
rand_tensor = torch.randn(4, 4, requires_grad=True, device=device)
# tensor subclasses as inputs
x = TwoTensor(rand_tensor, rand_tensor.clone())
y = TwoTensor(rand_tensor.clone(), rand_tensor.clone())
fw_compiler = functools.partial(
count_ops,
freq=4,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 12 here
# (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12)
# if we didn't enable selective checkpointing.
freq=8,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_custom_rule(self, device):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
]
def _custom_policy(mode, func, *args, **kwargs):
mm_count_key = f"{mode}_mm_count"
if mm_count_key not in meta:
meta[mm_count_key] = 0
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except second mm
# (i.e. we will hint the partitioner to recompute second mm in backward pass)
return func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] == 2
)
return _custom_policy
def selective_checkpointing_context_fn():
meta = {}
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
def gn(x, y):
return torch.sigmoid(
torch.sigmoid(torch.matmul(torch.matmul(x, y) * y, y) * y)
)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# Q: How do we come to this number 4?
# A: We have 2 matmuls in the forward pass, each matmul contributes 2 `mm` ops in the backward pass,
# so we have at least 4 `mm` ops in backward pass. It's "at least" because whether second matmul in
# the forward pass is recomputed in the backward pass is up to the partitioner to decide.
freq_ge=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=functools.partial(
selective_checkpointing_context_fn, [torch.ops.aten.mm.default]
),
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freq=2,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
# We would've expected 6 here
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
# if we didn't enable selective checkpointing.
freq=4,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_outplace_op(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list),
)
def gn(x, y):
return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu()
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[4, 0],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_list_ops(self, device):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.cat([x, y]).sin()
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freqs=[1],
ops=[torch.ops.aten.cat.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[1],
ops=[torch.ops.aten.cat.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@unittest.skip(
"In-place op support in selective checkpointing + torch.compile "
"requires TorchDispatchMode + torch.compile work to complete"
)
@requires_cuda_and_triton
def test_compile_selective_checkpoint_inplace_op(self, device):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x, y):
return torch.sigmoid(
torch.selu_(torch.matmul(torch.matmul(x, y), y))
).relu_()
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
y = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[4, 0],
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
def test_compile_selective_checkpoint_random_op(self, device):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.sigmoid.default,
]
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)
def gn(x):
return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True))
def fn(x):
return torch.utils.checkpoint.checkpoint(
gn,
x,
use_reentrant=False,
# Regardless of whether `preserve_rng_state` is True or False,
# we will always preserve RNG state when using `torch.compile`.
preserve_rng_state=preserve_rng_state,
context_fn=selective_checkpointing_context_fn,
)
x = torch.randn(4, 4, requires_grad=True, device=device)
fw_compiler = functools.partial(
count_ops,
freqs=[2, 1],
ops=[
torch.ops.aten.sigmoid.default,
torch.ops.aten.native_dropout.default,
],
)
bw_compiler = functools.partial(
count_ops,
# NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1).
freqs=[0, 1],
ops=[
torch.ops.aten.sigmoid.default,
torch.ops.aten.native_dropout.default,
],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
# because eager version doesn't preserve RNG state while torch.compile still does.
# Hence when `preserve_rng_state` is False, we skip the output and gradient comparison
# between torch.compile and eager.
self._validate(fn, backend, x, skip_check=not preserve_rng_state)
self._compare_orig_and_checkpointed_fns(gn, fn, x)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_invalid_context(self):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn,
x,
y,
use_reentrant=False,
context_fn=_invalid_context_gen,
)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
fw_compiler = functools.partial(
count_ops,
freq=1,
op=torch.ops.aten.mm.default,
)
bw_compiler = functools.partial(
count_ops,
freq_ge=2,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
):
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_compile_selective_checkpoint_parametrization(self):
def sac_policy():
def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
to_recompute = func in {
torch.ops.aten.mul.Tensor,
torch.ops.aten.sigmoid.default,
}
return (
CheckpointPolicy.MUST_RECOMPUTE
if to_recompute
else CheckpointPolicy.MUST_SAVE
)
return _custom_policy
return create_selective_checkpoint_contexts(_recomp_policy())
class Parametrization(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def parametrization(self, x):
return torch.sigmoid(torch.mul(x, x))
def forward(self, x):
return checkpoint(
self.parametrization, x, use_reentrant=False, context_fn=sac_policy
)
def apply_parametrization(model):
modules = list(model.modules())
for mod in modules:
params_dict = dict(mod.named_parameters(recurse=False))
for p_name, p in params_dict.items():
mod.register_parameter(p_name, nn.Parameter(p))
nn.utils.parametrize.register_parametrization(
mod, p_name, Parametrization(), unsafe=True
)
return model
class MLPModule(nn.Module):
def __init__(self) -> None:
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(16, 16, bias=False)
def forward(self, x):
return self.net1(x)
def reset_parameters(self):
self.net1.reset_parameters()
fw_compiler = functools.partial(
count_ops,
freqs=[1, 1],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
)
bw_compiler = functools.partial(
count_ops,
freqs=[
2, # 1 from mul recompute, 1 from mul backward
1,
],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
)
model = MLPModule()
model = apply_parametrization(model)
model_compiled = torch.compile(
copy.deepcopy(model), backend=backend, fullgraph=True
)
input = torch.randn(8, 16, requires_grad=True)
input_compiled = copy.deepcopy(input)
out = model(input)
out.sum().backward()
out_compiled = model_compiled(input_compiled)
out_compiled.sum().backward()
self.assertEqual(out, out_compiled)
self.assertEqual(input.grad, input_compiled.grad)
@requires_cuda_and_triton
def test_autocast_flash_attention(self, device):
def fn(primals_1, primals_2, primals_3):
return torch.ops.aten._scaled_dot_product_efficient_attention.default(
primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687
)[0]
def gn(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
with torch.autocast(device_type=device):
x = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
y = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
z = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
args = (x, y, z)
torch.manual_seed(0)
ref = gn(*args)
opt_gn = torch.compile(gn)
torch.manual_seed(0)
res = opt_gn(*args)
self.assertEqual(ref, res)
@requires_cuda_and_triton
def test_error_msg(self, device):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
x = torch.sin(x)
torch._dynamo.graph_break()
x = torch.cos(x)
return x
mod = MockModule().to(device)
def fn(x):
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
x = torch.randn(4, 4).to(device)
opt_fn = torch.compile(fn, fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "User-inserted graph break"
):
opt_fn(x)
@requires_cuda_and_triton
def test_list_inputs(self, device):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, ys):
a = torch.sin(x) # noqa: F841
b = torch.cos(ys[0])
c = torch.cos(ys[1])
return (x, [b, c])
mod = MockModule().to(device)
def fn(x, ys):
return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True)
x = torch.randn(4, 4).to(device)
y = torch.randn(4, 4).to(device)
z = torch.randn(4, 4).to(device)
ref = fn(x, [y, z])
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x, [y, z])
self.assertEqual(ref, res)
@requires_cuda_and_triton
def test_pattern_matcher(self, device):
# Check that the sdpa op is recomputed in the backward graph
# tests percolate_tags
@checkpoint_wrapper
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return (
torch.matmul(query, key.transpose(-2, -1))
.mul(1.0 / math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.matmul(value)
)
def fn(query, key, value):
# Checks that sin is not recomputed in the backward graph
return dot_prod_attention(query.sin(), key, value)
tensor_shape = (4, 2, 16, 32)
dtype = torch.float16
args1 = [
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
]
# Save the AOT graphs
aot_graphs = []
from torch._inductor import compile_fx
def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
aot_graphs.append(graph)
return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs)
backend = functools.partial(
compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
)
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
opt_fn(*args1).sum().backward()
fwd_graph = aot_graphs[0]
op1 = torch.ops.aten._scaled_dot_product_flash_attention.default
op2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default
self.assertTrue(
count_ops(
fwd_graph,
[],
freq=1,
op=op1,
)
or count_ops(
fwd_graph,
[],
freq=1,
op=op2,
)
)
bwd_graph = aot_graphs[1]
# Check that sin is not recomputed in the backward graph - checks percolate tags
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
# Check that the sdpa op is recomputed in the backward graph
self.assertTrue(
count_ops(
bwd_graph,
[],
freq=1,
op=op1,
)
or count_ops(
bwd_graph,
[],
freq=1,
op=op2,
)
)
@requires_distributed()
@requires_cuda_and_triton
def test_distributed_utils_checkpoint_wrapper(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as dist_checkpoint_wrapper,
)
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 4)
self.c = 2
def forward(self, x):
x = torch.sin(x)
x = self.linear(x)
x = torch.cos(x)
return x * self.c
mod = dist_checkpoint_wrapper(MockModule())
x = torch.randn(4, 4)
ref = mod(x)
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
res = opt_mod(x)
self.assertEqual(ref, res)
@requires_distributed()
@requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_dynamo_does_not_trace_getattr_as_top_frame(self):
# inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do.
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
cnt = CompileCounterWithBackend("eager")
lin = torch.nn.Linear(1, 1)
mod = torch.nn.Sequential(lin, lin)
mod = CheckpointWrapper(mod)
mod._checkpoint_wrapped_module.a = torch.ones(1, 1)
def fn(x):
return mod(x) * mod.a
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
x = torch.randn(1, 1)
self.assertEqual(opt_fn(x), fn(x))
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_mutation(self):
counter = 0
def gn(x):
nonlocal counter
counter += 1
return torch.sin(x)
def fn(x):
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
x = torch.randn(4, 4, requires_grad=True)
fn(x).sum().backward()
# The mutation is reapplied in the backward as well
self.assertEqual(counter, 2)
counter = 0
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
opt_fn(x).sum().backward()
# The mutation is not reapplied in the backward because the flag was on.
self.assertEqual(counter, 1)
devices = ["cuda", "hpu"]
instantiate_device_type_tests(
ActivationCheckpointingViaTagsTests, globals(), only_for=devices
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()