mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775 Approved by: https://github.com/zou3519 ghstack dependencies: #165734
1683 lines
58 KiB
Python
1683 lines
58 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import contextlib
|
|
import copy
|
|
import functools
|
|
import math
|
|
import unittest # noqa: F811
|
|
from importlib import import_module
|
|
|
|
import torch
|
|
import torch._dynamo.config
|
|
import torch._dynamo.test_case
|
|
import torch._functorch.config
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
from torch._dynamo.testing import CompileCounterWithBackend
|
|
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
from torch.utils.checkpoint import (
|
|
checkpoint,
|
|
CheckpointPolicy,
|
|
create_selective_checkpoint_contexts,
|
|
)
|
|
|
|
|
|
if HAS_CUDA_AND_TRITON:
|
|
import triton
|
|
from triton import language as tl
|
|
|
|
@triton.jit
|
|
def add_one_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
output = x + 1
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
|
|
requires_distributed = functools.partial(
|
|
unittest.skipIf, not dist.is_available(), "requires distributed"
|
|
)
|
|
|
|
|
|
def checkpoint_wrapper(fn):
|
|
def inner(*args):
|
|
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
|
|
|
|
return inner
|
|
|
|
|
|
def count_ops(
|
|
gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None
|
|
):
|
|
def match_rng_op(node, op):
|
|
if isinstance(node.target, torch._ops.HigherOrderOperator):
|
|
if node.name == "run_and_save_rng_state":
|
|
return node.args[0] == op
|
|
elif node.name == "run_with_rng_state":
|
|
return node.args[1] == op
|
|
elif node.name == "graphsafe_run_with_rng_state":
|
|
return node.args[0] == op
|
|
return False
|
|
|
|
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
|
|
if op is not None:
|
|
assert not isinstance(op, list)
|
|
ops = [op]
|
|
if freq is not None:
|
|
freqs = [freq]
|
|
if freq_ge is not None:
|
|
freqs_ge = [freq_ge]
|
|
if freqs:
|
|
for op, freq in zip(ops, freqs):
|
|
actual_count = 0
|
|
for node in gm.graph.nodes:
|
|
if match_rng_op(node, op) or node.target == op:
|
|
actual_count += 1
|
|
err_msg = f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}."
|
|
assert actual_count == freq, err_msg
|
|
else:
|
|
assert freqs_ge is not None
|
|
for op, freq_ge in zip(ops, freqs_ge):
|
|
actual_count = 0
|
|
for node in gm.graph.nodes:
|
|
if match_rng_op(node, op) or node.target == op:
|
|
actual_count += 1
|
|
assert actual_count >= freq_ge, (
|
|
f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
|
|
)
|
|
return gm
|
|
|
|
|
|
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: set[str]):
|
|
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
|
|
return_node = list(graph.nodes)[-1]
|
|
assert return_node.target == "output"
|
|
for x in return_node.args[0]:
|
|
fwd_outputs.add(str(x))
|
|
|
|
|
|
class _InvalidContext:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
|
|
def _invalid_context_gen():
|
|
return _InvalidContext(), _InvalidContext()
|
|
|
|
|
|
def find_first_node(gm, func):
|
|
for node in gm.graph.nodes:
|
|
if node.target is func:
|
|
return node
|
|
return None
|
|
|
|
|
|
def op_count(gm):
|
|
result = 0
|
|
for node in gm.graph.nodes:
|
|
if "call" in node.op:
|
|
result += 1
|
|
return result
|
|
|
|
|
|
def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
|
|
def _custom_policy(ctx, func, *args, **kwargs):
|
|
if no_recompute_list is not None and func in no_recompute_list:
|
|
return CheckpointPolicy.MUST_SAVE
|
|
if must_recompute_list is not None and func in must_recompute_list:
|
|
return CheckpointPolicy.MUST_RECOMPUTE
|
|
else:
|
|
return CheckpointPolicy.PREFER_RECOMPUTE
|
|
|
|
return _custom_policy
|
|
|
|
|
|
class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
|
def _validate(
|
|
self,
|
|
fn,
|
|
backend,
|
|
*args,
|
|
skip_check=False,
|
|
fullgraph=True,
|
|
compiled_autograd=False,
|
|
):
|
|
cloned_args = []
|
|
for arg in args:
|
|
cloned_args.append(arg.detach().clone().requires_grad_(arg.requires_grad))
|
|
|
|
cloned_fn = copy.deepcopy(fn)
|
|
|
|
torch.manual_seed(0)
|
|
expected = fn(*args)
|
|
expected.sum().backward()
|
|
|
|
torch.manual_seed(0)
|
|
compiled_fn = torch.compile(cloned_fn, fullgraph=fullgraph, backend=backend)
|
|
ctx = contextlib.nullcontext()
|
|
if compiled_autograd:
|
|
ctx = torch._dynamo.compiled_autograd._enable(
|
|
lambda gm: torch.compile(gm, fullgraph=fullgraph, backend=backend)
|
|
)
|
|
with ctx:
|
|
result = compiled_fn(*cloned_args)
|
|
result.sum().backward()
|
|
|
|
if not skip_check:
|
|
self.assertEqual(
|
|
result,
|
|
expected,
|
|
msg="Output mismatch between torch.compile and eager versions",
|
|
)
|
|
for arg, cloned_arg in zip(args, cloned_args):
|
|
self.assertEqual(
|
|
arg.grad,
|
|
cloned_arg.grad,
|
|
msg="Gradient mismatch between torch.compile and eager versions",
|
|
)
|
|
|
|
def _compare_orig_and_checkpointed_fns(
|
|
self, orig_fn, checkpointed_fn, *args, fullgraph=True
|
|
):
|
|
# The original version and the checkpointed version of the same function
|
|
# should produce the same outputs and the same gradients under torch.compile.
|
|
|
|
def clone_args(args):
|
|
cloned_args = []
|
|
for arg in args:
|
|
cloned_args.append(
|
|
arg.detach().clone().requires_grad_(arg.requires_grad)
|
|
)
|
|
return cloned_args
|
|
|
|
def run(compiler):
|
|
# Run original version
|
|
cloned_args_orig_fn = clone_args(args)
|
|
torch.manual_seed(0)
|
|
compiled_orig_fn = compiler(orig_fn)
|
|
result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn)
|
|
result_orig_fn.sum().backward()
|
|
|
|
# Run checkpointed version
|
|
cloned_args_checkpointed_fn = clone_args(args)
|
|
torch.manual_seed(0)
|
|
compiled_checkpointed_fn = compiler(copy.deepcopy(checkpointed_fn))
|
|
result_checkpointed_fn = compiled_checkpointed_fn(
|
|
*cloned_args_checkpointed_fn
|
|
)
|
|
result_checkpointed_fn.sum().backward()
|
|
|
|
# Check that outputs and gradients are equal
|
|
self.assertEqual(
|
|
result_orig_fn,
|
|
result_checkpointed_fn,
|
|
msg="Output mismatch between the original version and the checkpointed version of the same function",
|
|
)
|
|
for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip(
|
|
cloned_args_orig_fn, cloned_args_checkpointed_fn
|
|
):
|
|
self.assertEqual(
|
|
cloned_arg_orig_fn.grad,
|
|
cloned_arg_checkpointed_fn.grad,
|
|
msg="Gradient mismatch between the original version and the checkpointed version of the same function",
|
|
)
|
|
|
|
run(functools.partial(torch.compile, fullgraph=fullgraph))
|
|
if fullgraph:
|
|
|
|
def export_compiler(fn):
|
|
class WrapAsModule(nn.Module):
|
|
def forward(self, *args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
|
|
mod = WrapAsModule()
|
|
|
|
def runtime_wrapper(*runtime_args):
|
|
from torch.export import _trace
|
|
|
|
gm = _trace._export_to_torch_ir(
|
|
f=mod,
|
|
args=tuple(clone_args(args)),
|
|
kwargs={},
|
|
dynamic_shapes=None,
|
|
preserve_module_call_signature=(),
|
|
restore_fqn=False,
|
|
prefer_deferred_runtime_asserts_over_guards=False,
|
|
_log_export_usage=False,
|
|
)
|
|
# NOTE: this is necessary for rng to be added to the exported graph
|
|
return torch.compile(gm, fullgraph=False)(*runtime_args)
|
|
|
|
return runtime_wrapper
|
|
|
|
run(export_compiler)
|
|
|
|
def test_tags_function(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=3, op=torch.ops.aten.mm.default
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_tags_function_via_global_checkpoint(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
# This goes through VariableBuilder
|
|
return checkpoint(gn, torch.sin(x), y, use_reentrant=True)
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=3, op=torch.ops.aten.mm.default
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_tags_function_with_kwargs(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
|
|
)
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=3, op=torch.ops.aten.mm.default
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_tags_sequential_layers(self, device):
|
|
def gn(x):
|
|
x = x.cos()
|
|
for _ in range(3):
|
|
x = torch.mm(x, x)
|
|
x = x.cos()
|
|
return x
|
|
|
|
def fn(x):
|
|
x = torch.utils.checkpoint.checkpoint(gn, x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x)
|
|
return x
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=6, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[2, 18],
|
|
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_tags_multiple_checkpoints(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(z)
|
|
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
return z
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=6, op=torch.ops.aten.mm.default
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_tags_module(self, device):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return torch.sigmoid(self.linear(x))
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
mod, torch.sin(x), use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(10, 10, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
|
)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_tags_decomps(self, device):
|
|
# Ensures that tags are passed on through decompositions as well
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.gelu(self.linear(x))
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
mod, torch.sin(x), use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(10, 10, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.erf.default
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.erf.default
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
decompositions=lambda: import_module(
|
|
"torch._inductor.compile_fx"
|
|
).select_decomp_table(),
|
|
)
|
|
self._validate(fn, backend, x)
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._inductor.config.patch(fallback_random=True)
|
|
def test_tags_recomputed_rand(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
return z
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
# bw_compiler = functools.partial(
|
|
# count_ops, freq=6, op=torch.ops.aten.mm.default
|
|
# ) # mm recomputed in the bwd
|
|
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
backend = "inductor"
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._inductor.config.patch(fallback_random=True)
|
|
def test_tags_rand(self, device):
|
|
def gn(x, y):
|
|
x = torch.mm(x, y)
|
|
x = torch.mm(x, y)
|
|
return x
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
# x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
return x
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
# bw_compiler = functools.partial(
|
|
# count_ops, freq=6, op=torch.ops.aten.mm.default
|
|
# ) # mm recomputed in the bwd
|
|
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
# backend = "aot_eager"
|
|
backend = "inductor"
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._inductor.config.patch(fallback_random=True)
|
|
def test_tags_dropout(self, device):
|
|
# Figure out a way to test the number of inductor_random calls
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.dropout = torch.nn.Dropout(0.2)
|
|
|
|
def forward(self, x):
|
|
return self.dropout(self.linear(x))
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
|
|
|
|
x = torch.randn(10, 10, device=device, requires_grad=True)
|
|
backend = "inductor"
|
|
# rand decomps do not have have numerical results as eager
|
|
self._validate(fn, backend, x, skip_check=True)
|
|
|
|
@skipIfHpu
|
|
@torch._functorch.config.patch(recompute_views=True)
|
|
@torch._inductor.config.patch(fx_graph_cache=False)
|
|
def test_tags_must_save_tensor_that_has_backward_hook(self):
|
|
def my_post_forward_hook(submod, args, output):
|
|
output.register_hook(my_backward_hook)
|
|
return output
|
|
|
|
def my_backward_hook(grad):
|
|
return grad
|
|
|
|
class MySubmod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
y = torch.matmul(x, x)
|
|
z = y * y
|
|
return z
|
|
|
|
class MyMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submod = MySubmod()
|
|
self.norm = torch.nn.LayerNorm(4)
|
|
|
|
def forward(self, x):
|
|
out = torch.utils.checkpoint.checkpoint(
|
|
self.submod, x, use_reentrant=False
|
|
)
|
|
norm_out = self.norm(out)
|
|
return norm_out
|
|
|
|
def _factory_fn():
|
|
mod = MyMod()
|
|
x = torch.ones(4, 4, dtype=torch.float32, requires_grad=True)
|
|
backend = "inductor"
|
|
return mod, x, backend
|
|
|
|
mod_no_hook, x, backend = _factory_fn()
|
|
mod_no_hook_fwd_outputs = set()
|
|
|
|
with torch._inductor.config.patch(
|
|
post_grad_custom_pre_pass=functools.partial(
|
|
collect_fwd_graph_outputs, fwd_outputs=mod_no_hook_fwd_outputs
|
|
)
|
|
):
|
|
self._validate(
|
|
mod_no_hook, backend, x, fullgraph=True, compiled_autograd=True
|
|
)
|
|
|
|
torch._dynamo.reset()
|
|
mod_with_hook, x, backend = _factory_fn()
|
|
mod_with_hook.submod.register_forward_hook(my_post_forward_hook)
|
|
mod_with_hook_fwd_outputs = set()
|
|
|
|
with torch._inductor.config.patch(
|
|
post_grad_custom_pre_pass=functools.partial(
|
|
collect_fwd_graph_outputs, fwd_outputs=mod_with_hook_fwd_outputs
|
|
)
|
|
):
|
|
self._validate(
|
|
mod_with_hook, backend, x, fullgraph=True, compiled_autograd=True
|
|
)
|
|
|
|
# If `z` has a backward hook, result of `z = y * y` should also be saved in addition to the usual saved tensors.
|
|
mod_no_hook_fwd_outputs_no_primal = {
|
|
x for x in mod_no_hook_fwd_outputs if not x.startswith("primals_")
|
|
}
|
|
mod_with_hook_fwd_outputs_no_primal = {
|
|
x for x in mod_with_hook_fwd_outputs if not x.startswith("primals_")
|
|
}
|
|
additional_saved_tensors = (
|
|
mod_with_hook_fwd_outputs_no_primal - mod_no_hook_fwd_outputs_no_primal
|
|
)
|
|
expected_additional_saved_tensors = {"mul"}
|
|
self.assertEqual(
|
|
additional_saved_tensors,
|
|
expected_additional_saved_tensors,
|
|
f"""
|
|
Expected additional saved tensors: {expected_additional_saved_tensors} but got: {additional_saved_tensors}.
|
|
Non-primal fwd outputs from model w/ backward hook: {mod_with_hook_fwd_outputs_no_primal}.
|
|
Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""",
|
|
)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_fallback(self, device):
|
|
def gn(x, y):
|
|
torch._dynamo.graph_break()
|
|
a = torch.sigmoid(torch.matmul(x, y))
|
|
torch._dynamo.graph_break()
|
|
return torch.cos(a)
|
|
|
|
def fn(x, y):
|
|
return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False))
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
args = (x, y)
|
|
|
|
backend = "aot_eager"
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
expected = fn(*args)
|
|
result = torch.compile(fn, backend=cnt)(*args)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
# One graph for torch.sin on the input, and other for torch.cos.
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
self.assertEqual(len(cnt.graphs), 2)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_kwargs(self, device):
|
|
def gn(x, y, z=None):
|
|
a = torch.matmul(x, y)
|
|
if z is not None:
|
|
return torch.matmul(a, z)
|
|
return a
|
|
|
|
def fn(x, y, z):
|
|
return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
z = torch.randn(4, 4, requires_grad=True, device=device)
|
|
args = (x, y, z)
|
|
|
|
backend = "aot_eager"
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
expected = fn(*args)
|
|
result = torch.compile(fn, backend=cnt)(*args)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(cnt.graphs), 1)
|
|
|
|
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
|
|
# one for checkpoint, and 3 for x, y, z
|
|
self.assertEqual(len(wrap_node.args), 4)
|
|
|
|
body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 2)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_symints_location(self, device):
|
|
def gn(x, y):
|
|
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
|
|
backend = "aot_eager"
|
|
cnt = CompileCounterWithBackend(backend)
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
args = (x, y)
|
|
expected = fn(*args)
|
|
result = opt_fn(*args)
|
|
|
|
x = torch.randn(5, 5, requires_grad=True, device=device)
|
|
y = torch.randn(5, 5, requires_grad=True, device=device)
|
|
args = (x, y)
|
|
expected = fn(*args)
|
|
result = opt_fn(*args)
|
|
|
|
self.assertEqual(result.shape, expected.shape)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(len(cnt.graphs), 2)
|
|
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
|
|
self.assertEqual(len(wrap_node.args), 3)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_must_recompute(self, device):
|
|
def context_fn_must_recompute_mm():
|
|
must_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(
|
|
must_recompute_list=must_recompute_list,
|
|
),
|
|
)
|
|
|
|
def context_fn_no_recompute_mm():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(
|
|
no_recompute_list=no_recompute_list,
|
|
),
|
|
)
|
|
|
|
def _test(context_fn, bw_compiler):
|
|
def gn(x):
|
|
return torch.sigmoid(torch.matmul(x, x))
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
use_reentrant=False,
|
|
context_fn=context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=1,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x)
|
|
|
|
_test(
|
|
context_fn=context_fn_must_recompute_mm,
|
|
bw_compiler=functools.partial(
|
|
count_ops,
|
|
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
|
|
op=torch.ops.aten.mm.default,
|
|
),
|
|
)
|
|
_test(
|
|
context_fn=context_fn_no_recompute_mm,
|
|
bw_compiler=functools.partial(
|
|
count_ops,
|
|
freq=2, # 2 bwd mm ops per fwd matmul
|
|
op=torch.ops.aten.mm.default,
|
|
),
|
|
)
|
|
|
|
def test_sac_with_partial_context_fn(self):
|
|
class CustomPolicy:
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def __call__(self, ctx, out, func, *args, **kwargs):
|
|
return CheckpointPolicy.MUST_SAVE
|
|
|
|
def f(x, y):
|
|
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
|
|
|
context_fn1 = functools.partial(
|
|
create_selective_checkpoint_contexts, CustomPolicy()
|
|
)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
f,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=context_fn1,
|
|
)
|
|
|
|
opt_fn = torch.compile(fn, backend="aot_eager_decomp_partition", fullgraph=True)
|
|
a = torch.randn(4, 4, requires_grad=True, device="cpu")
|
|
b = torch.randn(4, 4, requires_grad=True, device="cpu")
|
|
|
|
expected = fn(a, b)
|
|
result = opt_fn(a, b)
|
|
self.assertEqual(result, expected)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# We would've expected 6 here
|
|
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
|
# if we didn't enable selective checkpointing.
|
|
freq=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
|
self, device
|
|
):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=1,
|
|
op=torch.ops.aten.sigmoid.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# Main check here is just that sigmoid is properly recomputed
|
|
# (we will see a sigmoid() and sigmoid_backward() in the bw graph)
|
|
freq=1,
|
|
op=torch.ops.aten.sigmoid.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
disable_functionalization=True,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
|
# Copy of the above test, but make sure that having a triton kernel in the
|
|
# region does not error.
|
|
def add_one(x):
|
|
out = torch.empty_like(x)
|
|
n_elements = x.numel()
|
|
add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
|
|
return out
|
|
|
|
class AddOne(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return add_one(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, x):
|
|
return x
|
|
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return (
|
|
torch.sigmoid(torch.matmul(torch.matmul(AddOne.apply(x.sin()), y), y))
|
|
* y
|
|
)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# We would've expected 6 here
|
|
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
|
# if we didn't enable selective checkpointing.
|
|
freq=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_tensor_subclass(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
rand_tensor = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
# tensor subclasses as inputs
|
|
x = TwoTensor(rand_tensor, rand_tensor.clone())
|
|
y = TwoTensor(rand_tensor.clone(), rand_tensor.clone())
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# We would've expected 12 here
|
|
# (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12)
|
|
# if we didn't enable selective checkpointing.
|
|
freq=8,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_custom_rule(self, device):
|
|
def _get_custom_policy(meta):
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
|
|
def _custom_policy(mode, func, *args, **kwargs):
|
|
mm_count_key = f"{mode}_mm_count"
|
|
if mm_count_key not in meta:
|
|
meta[mm_count_key] = 0
|
|
if func == torch.ops.aten.mm.default:
|
|
meta[mm_count_key] += 1
|
|
# Saves output of all compute ops, except second mm
|
|
# (i.e. we will hint the partitioner to recompute second mm in backward pass)
|
|
return func in no_recompute_list and not (
|
|
func == torch.ops.aten.mm.default and meta[mm_count_key] == 2
|
|
)
|
|
|
|
return _custom_policy
|
|
|
|
def selective_checkpointing_context_fn():
|
|
meta = {}
|
|
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(
|
|
torch.sigmoid(torch.matmul(torch.matmul(x, y) * y, y) * y)
|
|
)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# Q: How do we come to this number 4?
|
|
# A: We have 2 matmuls in the forward pass, each matmul contributes 2 `mm` ops in the backward pass,
|
|
# so we have at least 4 `mm` ops in backward pass. It's "at least" because whether second matmul in
|
|
# the forward pass is recomputed in the backward pass is up to the partitioner to decide.
|
|
freq_ge=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
|
|
def selective_checkpointing_context_fn(no_recompute_list):
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=functools.partial(
|
|
selective_checkpointing_context_fn, [torch.ops.aten.mm.default]
|
|
),
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# We would've expected 6 here
|
|
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
|
# if we didn't enable selective checkpointing.
|
|
freq=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_outplace_op(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
torch.ops.aten.sigmoid.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list),
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu()
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[2, 1],
|
|
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[4, 0],
|
|
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_list_ops(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
# recompute everything
|
|
no_recompute_list = []
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.cat([x, y]).sin()
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[1],
|
|
ops=[torch.ops.aten.cat.default],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[1],
|
|
ops=[torch.ops.aten.cat.default],
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
@unittest.skip(
|
|
"In-place op support in selective checkpointing + torch.compile "
|
|
"requires TorchDispatchMode + torch.compile work to complete"
|
|
)
|
|
@requires_cuda_and_triton
|
|
def test_compile_selective_checkpoint_inplace_op(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
torch.ops.aten.sigmoid.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(
|
|
torch.selu_(torch.matmul(torch.matmul(x, y), y))
|
|
).relu_()
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[2, 1],
|
|
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[4, 0],
|
|
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
@torch._inductor.config.patch(fallback_random=True)
|
|
def test_compile_selective_checkpoint_random_op(self, device):
|
|
for preserve_rng_state in [True, False]:
|
|
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.sigmoid.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x):
|
|
return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True))
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
use_reentrant=False,
|
|
# Regardless of whether `preserve_rng_state` is True or False,
|
|
# we will always preserve RNG state when using `torch.compile`.
|
|
preserve_rng_state=preserve_rng_state,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[2, 1],
|
|
ops=[
|
|
torch.ops.aten.sigmoid.default,
|
|
torch.ops.aten.native_dropout.default,
|
|
],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1).
|
|
freqs=[0, 1],
|
|
ops=[
|
|
torch.ops.aten.sigmoid.default,
|
|
torch.ops.aten.native_dropout.default,
|
|
],
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
|
|
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
|
|
# because eager version doesn't preserve RNG state while torch.compile still does.
|
|
# Hence when `preserve_rng_state` is False, we skip the output and gradient comparison
|
|
# between torch.compile and eager.
|
|
self._validate(fn, backend, x, skip_check=not preserve_rng_state)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x)
|
|
|
|
@requires_cuda_and_triton
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_invalid_context(self):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=_invalid_context_gen,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=1,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freq_ge=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
Exception, "must generate a tuple of two `TorchDispatchMode`s"
|
|
):
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda_and_triton
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_compile_selective_checkpoint_parametrization(self):
|
|
def sac_policy():
|
|
def _recomp_policy():
|
|
def _custom_policy(ctx, func, *args, **kwargs):
|
|
to_recompute = func in {
|
|
torch.ops.aten.mul.Tensor,
|
|
torch.ops.aten.sigmoid.default,
|
|
}
|
|
return (
|
|
CheckpointPolicy.MUST_RECOMPUTE
|
|
if to_recompute
|
|
else CheckpointPolicy.MUST_SAVE
|
|
)
|
|
|
|
return _custom_policy
|
|
|
|
return create_selective_checkpoint_contexts(_recomp_policy())
|
|
|
|
class Parametrization(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def parametrization(self, x):
|
|
return torch.sigmoid(torch.mul(x, x))
|
|
|
|
def forward(self, x):
|
|
return checkpoint(
|
|
self.parametrization, x, use_reentrant=False, context_fn=sac_policy
|
|
)
|
|
|
|
def apply_parametrization(model):
|
|
modules = list(model.modules())
|
|
|
|
for mod in modules:
|
|
params_dict = dict(mod.named_parameters(recurse=False))
|
|
for p_name, p in params_dict.items():
|
|
mod.register_parameter(p_name, nn.Parameter(p))
|
|
nn.utils.parametrize.register_parametrization(
|
|
mod, p_name, Parametrization(), unsafe=True
|
|
)
|
|
|
|
return model
|
|
|
|
class MLPModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
torch.manual_seed(5)
|
|
self.net1 = nn.Linear(16, 16, bias=False)
|
|
|
|
def forward(self, x):
|
|
return self.net1(x)
|
|
|
|
def reset_parameters(self):
|
|
self.net1.reset_parameters()
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[1, 1],
|
|
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[
|
|
2, # 1 from mul recompute, 1 from mul backward
|
|
1,
|
|
],
|
|
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
|
)
|
|
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
|
|
model = MLPModule()
|
|
model = apply_parametrization(model)
|
|
model_compiled = torch.compile(
|
|
copy.deepcopy(model), backend=backend, fullgraph=True
|
|
)
|
|
input = torch.randn(8, 16, requires_grad=True)
|
|
input_compiled = copy.deepcopy(input)
|
|
|
|
out = model(input)
|
|
out.sum().backward()
|
|
out_compiled = model_compiled(input_compiled)
|
|
out_compiled.sum().backward()
|
|
|
|
self.assertEqual(out, out_compiled)
|
|
self.assertEqual(input.grad, input_compiled.grad)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_autocast_flash_attention(self, device):
|
|
def fn(primals_1, primals_2, primals_3):
|
|
return torch.ops.aten._scaled_dot_product_efficient_attention.default(
|
|
primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687
|
|
)[0]
|
|
|
|
def gn(*args):
|
|
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
|
|
|
|
with torch.autocast(device_type=device):
|
|
x = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
|
|
y = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
|
|
z = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
|
|
args = (x, y, z)
|
|
|
|
torch.manual_seed(0)
|
|
ref = gn(*args)
|
|
|
|
opt_gn = torch.compile(gn)
|
|
torch.manual_seed(0)
|
|
res = opt_gn(*args)
|
|
self.assertEqual(ref, res)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_error_msg(self, device):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
x = torch.sin(x)
|
|
torch._dynamo.graph_break()
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
|
|
|
|
x = torch.randn(4, 4).to(device)
|
|
opt_fn = torch.compile(fn, fullgraph=True)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported, "User-inserted graph break"
|
|
):
|
|
opt_fn(x)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_list_inputs(self, device):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, ys):
|
|
a = torch.sin(x) # noqa: F841
|
|
b = torch.cos(ys[0])
|
|
c = torch.cos(ys[1])
|
|
return (x, [b, c])
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x, ys):
|
|
return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True)
|
|
|
|
x = torch.randn(4, 4).to(device)
|
|
y = torch.randn(4, 4).to(device)
|
|
z = torch.randn(4, 4).to(device)
|
|
ref = fn(x, [y, z])
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x, [y, z])
|
|
self.assertEqual(ref, res)
|
|
|
|
@requires_cuda_and_triton
|
|
def test_pattern_matcher(self, device):
|
|
# Check that the sdpa op is recomputed in the backward graph
|
|
# tests percolate_tags
|
|
|
|
@checkpoint_wrapper
|
|
def dot_prod_attention(
|
|
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return (
|
|
torch.matmul(query, key.transpose(-2, -1))
|
|
.mul(1.0 / math.sqrt(key.shape[-1]))
|
|
.softmax(dim=-1)
|
|
.matmul(value)
|
|
)
|
|
|
|
def fn(query, key, value):
|
|
# Checks that sin is not recomputed in the backward graph
|
|
return dot_prod_attention(query.sin(), key, value)
|
|
|
|
tensor_shape = (4, 2, 16, 32)
|
|
dtype = torch.float16
|
|
args1 = [
|
|
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
|
|
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
|
|
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
|
|
]
|
|
|
|
# Save the AOT graphs
|
|
aot_graphs = []
|
|
from torch._inductor import compile_fx
|
|
|
|
def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
|
|
aot_graphs.append(graph)
|
|
return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs)
|
|
|
|
backend = functools.partial(
|
|
compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
|
|
)
|
|
|
|
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
|
opt_fn(*args1).sum().backward()
|
|
|
|
fwd_graph = aot_graphs[0]
|
|
op1 = torch.ops.aten._scaled_dot_product_flash_attention.default
|
|
op2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default
|
|
self.assertTrue(
|
|
count_ops(
|
|
fwd_graph,
|
|
[],
|
|
freq=1,
|
|
op=op1,
|
|
)
|
|
or count_ops(
|
|
fwd_graph,
|
|
[],
|
|
freq=1,
|
|
op=op2,
|
|
)
|
|
)
|
|
bwd_graph = aot_graphs[1]
|
|
# Check that sin is not recomputed in the backward graph - checks percolate tags
|
|
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
|
|
# Check that the sdpa op is recomputed in the backward graph
|
|
self.assertTrue(
|
|
count_ops(
|
|
bwd_graph,
|
|
[],
|
|
freq=1,
|
|
op=op1,
|
|
)
|
|
or count_ops(
|
|
bwd_graph,
|
|
[],
|
|
freq=1,
|
|
op=op2,
|
|
)
|
|
)
|
|
|
|
@requires_distributed()
|
|
@requires_cuda_and_triton
|
|
def test_distributed_utils_checkpoint_wrapper(self):
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
checkpoint_wrapper as dist_checkpoint_wrapper,
|
|
)
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.c = 2
|
|
|
|
def forward(self, x):
|
|
x = torch.sin(x)
|
|
x = self.linear(x)
|
|
x = torch.cos(x)
|
|
return x * self.c
|
|
|
|
mod = dist_checkpoint_wrapper(MockModule())
|
|
x = torch.randn(4, 4)
|
|
ref = mod(x)
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
res = opt_mod(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
@requires_distributed()
|
|
@requires_cuda_and_triton
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_dynamo_does_not_trace_getattr_as_top_frame(self):
|
|
# inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do.
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
CheckpointWrapper,
|
|
)
|
|
|
|
cnt = CompileCounterWithBackend("eager")
|
|
|
|
lin = torch.nn.Linear(1, 1)
|
|
mod = torch.nn.Sequential(lin, lin)
|
|
mod = CheckpointWrapper(mod)
|
|
mod._checkpoint_wrapped_module.a = torch.ones(1, 1)
|
|
|
|
def fn(x):
|
|
return mod(x) * mod.a
|
|
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
x = torch.randn(1, 1)
|
|
|
|
self.assertEqual(opt_fn(x), fn(x))
|
|
|
|
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
|
def test_nonlocal_mutation(self):
|
|
counter = 0
|
|
|
|
def gn(x):
|
|
nonlocal counter
|
|
counter += 1
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
fn(x).sum().backward()
|
|
# The mutation is reapplied in the backward as well
|
|
self.assertEqual(counter, 2)
|
|
counter = 0
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
opt_fn(x).sum().backward()
|
|
# The mutation is not reapplied in the backward because the flag was on.
|
|
self.assertEqual(counter, 1)
|
|
|
|
|
|
devices = ["cuda", "hpu"]
|
|
instantiate_device_type_tests(
|
|
ActivationCheckpointingViaTagsTests, globals(), only_for=devices
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|