mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
TODO: - [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync - [x] Update rng state initialization to take from correct device - [x] Tests - [x] handling of retain_graph - [x] respect fallback random Fix for https://github.com/pytorch/pytorch/issues/130123. Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states. We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward. ``` ===== Forward graph 1 ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0): sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None ... ===== Backward graph 1 ===== def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0): sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None ``` There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls: - fwd0: fwd_rng_state0 -> fwd_rng_state1 - fwd1: fwd_rng_state1 -> fwd_rng_state2 - bwd1 - bwd0 Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary. Other notes: Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order. Questions for reviewers: This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`. Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts. Edit: updated to be taken from randint() Update: initializing rng states from torch.randint.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878 Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
1416 lines
49 KiB
Python
1416 lines
49 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import contextlib
|
|
import copy
|
|
import functools
|
|
import math
|
|
import unittest # noqa: F811
|
|
from importlib import import_module
|
|
|
|
import torch
|
|
import torch._dynamo.config
|
|
import torch._dynamo.test_case
|
|
import torch._functorch.config
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
from torch._dynamo.testing import CompileCounterWithBackend
|
|
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
|
from torch.testing._internal.common_cuda import (
|
|
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
|
SM90OrLater,
|
|
)
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
from torch.utils.checkpoint import (
|
|
checkpoint,
|
|
CheckpointPolicy,
|
|
create_selective_checkpoint_contexts,
|
|
)
|
|
|
|
|
|
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
|
requires_distributed = functools.partial(
|
|
unittest.skipIf, not dist.is_available(), "requires distributed"
|
|
)
|
|
|
|
|
|
def checkpoint_wrapper(fn):
|
|
def inner(*args):
|
|
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
|
|
|
|
return inner
|
|
|
|
|
|
def count_ops(
|
|
gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None
|
|
):
|
|
def match_rng_op(node, op):
|
|
if isinstance(node.target, torch._ops.HigherOrderOperator):
|
|
if node.name == "run_and_save_rng_state":
|
|
return node.args[0] == op
|
|
elif node.name == "run_with_rng_state":
|
|
return node.args[1] == op
|
|
elif node.name == "graphsafe_run_with_rng_state":
|
|
return node.args[0] == op
|
|
return False
|
|
|
|
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
|
|
if op is not None:
|
|
assert not isinstance(op, list)
|
|
ops = [op]
|
|
if freq is not None:
|
|
freqs = [freq]
|
|
if freq_ge is not None:
|
|
freqs_ge = [freq_ge]
|
|
if freqs:
|
|
for op, freq in zip(ops, freqs):
|
|
actual_count = 0
|
|
for node in gm.graph.nodes:
|
|
if match_rng_op(node, op) or node.target == op:
|
|
actual_count += 1
|
|
err_msg = f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}."
|
|
assert actual_count == freq, err_msg
|
|
else:
|
|
assert freqs_ge is not None
|
|
for op, freq_ge in zip(ops, freqs_ge):
|
|
actual_count = 0
|
|
for node in gm.graph.nodes:
|
|
if match_rng_op(node, op) or node.target == op:
|
|
actual_count += 1
|
|
assert (
|
|
actual_count >= freq_ge
|
|
), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
|
|
return gm
|
|
|
|
|
|
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: set[str]):
|
|
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
|
|
return_node = list(graph.nodes)[-1]
|
|
assert return_node.target == "output"
|
|
for x in return_node.args[0]:
|
|
fwd_outputs.add(str(x))
|
|
|
|
|
|
class _InvalidContext:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
|
|
def _invalid_context_gen():
|
|
return _InvalidContext(), _InvalidContext()
|
|
|
|
|
|
def find_first_node(gm, func):
|
|
for node in gm.graph.nodes:
|
|
if node.target is func:
|
|
return node
|
|
return None
|
|
|
|
|
|
def op_count(gm):
|
|
result = 0
|
|
for node in gm.graph.nodes:
|
|
if "call" in node.op:
|
|
result += 1
|
|
return result
|
|
|
|
|
|
def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
|
|
def _custom_policy(ctx, func, *args, **kwargs):
|
|
if no_recompute_list is not None and func in no_recompute_list:
|
|
return CheckpointPolicy.MUST_SAVE
|
|
if must_recompute_list is not None and func in must_recompute_list:
|
|
return CheckpointPolicy.MUST_RECOMPUTE
|
|
else:
|
|
return CheckpointPolicy.PREFER_RECOMPUTE
|
|
|
|
return _custom_policy
|
|
|
|
|
|
class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
|
def _validate(
|
|
self,
|
|
fn,
|
|
backend,
|
|
*args,
|
|
skip_check=False,
|
|
fullgraph=True,
|
|
compiled_autograd=False,
|
|
):
|
|
cloned_args = []
|
|
for arg in args:
|
|
cloned_args.append(arg.detach().clone().requires_grad_(arg.requires_grad))
|
|
|
|
cloned_fn = copy.deepcopy(fn)
|
|
|
|
torch.manual_seed(0)
|
|
expected = fn(*args)
|
|
expected.sum().backward()
|
|
|
|
torch.manual_seed(0)
|
|
compiled_fn = torch.compile(cloned_fn, fullgraph=fullgraph, backend=backend)
|
|
ctx = contextlib.nullcontext()
|
|
if compiled_autograd:
|
|
ctx = torch._dynamo.compiled_autograd._enable(
|
|
lambda gm: torch.compile(gm, fullgraph=fullgraph, backend=backend)
|
|
)
|
|
with ctx:
|
|
result = compiled_fn(*cloned_args)
|
|
result.sum().backward()
|
|
|
|
if not skip_check:
|
|
self.assertEqual(
|
|
result,
|
|
expected,
|
|
msg="Output mismatch between torch.compile and eager versions",
|
|
)
|
|
for arg, cloned_arg in zip(args, cloned_args):
|
|
self.assertEqual(
|
|
arg.grad,
|
|
cloned_arg.grad,
|
|
msg="Gradient mismatch between torch.compile and eager versions",
|
|
)
|
|
|
|
def _compare_orig_and_checkpointed_fns(
|
|
self, orig_fn, checkpointed_fn, *args, fullgraph=True
|
|
):
|
|
# The original version and the checkpointed version of the same function
|
|
# should produce the same outputs and the same gradients under torch.compile.
|
|
|
|
# Run original version
|
|
cloned_args_orig_fn = []
|
|
for arg in args:
|
|
cloned_args_orig_fn.append(
|
|
arg.detach().clone().requires_grad_(arg.requires_grad)
|
|
)
|
|
torch.manual_seed(0)
|
|
compiled_orig_fn = torch.compile(
|
|
orig_fn, fullgraph=fullgraph, backend="inductor"
|
|
)
|
|
result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn)
|
|
result_orig_fn.sum().backward()
|
|
|
|
# Run checkpointed version
|
|
cloned_args_checkpointed_fn = []
|
|
for arg in args:
|
|
cloned_args_checkpointed_fn.append(
|
|
arg.detach().clone().requires_grad_(arg.requires_grad)
|
|
)
|
|
torch.manual_seed(0)
|
|
compiled_checkpointed_fn = torch.compile(
|
|
checkpointed_fn, fullgraph=fullgraph, backend="inductor"
|
|
)
|
|
result_checkpointed_fn = compiled_checkpointed_fn(*cloned_args_checkpointed_fn)
|
|
result_checkpointed_fn.sum().backward()
|
|
|
|
# Check that outputs and gradients are equal
|
|
self.assertEqual(
|
|
result_orig_fn,
|
|
result_checkpointed_fn,
|
|
msg="Output mismatch between the original version and the checkpointed version of the same function",
|
|
)
|
|
for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip(
|
|
cloned_args_orig_fn, cloned_args_checkpointed_fn
|
|
):
|
|
self.assertEqual(
|
|
cloned_arg_orig_fn.grad,
|
|
cloned_arg_checkpointed_fn.grad,
|
|
msg="Gradient mismatch between the original version and the checkpointed version of the same function",
|
|
)
|
|
|
|
def test_tags_function(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=3, op=torch.ops.aten.mm.default
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
def test_tags_function_via_global_checkpoint(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
# This goes through VariableBuilder
|
|
return checkpoint(gn, torch.sin(x), y, use_reentrant=True)
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=3, op=torch.ops.aten.mm.default
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
def test_tags_function_with_kwargs(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
|
|
)
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=3, op=torch.ops.aten.mm.default
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
def test_tags_sequential_layers(self, device):
|
|
def gn(x):
|
|
x = x.cos()
|
|
for _ in range(3):
|
|
x = torch.mm(x, x)
|
|
x = x.cos()
|
|
return x
|
|
|
|
def fn(x):
|
|
x = torch.utils.checkpoint.checkpoint(gn, x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x)
|
|
return x
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=6, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[2, 18],
|
|
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x)
|
|
|
|
@requires_cuda
|
|
def test_tags_multiple_checkpoints(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(z)
|
|
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
return z
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=6, op=torch.ops.aten.mm.default
|
|
) # mm recomputed in the bwd
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
def test_tags_module(self, device):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return torch.sigmoid(self.linear(x))
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
mod, torch.sin(x), use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(10, 10, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
|
)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x)
|
|
|
|
@requires_cuda
|
|
def test_tags_decomps(self, device):
|
|
# Ensures that tags are passed on through decompositions as well
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.gelu(self.linear(x))
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
mod, torch.sin(x), use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(10, 10, device=device, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.erf.default
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.erf.default
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
decompositions=lambda: import_module(
|
|
"torch._inductor.compile_fx"
|
|
).select_decomp_table(),
|
|
)
|
|
self._validate(fn, backend, x)
|
|
|
|
@requires_cuda
|
|
@torch._inductor.config.patch(fallback_random=True)
|
|
def test_tags_recomputed_rand(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
return z
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
# bw_compiler = functools.partial(
|
|
# count_ops, freq=6, op=torch.ops.aten.mm.default
|
|
# ) # mm recomputed in the bwd
|
|
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
backend = "inductor"
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
@torch._inductor.config.patch(fallback_random=True)
|
|
def test_tags_rand(self, device):
|
|
def gn(x, y):
|
|
x = torch.mm(x, y)
|
|
x = torch.mm(x, y)
|
|
return x
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
# x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
return x
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
# fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
# bw_compiler = functools.partial(
|
|
# count_ops, freq=6, op=torch.ops.aten.mm.default
|
|
# ) # mm recomputed in the bwd
|
|
# backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
# backend = "aot_eager"
|
|
backend = "inductor"
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
@torch._inductor.config.patch(fallback_random=True)
|
|
def test_tags_dropout(self, device):
|
|
# Figure out a way to test the number of inductor_random calls
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.dropout = torch.nn.Dropout(0.2)
|
|
|
|
def forward(self, x):
|
|
return self.dropout(self.linear(x))
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
|
|
|
|
x = torch.randn(10, 10, device=device, requires_grad=True)
|
|
backend = "inductor"
|
|
# rand decomps do not have have numerical results as eager
|
|
self._validate(fn, backend, x, skip_check=True)
|
|
|
|
@skipIfHpu
|
|
@torch._functorch.config.patch(recompute_views=True)
|
|
@torch._inductor.config.patch(fx_graph_cache=False)
|
|
def test_tags_must_save_tensor_that_has_backward_hook(self):
|
|
def my_post_forward_hook(submod, args, output):
|
|
output.register_hook(my_backward_hook)
|
|
return output
|
|
|
|
def my_backward_hook(grad):
|
|
return grad
|
|
|
|
class MySubmod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
y = torch.matmul(x, x)
|
|
z = y * y
|
|
return z
|
|
|
|
class MyMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submod = MySubmod()
|
|
self.norm = torch.nn.LayerNorm(4)
|
|
|
|
def forward(self, x):
|
|
out = torch.utils.checkpoint.checkpoint(
|
|
self.submod, x, use_reentrant=False
|
|
)
|
|
norm_out = self.norm(out)
|
|
return norm_out
|
|
|
|
def _factory_fn():
|
|
mod = MyMod()
|
|
x = torch.ones(4, 4, dtype=torch.float32, requires_grad=True)
|
|
backend = "inductor"
|
|
return mod, x, backend
|
|
|
|
mod_no_hook, x, backend = _factory_fn()
|
|
mod_no_hook_fwd_outputs = set()
|
|
|
|
with torch._inductor.config.patch(
|
|
post_grad_custom_pre_pass=functools.partial(
|
|
collect_fwd_graph_outputs, fwd_outputs=mod_no_hook_fwd_outputs
|
|
)
|
|
):
|
|
self._validate(
|
|
mod_no_hook, backend, x, fullgraph=True, compiled_autograd=True
|
|
)
|
|
|
|
torch._dynamo.reset()
|
|
mod_with_hook, x, backend = _factory_fn()
|
|
mod_with_hook.submod.register_forward_hook(my_post_forward_hook)
|
|
mod_with_hook_fwd_outputs = set()
|
|
|
|
with torch._inductor.config.patch(
|
|
post_grad_custom_pre_pass=functools.partial(
|
|
collect_fwd_graph_outputs, fwd_outputs=mod_with_hook_fwd_outputs
|
|
)
|
|
):
|
|
self._validate(
|
|
mod_with_hook, backend, x, fullgraph=True, compiled_autograd=True
|
|
)
|
|
|
|
# If `z` has a backward hook, result of `z = y * y` should also be saved in addition to the usual saved tensors.
|
|
mod_no_hook_fwd_outputs_no_primal = {
|
|
x for x in mod_no_hook_fwd_outputs if not x.startswith("primals_")
|
|
}
|
|
mod_with_hook_fwd_outputs_no_primal = {
|
|
x for x in mod_with_hook_fwd_outputs if not x.startswith("primals_")
|
|
}
|
|
additional_saved_tensors = (
|
|
mod_with_hook_fwd_outputs_no_primal - mod_no_hook_fwd_outputs_no_primal
|
|
)
|
|
expected_additional_saved_tensors = {"mul"}
|
|
self.assertEqual(
|
|
additional_saved_tensors,
|
|
expected_additional_saved_tensors,
|
|
f"""
|
|
Expected additional saved tensors: {expected_additional_saved_tensors} but got: {additional_saved_tensors}.
|
|
Non-primal fwd outputs from model w/ backward hook: {mod_with_hook_fwd_outputs_no_primal}.
|
|
Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""",
|
|
)
|
|
|
|
@requires_cuda
|
|
def test_fallback(self, device):
|
|
def gn(x, y):
|
|
torch._dynamo.graph_break()
|
|
a = torch.sigmoid(torch.matmul(x, y))
|
|
torch._dynamo.graph_break()
|
|
return torch.cos(a)
|
|
|
|
def fn(x, y):
|
|
return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False))
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
args = (x, y)
|
|
|
|
backend = "aot_eager"
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
expected = fn(*args)
|
|
result = torch.compile(fn, backend=cnt)(*args)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
# One graph for torch.sin on the input, and other for torch.cos.
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
self.assertEqual(len(cnt.graphs), 2)
|
|
|
|
@requires_cuda
|
|
def test_kwargs(self, device):
|
|
def gn(x, y, z=None):
|
|
a = torch.matmul(x, y)
|
|
if z is not None:
|
|
return torch.matmul(a, z)
|
|
return a
|
|
|
|
def fn(x, y, z):
|
|
return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
z = torch.randn(4, 4, requires_grad=True, device=device)
|
|
args = (x, y, z)
|
|
|
|
backend = "aot_eager"
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
expected = fn(*args)
|
|
result = torch.compile(fn, backend=cnt)(*args)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(cnt.graphs), 1)
|
|
|
|
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
|
|
# one for checkpoint, and 3 for x, y, z
|
|
self.assertEqual(len(wrap_node.args), 4)
|
|
|
|
body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 2)
|
|
|
|
@requires_cuda
|
|
def test_symints_location(self, device):
|
|
def gn(x, y):
|
|
return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
|
|
backend = "aot_eager"
|
|
cnt = CompileCounterWithBackend(backend)
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
args = (x, y)
|
|
expected = fn(*args)
|
|
result = opt_fn(*args)
|
|
|
|
x = torch.randn(5, 5, requires_grad=True, device=device)
|
|
y = torch.randn(5, 5, requires_grad=True, device=device)
|
|
args = (x, y)
|
|
expected = fn(*args)
|
|
result = opt_fn(*args)
|
|
|
|
self.assertEqual(result.shape, expected.shape)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(len(cnt.graphs), 2)
|
|
wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
|
|
self.assertEqual(len(wrap_node.args), 3)
|
|
|
|
@requires_cuda
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_must_recompute(self, device):
|
|
def context_fn_must_recompute_mm():
|
|
must_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(
|
|
must_recompute_list=must_recompute_list,
|
|
),
|
|
)
|
|
|
|
def context_fn_no_recompute_mm():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(
|
|
no_recompute_list=no_recompute_list,
|
|
),
|
|
)
|
|
|
|
def _test(context_fn, bw_compiler):
|
|
def gn(x):
|
|
return torch.sigmoid(torch.matmul(x, x))
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
use_reentrant=False,
|
|
context_fn=context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=1,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x)
|
|
|
|
_test(
|
|
context_fn=context_fn_must_recompute_mm,
|
|
bw_compiler=functools.partial(
|
|
count_ops,
|
|
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
|
|
op=torch.ops.aten.mm.default,
|
|
),
|
|
)
|
|
_test(
|
|
context_fn=context_fn_no_recompute_mm,
|
|
bw_compiler=functools.partial(
|
|
count_ops,
|
|
freq=2, # 2 bwd mm ops per fwd matmul
|
|
op=torch.ops.aten.mm.default,
|
|
),
|
|
)
|
|
|
|
@requires_cuda
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# We would've expected 6 here
|
|
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
|
# if we didn't enable selective checkpointing.
|
|
freq=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_tensor_subclass(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
rand_tensor = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
# tensor subclasses as inputs
|
|
x = TwoTensor(rand_tensor, rand_tensor.clone())
|
|
y = TwoTensor(rand_tensor.clone(), rand_tensor.clone())
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# We would've expected 12 here
|
|
# (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12)
|
|
# if we didn't enable selective checkpointing.
|
|
freq=8,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_custom_rule(self, device):
|
|
def _get_custom_policy(meta):
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
]
|
|
|
|
def _custom_policy(mode, func, *args, **kwargs):
|
|
mm_count_key = f"{mode}_mm_count"
|
|
if mm_count_key not in meta:
|
|
meta[mm_count_key] = 0
|
|
if func == torch.ops.aten.mm.default:
|
|
meta[mm_count_key] += 1
|
|
# Saves output of all compute ops, except second mm
|
|
# (i.e. we will hint the partitioner to recompute second mm in backward pass)
|
|
return func in no_recompute_list and not (
|
|
func == torch.ops.aten.mm.default and meta[mm_count_key] == 2
|
|
)
|
|
|
|
return _custom_policy
|
|
|
|
def selective_checkpointing_context_fn():
|
|
meta = {}
|
|
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(
|
|
torch.sigmoid(torch.matmul(torch.matmul(x, y) * y, y) * y)
|
|
)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# Q: How do we come to this number 4?
|
|
# A: We have 2 matmuls in the forward pass, each matmul contributes 2 `mm` ops in the backward pass,
|
|
# so we have at least 4 `mm` ops in backward pass. It's "at least" because whether second matmul in
|
|
# the forward pass is recomputed in the backward pass is up to the partitioner to decide.
|
|
freq_ge=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
|
|
def selective_checkpointing_context_fn(no_recompute_list):
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=functools.partial(
|
|
selective_checkpointing_context_fn, [torch.ops.aten.mm.default]
|
|
),
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# We would've expected 6 here
|
|
# (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
|
# if we didn't enable selective checkpointing.
|
|
freq=4,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_outplace_op(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
torch.ops.aten.sigmoid.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list),
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu()
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[2, 1],
|
|
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[4, 0],
|
|
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
@unittest.skip(
|
|
"In-place op support in selective checkpointing + torch.compile "
|
|
"requires TorchDispatchMode + torch.compile work to complete"
|
|
)
|
|
@requires_cuda
|
|
def test_compile_selective_checkpoint_inplace_op(self, device):
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.mm.default,
|
|
torch.ops.aten.sigmoid.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(
|
|
torch.selu_(torch.matmul(torch.matmul(x, y), y))
|
|
).relu_()
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
y = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[2, 1],
|
|
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[4, 0],
|
|
ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
self._validate(fn, backend, x, y)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
|
|
|
@requires_cuda
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
@torch._inductor.config.patch(fallback_random=True)
|
|
def test_compile_selective_checkpoint_random_op(self, device):
|
|
for preserve_rng_state in [True, False]:
|
|
|
|
def selective_checkpointing_context_fn():
|
|
no_recompute_list = [
|
|
torch.ops.aten.sigmoid.default,
|
|
]
|
|
return create_selective_checkpoint_contexts(
|
|
_get_custom_policy(no_recompute_list=no_recompute_list)
|
|
)
|
|
|
|
def gn(x):
|
|
return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True))
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
use_reentrant=False,
|
|
# Regardless of whether `preserve_rng_state` is True or False,
|
|
# we will always preserve RNG state when using `torch.compile`.
|
|
preserve_rng_state=preserve_rng_state,
|
|
context_fn=selective_checkpointing_context_fn,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True, device=device)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[2, 1],
|
|
ops=[
|
|
torch.ops.aten.sigmoid.default,
|
|
torch.ops.aten.native_dropout.default,
|
|
],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
# NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1).
|
|
freqs=[0, 1],
|
|
ops=[
|
|
torch.ops.aten.sigmoid.default,
|
|
torch.ops.aten.native_dropout.default,
|
|
],
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
|
|
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
|
|
# because eager version doesn't preserve RNG state while torch.compile still does.
|
|
# Hence when `preserve_rng_state` is False, we skip the output and gradient comparison
|
|
# between torch.compile and eager.
|
|
self._validate(fn, backend, x, skip_check=not preserve_rng_state)
|
|
self._compare_orig_and_checkpointed_fns(gn, fn, x)
|
|
|
|
@requires_cuda
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
|
def test_compile_selective_checkpoint_invalid_context(self):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y)) * y
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
x,
|
|
y,
|
|
use_reentrant=False,
|
|
context_fn=_invalid_context_gen,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freq=1,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freq_ge=2,
|
|
op=torch.ops.aten.mm.default,
|
|
)
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
Exception, "must generate a tuple of two `TorchDispatchMode`s"
|
|
):
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_compile_selective_checkpoint_parametrization(self):
|
|
def sac_policy():
|
|
def _recomp_policy():
|
|
def _custom_policy(ctx, func, *args, **kwargs):
|
|
to_recompute = func in {
|
|
torch.ops.aten.mul.Tensor,
|
|
torch.ops.aten.sigmoid.default,
|
|
}
|
|
return (
|
|
CheckpointPolicy.MUST_RECOMPUTE
|
|
if to_recompute
|
|
else CheckpointPolicy.MUST_SAVE
|
|
)
|
|
|
|
return _custom_policy
|
|
|
|
return create_selective_checkpoint_contexts(_recomp_policy())
|
|
|
|
class Parametrization(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def parametrization(self, x):
|
|
return torch.sigmoid(torch.mul(x, x))
|
|
|
|
def forward(self, x):
|
|
return checkpoint(
|
|
self.parametrization, x, use_reentrant=False, context_fn=sac_policy
|
|
)
|
|
|
|
def apply_parametrization(model):
|
|
modules = list(model.modules())
|
|
|
|
for mod in modules:
|
|
params_dict = dict(mod.named_parameters(recurse=False))
|
|
for p_name, p in params_dict.items():
|
|
mod.register_parameter(p_name, nn.Parameter(p))
|
|
nn.utils.parametrize.register_parametrization(
|
|
mod, p_name, Parametrization(), unsafe=True
|
|
)
|
|
|
|
return model
|
|
|
|
class MLPModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
torch.manual_seed(5)
|
|
self.net1 = nn.Linear(16, 16, bias=False)
|
|
|
|
def forward(self, x):
|
|
return self.net1(x)
|
|
|
|
def reset_parameters(self):
|
|
self.net1.reset_parameters()
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[1, 1],
|
|
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
|
)
|
|
bw_compiler = functools.partial(
|
|
count_ops,
|
|
freqs=[
|
|
2, # 1 from mul recompute, 1 from mul backward
|
|
1,
|
|
],
|
|
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
|
)
|
|
|
|
backend = aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
)
|
|
|
|
model = MLPModule()
|
|
model = apply_parametrization(model)
|
|
model_compiled = torch.compile(
|
|
copy.deepcopy(model), backend=backend, fullgraph=True
|
|
)
|
|
input = torch.randn(8, 16, requires_grad=True)
|
|
input_compiled = copy.deepcopy(input)
|
|
|
|
out = model(input)
|
|
out.sum().backward()
|
|
out_compiled = model_compiled(input_compiled)
|
|
out_compiled.sum().backward()
|
|
|
|
self.assertEqual(out, out_compiled)
|
|
self.assertEqual(input.grad, input_compiled.grad)
|
|
|
|
@skipIfRocm
|
|
@requires_cuda
|
|
def test_autocast_flash_attention(self, device):
|
|
def fn(primals_1, primals_2, primals_3):
|
|
return torch.ops.aten._scaled_dot_product_efficient_attention.default(
|
|
primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687
|
|
)[0]
|
|
|
|
def gn(*args):
|
|
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
|
|
|
|
with torch.autocast(device_type=device):
|
|
x = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
|
|
y = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
|
|
z = torch.randn(4, 2, 16, 32, device=device, requires_grad=True)
|
|
args = (x, y, z)
|
|
|
|
torch.manual_seed(0)
|
|
ref = gn(*args)
|
|
|
|
opt_gn = torch.compile(gn)
|
|
torch.manual_seed(0)
|
|
res = opt_gn(*args)
|
|
self.assertEqual(ref, res)
|
|
|
|
@requires_cuda
|
|
def test_error_msg(self, device):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
x = torch.sin(x)
|
|
torch._dynamo.graph_break()
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
|
|
|
|
x = torch.randn(4, 4).to(device)
|
|
opt_fn = torch.compile(fn, fullgraph=True)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported, "User-inserted graph break"
|
|
):
|
|
opt_fn(x)
|
|
|
|
@requires_cuda
|
|
def test_list_inputs(self, device):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, ys):
|
|
a = torch.sin(x) # noqa: F841
|
|
b = torch.cos(ys[0])
|
|
c = torch.cos(ys[1])
|
|
return (x, [b, c])
|
|
|
|
mod = MockModule().to(device)
|
|
|
|
def fn(x, ys):
|
|
return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True)
|
|
|
|
x = torch.randn(4, 4).to(device)
|
|
y = torch.randn(4, 4).to(device)
|
|
z = torch.randn(4, 4).to(device)
|
|
ref = fn(x, [y, z])
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x, [y, z])
|
|
self.assertEqual(ref, res)
|
|
|
|
@requires_cuda
|
|
def test_pattern_matcher(self, device):
|
|
# Check that the sdpa op is recomputed in the backward graph
|
|
# tests percolate_tags
|
|
|
|
@checkpoint_wrapper
|
|
def dot_prod_attention(
|
|
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return (
|
|
torch.matmul(query, key.transpose(-2, -1))
|
|
.mul(1.0 / math.sqrt(key.shape[-1]))
|
|
.softmax(dim=-1)
|
|
.matmul(value)
|
|
)
|
|
|
|
def fn(query, key, value):
|
|
# Checks that sin is not recomputed in the backward graph
|
|
return dot_prod_attention(query.sin(), key, value)
|
|
|
|
tensor_shape = (4, 2, 16, 32)
|
|
dtype = torch.float16
|
|
args1 = [
|
|
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
|
|
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
|
|
torch.randn(tensor_shape, device=device, dtype=dtype, requires_grad=True),
|
|
]
|
|
|
|
# Save the AOT graphs
|
|
aot_graphs = []
|
|
from torch._inductor import compile_fx
|
|
|
|
def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
|
|
aot_graphs.append(graph)
|
|
return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs)
|
|
|
|
backend = functools.partial(
|
|
compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
|
|
)
|
|
|
|
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
|
opt_fn(*args1).sum().backward()
|
|
if PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
|
|
op = torch.ops.aten._scaled_dot_product_cudnn_attention.default
|
|
else:
|
|
op = torch.ops.aten._scaled_dot_product_flash_attention.default
|
|
|
|
fwd_graph = aot_graphs[0]
|
|
self.assertTrue(
|
|
count_ops(
|
|
fwd_graph,
|
|
[],
|
|
freq=1,
|
|
op=op,
|
|
)
|
|
)
|
|
|
|
bwd_graph = aot_graphs[1]
|
|
# Check that sin is not recomputed in the backward graph - checks percolate tags
|
|
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
|
|
# Check that the sdpa op is recomputed in the backward graph
|
|
self.assertTrue(
|
|
count_ops(
|
|
bwd_graph,
|
|
[],
|
|
freq=1,
|
|
op=op,
|
|
)
|
|
)
|
|
|
|
@requires_distributed()
|
|
@requires_cuda
|
|
def test_distributed_utils_checkpoint_wrapper(self):
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
checkpoint_wrapper as dist_checkpoint_wrapper,
|
|
)
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.c = 2
|
|
|
|
def forward(self, x):
|
|
x = torch.sin(x)
|
|
x = self.linear(x)
|
|
x = torch.cos(x)
|
|
return x * self.c
|
|
|
|
mod = dist_checkpoint_wrapper(MockModule())
|
|
x = torch.randn(4, 4)
|
|
ref = mod(x)
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
res = opt_mod(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
@requires_distributed()
|
|
@requires_cuda
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_dynamo_does_not_trace_getattr_as_top_frame(self):
|
|
# inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do.
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
CheckpointWrapper,
|
|
)
|
|
|
|
cnt = CompileCounterWithBackend("eager")
|
|
|
|
lin = torch.nn.Linear(1, 1)
|
|
mod = torch.nn.Sequential(lin, lin)
|
|
mod = CheckpointWrapper(mod)
|
|
mod._checkpoint_wrapped_module.a = torch.ones(1, 1)
|
|
|
|
def fn(x):
|
|
return mod(x) * mod.a
|
|
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
x = torch.randn(1, 1)
|
|
|
|
self.assertEqual(opt_fn(x), fn(x))
|
|
|
|
|
|
devices = ["cuda", "hpu"]
|
|
instantiate_device_type_tests(
|
|
ActivationCheckpointingViaTagsTests, globals(), only_for=devices
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|