[ca] trace saved variable unpacking (#147242)

## Before

Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op.

## After

We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as:
```python
# pseudocode
class SavedVariable:
  def unpack(self):
    if self.hook:
      return self.hook(self.packed_data)
    else:
      return self.packed_data

# This approach won't directly work when we add support for Forward AD or double-backward.
```

Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution.

All tests pass when running the CA graph directly, the remaining issues are in Dynamo.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242
Approved by: https://github.com/jansel
This commit is contained in:
Simon Fan
2025-02-25 19:57:54 -08:00
committed by PyTorch MergeBot
parent 08f4c1a233
commit 0a2da008f8
10 changed files with 527 additions and 122 deletions

View File

@ -10,6 +10,7 @@ import os
import re
import subprocess
import sys
import tempfile
import unittest
from copy import deepcopy
from importlib.machinery import SourceFileLoader
@ -48,11 +49,17 @@ from torch.testing._internal.logging_utils import logs_to_string
# note: these tests are not run on windows due to inductor_utils.HAS_CPU
def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"):
assert backend in ["inductor", "aot_eager"]
def make_compiler_fn(
fullgraph=True, dynamic=True, backend="inductor", gm_hook=lambda gm: None
):
assert backend in ["inductor", "aot_eager", "ca_eager"]
def _compiler_fn(gm):
"""Same as torch.compile() but counts number of compiles"""
gm_hook(gm)
if backend == "ca_eager":
return gm
def _inner_compiler(gm_, example_inputs_):
counters["compiled_autograd"]["compiles"] += 1
@ -112,7 +119,10 @@ class TestCompiledAutograd(TestCase):
torch.manual_seed(123)
expected = list(fn())
torch.manual_seed(123)
with compiled_autograd._enable(compiler_fn):
with compiled_autograd._enable(compiler_fn), mock.patch(
"torch._functorch.aot_autograd.AOT_COUNTER",
new_callable=itertools.count,
):
opt_fn = torch.compile(fn) if compile_fn else fn
actual = list(opt_fn())
self.assertEqual(expected, actual)
@ -915,7 +925,8 @@ main()
inputs=[param, activ],
sizes=(),
scalars=(),
hooks=(),
hooks=[],
packed_inputs=[],
)
finally:
handle.remove()
@ -3322,7 +3333,10 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
graphs.append(gm)
return inner_compiler_fn(gm)
with compiled_autograd._enable(compiler_fn):
with compiled_autograd._enable(compiler_fn), mock.patch(
"torch._functorch.aot_autograd.AOT_COUNTER",
new_callable=itertools.count,
):
res = fn(x)
res.sum().backward()
@ -3336,7 +3350,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
graph_code,
"""\
class CompiledAutograd0(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]
getitem_2 = inputs[2]
@ -3511,9 +3525,7 @@ class CompiledAutograd0(torch.nn.Module):
fn, count=2, compiler_fn=make_compiler_fn(backend="aot_eager")
)
@unittest.expectedFailure
def test_saved_tensor_unpack_hook_ordering(self):
# not the correct behaviour, I'm just preventing this from changing silently
def f(x, y):
return x * y
@ -3531,8 +3543,6 @@ class CompiledAutograd0(torch.nn.Module):
return x
def tensor_hook(_):
# in eager, tensor_hook is fired before unpack_hook
# but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not
self.assertEqual(unpack_count, 0)
x = torch.ones(4, requires_grad=True)
@ -3544,21 +3554,252 @@ class CompiledAutograd0(torch.nn.Module):
self.assertEqual(pack_count, 1)
self.assertEqual(unpack_count, 0)
loss = out_test.sum()
loss.register_hook(tensor_hook)
loss.register_hook(
tensor_hook
) # scheduled to fire before any saved activations
loss.backward()
self.assertEqual(pack_count, 1)
self.assertEqual(unpack_count, 1)
def test_reentrant_checkpointing(self):
def fn(x):
@parametrize("reentrant", (True, False))
def test_checkpointing_simple(self, reentrant):
def fn():
def _fn(x):
y = x.sin()
z = y.cos()
return (y * z).sum()
inp = torch.rand(10, 10, requires_grad=True)
out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True)
with torch._dynamo.compiled_autograd._enable(torch.compile):
out = torch.utils.checkpoint.checkpoint(_fn, inp, use_reentrant=reentrant)
out.backward()
yield inp.grad
if reentrant:
self.check_output_and_recompiles(
fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False)
)
else:
# dynamo issues, just run the CA graph directly for now
def check(gm):
graph_code = normalize_gm(gm.print_readable(print_output=False))
self.assertExpectedInline(
graph_code,
"""\
class CompiledAutograd0(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]; inputs = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
getitem_2 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_2], [True], [10, 10]); getitem_2 = None
getitem_3 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_3], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_3 = None
getitem_4 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_5 = hooks[0]
getitem_6 = packed_data[0]
getitem_7 = hooks[1]
getitem_8 = packed_data[1]
call_hook = torch__dynamo_external_utils_call_hook(getitem_5, getitem_6, hook_type = 'unpack_hook'); getitem_5 = getitem_6 = None
call_hook_1 = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
mul_backward0 = torch__dynamo_compiled_autograd_ops_MulBackward0([getitem_4], [True, True], call_hook, 6, call_hook_1, 6); getitem_4 = call_hook = call_hook_1 = None
getitem_9 = mul_backward0[0]
getitem_10 = mul_backward0[1]; mul_backward0 = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_9, getitem_10], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False), ((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_9 = getitem_10 = None
getitem_11 = validate_outputs_2[0]
getitem_12 = validate_outputs_2[1]; validate_outputs_2 = None
getitem_13 = hooks[2]
getitem_14 = packed_data[2]
call_hook_2 = torch__dynamo_external_utils_call_hook(getitem_13, getitem_14, hook_type = 'unpack_hook'); getitem_13 = getitem_14 = None
cos_backward0 = torch__dynamo_compiled_autograd_ops_CosBackward0([getitem_12], [True], call_hook_2); getitem_12 = call_hook_2 = None
getitem_15 = cos_backward0[0]; cos_backward0 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_15], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_15 = None
getitem_16 = validate_outputs_3[0]; validate_outputs_3 = None
add = torch.add(getitem_11, getitem_16); getitem_11 = getitem_16 = None
getitem_17 = hooks[3]; hooks = None
getitem_18 = packed_data[3]; packed_data = None
call_hook_3 = torch__dynamo_external_utils_call_hook(getitem_17, getitem_18, hook_type = 'unpack_hook'); getitem_17 = getitem_18 = None
sin_backward0 = torch__dynamo_compiled_autograd_ops_SinBackward0([add], [True], call_hook_3); add = call_hook_3 = None
getitem_19 = sin_backward0[0]; sin_backward0 = None
validate_outputs_4 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_19], [((None, None, device(type='cpu'), 6, 0, None), [10, 10], False)]); getitem_19 = None
getitem_20 = validate_outputs_4[0]; validate_outputs_4 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_20); getitem_1 = getitem_20 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
)
self.check_output_and_recompiles(
fn,
count=[1, 0],
compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check),
)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_cpu_offloading(self):
def fn():
def pack(x):
return x.cpu()
def unpack(x):
return x.cuda()
class MyMatMul(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.matmul(x, x)
@staticmethod
def backward(ctx, grad_out):
(x,) = ctx.saved_tensors
return grad_out * x
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
for i in [10, 100, 10, 20, 30]:
x = torch.randn(i, requires_grad=True).cuda()
MyMatMul.apply(x).sum().backward()
yield x.grad
i = 0
def check(gm):
nonlocal i
if i == 0:
i += 1
return
graph_code = normalize_gm(gm.print_readable(print_output=False))
self.assertExpectedInline(
graph_code,
"""\
class CompiledAutograd1(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]; inputs = None
getitem_2 = sizes[0]; getitem_2 = None
getitem_3 = sizes[1]
getitem_4 = sizes[2]; sizes = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem = None
getitem_5 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], []); getitem_5 = None
getitem_6 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cuda', index=0), 6, 0, None), [], False)]); getitem_6 = None
getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_8 = hooks[0]
getitem_9 = packed_data[0]; packed_data = None
getitem_10 = hooks[1]; hooks = None
call_hook = torch__dynamo_external_utils_call_hook(getitem_8, getitem_9, hook_type = 'unpack_hook'); getitem_8 = getitem_9 = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_10, (call_hook,), getitem_7); getitem_10 = call_hook = getitem_7 = None
getitem_12 = call_backward[0]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_12], [((None, None, device(type='cuda', index=0), 6, 0, None), [getitem_3], False)]); getitem_12 = getitem_3 = None
getitem_13 = validate_outputs_2[0]; validate_outputs_2 = None
to_copy_backward0 = torch__dynamo_compiled_autograd_ops_ToCopyBackward0([getitem_13], [True], (None, None, device(type='cpu'), 6, 0, None)); getitem_13 = None
getitem_14 = to_copy_backward0[0]; to_copy_backward0 = None
validate_outputs_3 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_14], [((None, None, device(type='cpu'), 6, 0, None), [getitem_4], False)]); getitem_14 = getitem_4 = None
getitem_15 = validate_outputs_3[0]; validate_outputs_3 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_15); getitem_1 = getitem_15 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
)
self.check_output_and_recompiles(
fn, count=2, compiler_fn=make_compiler_fn(gm_hook=check)
)
@skipIfWindows(msg="temp dir not compatible")
def test_disk_offloading(self):
with tempfile.TemporaryDirectory() as d:
def fn():
pack_count = 0
def pack(x):
nonlocal pack_count
path = f"{d}/{pack_count}.pt"
torch.save(x, path)
return path
def unpack(path):
x = torch.load(path)
return x
class MyMatMul(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.matmul(x, x)
@staticmethod
def backward(ctx, grad_out):
(x,) = ctx.saved_tensors
return grad_out * x
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
for i in [10, 100, 10, 20, 30]:
x = torch.randn(i, requires_grad=True)
MyMatMul.apply(x).sum().backward()
yield x.grad
i = 0
def check(gm):
nonlocal i
if i == 0:
i += 1
return
graph_code = normalize_gm(gm.print_readable(print_output=False))
self.assertExpectedInline(
graph_code,
"""\
class CompiledAutograd1(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks, packed_data):
getitem = inputs[0]
getitem_1 = inputs[1]; inputs = None
getitem_2 = sizes[0]; getitem_2 = None
getitem_3 = sizes[1]; sizes = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem = None
getitem_4 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_4], [True], []); getitem_4 = None
getitem_5 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_5], [((None, None, device(type='cpu'), 6, 0, None), [], False)]); getitem_5 = None
getitem_6 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_7 = hooks[0]
getitem_8 = packed_data[0]; packed_data = None
getitem_9 = hooks[1]; hooks = None
call_hook = torch__dynamo_external_utils_call_hook(getitem_7, getitem_8, hook_type = 'unpack_hook'); getitem_7 = getitem_8 = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_9, (call_hook,), getitem_6); getitem_9 = call_hook = getitem_6 = None
getitem_11 = call_backward[0]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_11], [((None, None, device(type='cpu'), 6, 0, None), [getitem_3], False)]); getitem_11 = getitem_3 = None
getitem_12 = validate_outputs_2[0]; validate_outputs_2 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_12); getitem_1 = getitem_12 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
)
# 1 graph break on torch.load -> 2 dynamo graphs
self.check_output_and_recompiles(
fn,
count=[2, 4],
compiler_fn=make_compiler_fn(fullgraph=False, gm_hook=check),
)
@skipIfWindows(msg="node name demangling inconsistent on windows")
def test_backward_hook_relative_ordering_partial(self):
@ -3617,7 +3858,7 @@ class CompiledAutograd0(torch.nn.Module):
self.check_output_and_recompiles(fn)
def test_sac(self):
def test_checkpointing_sac(self):
# circular import
from torch.utils.checkpoint import (
checkpoint,
@ -3666,7 +3907,9 @@ class CompiledAutograd0(torch.nn.Module):
yield model.layer4.weight.grad
yield model.layer4.bias.grad
self.check_output_and_recompiles(fn)
self.check_output_and_recompiles(
fn, count=[1, 5], compiler_fn=make_compiler_fn(fullgraph=False)
)
def load_test_module(name):
@ -3754,6 +3997,26 @@ known_graph_breaks_tests = {
"test_deep_reentrant", # reentrant .backward
"test_reentrant_priority", # reentrant .backward
"test_simple_reentrant", # reentrant .backward
"test_checkpoint_detects_non_determinism", # unpack hook in skip files
"test_checkpoint_valid_reset_on_error", # unpack hook in skip files
"test_checkpointing_non_reentrant_autocast_cpu", # unpack hook in skip files
"test_checkpointing_non_reentrant_autocast_gpu", # unpack hook in skip files
"test_checkpointing_without_reentrant_arbitrary_input_output", # unpack hook in skip files
"test_checkpointing_without_reentrant_correct_grad", # unpack hook in skip files
"test_checkpointing_without_reentrant_custom_function_works", # unpack hook in skip files
"test_checkpointing_without_reentrant_dataparallel", # _get_device_index in skip files
"test_checkpointing_without_reentrant_detached_tensor_use_reentrant_True", # reentrant .backward
"test_checkpointing_without_reentrant_parameter_used_in_an_out", # unpack hook in skip files
"test_checkpointing_without_reentrant_with_context_fn", # unpack hook in skip files
"test_save_on_cpu_and_checkpoint", # unpack hook in skip files
"test_saved_tensor_hooks_custom_error_propagation", # CustomError
"test_access_saved_tensor_twice_without_recomputation_works", # unpack hook in skip files
"test_saved_tensor_hooks_extra_enter_during_bw_no_leak", # ctx in skip files
"test_saved_tensor_hooks_extra_exit_during_bw_no_crash", # ctx in skip files
"test_checkpointing", # reentrant .backward
"test_checkpointing_without_reentrant_input_requires_grad_False", # reentrant .backward
"test_checkpointing_without_reentrant_input_requires_grad_True", # reentrant .backward
"test_checkpointing_without_reentrant_memory_savings", # reentrant .backward
}
test_contexts = {
@ -3764,9 +4027,7 @@ test_contexts = {
}
# These groups of tests aren't supported yet
known_failures_re = re.compile(
r"^test_(sparse|profiler|gradcheck|checkpoint|named_tensor)"
)
known_failures_re = re.compile(r"^test_(sparse|profiler|gradcheck|named_tensor)")
# Bugs needing investigation:
skipped_tests = {
@ -3837,7 +4098,7 @@ known_failing_tests = {
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
"test_grad_nonleaf_register_hook",
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
# Category: Dynamo
# Category: Dynamo (pass when directly running CA graph)
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
"test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None
@ -3849,7 +4110,20 @@ known_failing_tests = {
"test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int
"test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int
"test_setitem", # CopySlices accuracy error
# Category: Inductor
"test_save_on_cpu_and_checkpoint", # https://github.com/pytorch/pytorch/issues/147565
"test_checkpoint_detects_non_determinism", # different error
"test_checkpointing_non_reentrant_autocast_cpu", # saved != recompute
"test_checkpointing_non_reentrant_autocast_gpu", # saved != recompute
"test_checkpointing_without_reentrant_saved_object_identity", # same as https://github.com/pytorch/pytorch/issues/136193
"test_saved_variable_packing_unpacking_did_not_save_original_with_hooks", # register_hooks multiple times
"test_saved_variable_saved_original_inplace_detach", # RuntimeError not raised
"test_access_saved_tensor_twice_without_recomputation_works", # saved != recompute
"test_checkpointing_without_reentrant_dataparallel", # https://github.com/pytorch/pytorch/issues/127115
"test_checkpointing", # takes very very long
"test_checkpointing_without_reentrant_input_requires_grad_False", # takes very very long
"test_checkpointing_without_reentrant_input_requires_grad_True", # takes very very long
"test_checkpointing_without_reentrant_memory_savings", # takes very very long
# Category: Inductor (pass on backend="aot_eager")
"test_input_buffer_accum", # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267
"test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173
# Category: FakeTensor
@ -3861,6 +4135,7 @@ known_failing_tests = {
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
"test_checkpointing_without_reentrant_custom_function_works", # ctx.saved_tensors are cached by CA
# Category: Subclasses
"test_dtensor_basic",
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent",

View File

@ -134,7 +134,7 @@ class Op:
ops = OpNamespace()
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"]
_impure_targets = OrderedSet(
[
call_hook,
@ -206,7 +206,13 @@ class AutogradCompilerInstance:
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
self.symnode_proxy_lookup = {}
args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
(
args_proxy,
self.sizes_proxy,
self.scalars_proxy,
self.hooks_proxy,
self.packed_data_proxy,
) = (
self.fx_tracer.create_proxy("placeholder", name, (), {})
for name in _graph_placeholders
)
@ -268,7 +274,12 @@ class AutogradCompilerInstance:
self.stack.enter_context(
torch.fx.experimental.symbolic_shapes._suppress_guards(env)
)
return str(CompileContext.current_compile_id()), inputs, sizes, scalars
return (
str(CompileContext.current_compile_id()),
inputs,
sizes,
scalars,
)
def log_compile_reasons(
self,
@ -567,6 +578,19 @@ class AutogradCompilerInstance:
kwargs,
)
def unpack_hook(self, hook_id, data_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
data = self.packed_data_proxy[data_id] # type: ignore[index]
proxy = self.proxy_call_hook(
hook,
data,
hook_type="unpack_hook",
)
out = self.allocate_dummy()
self.bind_objects_to_proxies([out], [proxy])
return out
def tensor_pre_hook(self, inputs, hook_id, i: int):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
@ -706,6 +730,9 @@ class AutogradCompilerInstance:
after = len(self.fx_tracer.graph.nodes)
verbose_log.debug("DCE removed %d nodes", before - after)
def create_graph_module(self, id):
return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id)
def end_capture(self, outputs):
self.fx_tracer.create_proxy(
"call_function",
@ -745,6 +772,7 @@ class AutogradCompilerInstance:
).print_readable(print_output=False),
)
self.rename_aot_dispatcher_nodes()
self.delay_unpack_hook_nodes()
self.reorder_tensor_pre_hook_nodes()
self.reorder_pre_hook_nodes_to_schedule_asap()
self.reorder_accumulate_grad_nodes()
@ -763,9 +791,7 @@ class AutogradCompilerInstance:
# should prevent these ops from going into the CA graph.
self.dce()
graph = GraphModule(
self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}"
)
graph = self.create_graph_module(f"CompiledAutograd{self.id}")
set_locals_to_steal(graph, ["inputs"])
lazy_graph_code = lazy_format_graph_code(
"Compiled autograd graph",
@ -781,7 +807,7 @@ class AutogradCompilerInstance:
payload_fn=lambda: graph.print_readable(print_output=False),
)
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs):
global in_compiled_autograd_region
try:
in_compiled_autograd_region = True
@ -789,7 +815,7 @@ class AutogradCompilerInstance:
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
with _disable(), make_compile_context(self.id):
return compiled_fn(inputs, sizes, scalars, hooks)
return compiled_fn(inputs, sizes, scalars, hooks, packed_inputs)
finally:
in_compiled_autograd_region = False
@ -938,6 +964,19 @@ class AutogradCompilerInstance:
if getitem_node is not None:
arg.append(getitem_node)
def delay_unpack_hook_nodes(self):
"""
We can delay unpack hooks until they are needed, even later than in the eager autograd engine.
"""
for node in self.fx_tracer.graph.find_nodes(
op="call_function", target=call_hook
):
if node.kwargs.get("hook_type", None) != "unpack_hook":
continue
first_user = min(node.users)
first_user.prepend(node)
def reorder_tensor_pre_hook_nodes(self):
"""
Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed

View File

@ -1334,6 +1334,7 @@ auto Engine::execute(
!AnomalyMode::is_enabled(),
"compiled_autograd does not support AnomalyMode")
GraphTaskGuard guard(graph_task);
CheckpointValidGuard cpvguard(graph_task);
return (*compiled_autograd)(
graph_root, *graph_task, accumulate_grad, outputs);
}

View File

@ -46,6 +46,15 @@ at::Tensor PySavedVariableHooks::call_unpack_hook() {
// unpack_hook_ will be manually decrefed when the saved variable is released
}
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
PySavedVariableHooks::retrieve_unpack_hook_data() const {
Py_INCREF(unpack_hook_);
Py_INCREF(data_);
return std::make_pair(
c10::SafePyObject(unpack_hook_, getPyInterpreter()),
c10::SafePyObject(data_, getPyInterpreter()));
}
// NOLINTNEXTLINE(bugprone-exception-escape)
PySavedVariableHooks::~PySavedVariableHooks() {
// If python is already dead, leak the wrapped python objects

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ATen.h>
#include <c10/core/SafePyObject.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/python_variable.h>
@ -17,6 +18,8 @@ struct PySavedVariableHooks : public SavedVariableHooks {
void call_pack_hook(const at::Tensor& tensor) override;
at::Tensor call_unpack_hook() override;
~PySavedVariableHooks() override;
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
retrieve_unpack_hook_data() const override;
private:
PyObject* pack_hook_;

View File

@ -59,6 +59,7 @@ SavedVariable::SavedVariable(
if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
save_metadata(variable);
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
TORCH_INTERNAL_ASSERT(!data_.defined());
return;
}
@ -134,9 +135,14 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
// We want grad_fn here to provide the most helpful debug message to the user
// if versions don't match
auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
: !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
: grad_fn_;
std::shared_ptr<Node> grad_fn;
if (is_inplace_on_view_) {
grad_fn = weak_grad_fn_.lock();
} else if (!hooks_) {
grad_fn = saved_original_ ? data_.grad_fn() : nullptr;
} else {
grad_fn = grad_fn_;
}
if (!is_leaf_ && !grad_fn) {
// This issue was introduced when we added logic to save the original

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
@ -53,6 +54,15 @@ class TORCH_API SavedVariable {
return (bool)hooks_;
}
// Used by compiled autograd
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
retrieve_unpack_hook_data() const {
if (!hooks_) {
return std::nullopt;
}
return hooks_->retrieve_unpack_hook_data();
}
private:
// This field contains either:
// 1. the variable to save

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/core/SafePyObject.h>
namespace torch::autograd {
@ -8,6 +9,11 @@ struct TORCH_API SavedVariableHooks {
virtual void call_pack_hook(const at::Tensor& tensor) = 0;
virtual at::Tensor call_unpack_hook() = 0;
virtual ~SavedVariableHooks() = default;
virtual std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
retrieve_unpack_hook_data() const {
throw std::runtime_error(
"Compiled Autograd only supports python saved tensor hooks ");
}
};
} // namespace torch::autograd

View File

@ -17,6 +17,78 @@
namespace torch::dynamo::autograd {
using namespace torch::autograd;
// This is a layer of indirection for calling methods on the Python
// AutogradCompilerInstance (referred to as the "py_compiler") from
// libtorch_cpu (where Python is not available).
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
// overrides the methods to do the actual calls back to Python.
struct TORCH_API PyCompilerInterface {
PyCompilerInterface() = default;
PyCompilerInterface(const PyCompilerInterface&) = delete;
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
PyCompilerInterface(PyCompilerInterface&&) = delete;
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
virtual ~PyCompilerInterface() = default;
// Invokes py_compiler.bind_function
virtual std::string bind_function(
PyObject* py_compiler,
const std::string& fn_name,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
functional_apply_t fn,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function = false,
bool is_traceable = true) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
// output_metadata)
virtual variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_prologue(
PyObject* py_compiler,
const variable_list& inputs,
const at::TensorGeometry& base,
const at::TensorGeometry& view) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_epilogue(
PyObject* py_compiler,
const std::vector<bool>& needs_input_grad,
const at::Tensor& result,
const variable_list& res,
const at::Tensor& grad_slice) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual at::Tensor call_unpack(
PyObject* py_compiler,
std::optional<size_t> hook_id,
size_t hook_input_id) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
};
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
TORCH_API void setPyCompilerInterface(
std::unique_ptr<PyCompilerInterface>&& impl);
TORCH_API void resetPyCompilerInterface();
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
// symbol resolution issues. Instead requiring downstream users to include
// engine.h to access collect_input_metadata, we provide it here (with a
// different name to avoid ambigous symbols...)
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges);
struct SizeInput {
// Note: int value is still needed when dynamic to pass as an arg
enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 };
@ -154,19 +226,22 @@ struct TensorArgs {
}
TensorArg& lookup(const SavedVariable& sv) {
auto it = _saved_variables.find(&sv);
TORCH_INTERNAL_ASSERT(it != _saved_variables.end());
if (auto it = _saved_variables.find(&sv); it != _saved_variables.end()) {
// unpacked before graph
return *it->second;
}
// unpacked in graph
auto it2 = _saved_variables_proxies.find(&sv);
TORCH_INTERNAL_ASSERT(it2 != _saved_variables_proxies.end());
return *it2->second;
}
TensorArg& add(const at::Tensor& tensor) {
return lookup(tensor, true);
}
TensorArg& add(const SavedVariable& sv, const std::shared_ptr<Node>& node) {
// TODO(jansel): Here we unpack the SavedVariable exactly once. This might
// fire SavedTensor hooks. In the future we should try to put saved tensor
// hooks into the graph.
// no unpack hooks in this codepath
at::Tensor tensor = sv.unpack(node);
TensorArg& arg = add(tensor);
_saved_variables.emplace(&sv, &arg);
@ -185,6 +260,7 @@ struct TensorArgs {
// Every TensorArg from this is actually owned by _args (or _undefined) and
// that's why we have an un-owned pointer here.
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables;
std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables_proxies;
TensorArg _undefined;
uint32_t _next_id = 1; // id=0 used by _undefined
};
@ -245,6 +321,11 @@ struct AutogradCompilerCall {
return hooks.size() - 1;
}
size_t emplace_packed_input(c10::SafePyObject&& input) {
packed_inputs.emplace_back(std::move(input));
return packed_inputs.size() - 1;
}
void set_active_node_call_idx(size_t node_call_idx) {
active_node_call_idx = node_call_idx;
}
@ -255,10 +336,13 @@ struct AutogradCompilerCall {
LiftedIValueArgs lifted_ivalue_args;
std::vector<int64_t> dyn_size_inputs;
std::vector<c10::SafePyObject> hooks;
std::vector<c10::SafePyObject> packed_inputs;
NodeCalls node_calls;
SizeInput::DynType default_dyn_type;
// NodeCall id of each size, only when verbose logging is enabled
std::vector<uint32_t> size_input_origins;
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
sv_to_hooks;
};
class CompiledNodeArgs {
@ -285,9 +369,20 @@ class CompiledNodeArgs {
collect(_compiler.tensor_args.add(t));
}
void collect(const SavedVariable& sv, bool is_output) {
if (auto hook_data = sv.retrieve_unpack_hook_data();
hook_data.has_value()) {
// hooks, unpack in graph
auto& [hook, packed_input] = hook_data.value();
size_t hook_id = _compiler.emplace_hook(std::move(hook));
// rely on dynamo to dedup packed tensors from unpacked tensors
size_t input_id = _compiler.emplace_packed_input(std::move(packed_input));
_compiler.sv_to_hooks.emplace(&sv, std::make_pair(hook_id, input_id));
} else {
// no hooks, unpack now
collect(
_compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
}
}
void collect(const c10::SymInt& t) {
_compiler.add_size_input(t);
}
@ -655,6 +750,18 @@ class SwapSavedVariables {
}
void before(SavedVariable& t) {
if (auto it = compiler.sv_to_hooks.find(&t);
it != compiler.sv_to_hooks.end()) {
const auto& pyinterface =
torch::dynamo::autograd::getPyCompilerInterface();
auto proxy_tensor = pyinterface->call_unpack(
get_py_compiler(), it->second.first, it->second.second);
stashed_variables.save(&t, std::move(t));
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
t = SavedVariable(proxy_tensor, false);
at::SavedTensorDefaultHooks::set_tracing(prior);
} else {
// no hooks, was already unpacked
TensorArg& arg = compiler.tensor_args.lookup(t);
stashed_variables.save(&t, std::move(t));
if (arg.defined()) {
@ -664,6 +771,7 @@ class SwapSavedVariables {
at::SavedTensorDefaultHooks::set_tracing(prior);
}
}
}
void after(SavedVariable& t) {
stashed_variables.restore(&t);
}
@ -1370,73 +1478,6 @@ struct PackedArgs {
int64_t idx = 0;
};
// This is a layer of indirection for calling methods on the Python
// AutogradCompilerInstance (referred to as the "py_compiler") from
// libtorch_cpu (where Python is not available).
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
// overrides the methods to do the actual calls back to Python.
struct TORCH_API PyCompilerInterface {
PyCompilerInterface() = default;
PyCompilerInterface(const PyCompilerInterface&) = delete;
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
PyCompilerInterface(PyCompilerInterface&&) = delete;
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
virtual ~PyCompilerInterface() = default;
// Invokes py_compiler.bind_function
virtual std::string bind_function(
PyObject* py_compiler,
const std::string& fn_name,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
functional_apply_t fn,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function = false,
bool is_traceable = true) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
// output_metadata)
virtual variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_prologue(
PyObject* py_compiler,
const variable_list& inputs,
const at::TensorGeometry& base,
const at::TensorGeometry& view) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_epilogue(
PyObject* py_compiler,
const std::vector<bool>& needs_input_grad,
const at::Tensor& result,
const variable_list& res,
const at::Tensor& grad_slice) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
};
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
TORCH_API void setPyCompilerInterface(
std::unique_ptr<PyCompilerInterface>&& impl);
TORCH_API void resetPyCompilerInterface();
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
// symbol resolution issues. Instead requiring downstream users to include
// engine.h to access collect_input_metadata, we provide it here (with a
// different name to avoid ambigous symbols...)
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges);
} // namespace torch::dynamo::autograd
template <>

View File

@ -203,6 +203,16 @@ struct PyCompilerInterfaceImpl : PyCompilerInterface {
auto output = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
return toTensorList(output);
}
at::Tensor call_unpack(
PyObject* py_compiler,
std::optional<size_t> hook_id,
size_t hook_input_id) override {
py::handle handle(py_compiler);
py::object proxy = handle.attr("unpack_hook")(hook_id, hook_input_id);
auto tmp = py::cast<std::optional<at::Tensor>>(proxy);
TORCH_INTERNAL_ASSERT(tmp.has_value());
return tmp.value();
}
};
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
@ -213,7 +223,7 @@ static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
return pyinput;
}
static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
static PyObject* convert_pyobj_list(std::vector<c10::SafePyObject>& inputs) {
// inplace, consumes the input hooks
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
for (const auto i : c10::irange(inputs.size())) {
@ -654,7 +664,7 @@ static PyObject* wrap_string_list(const std::vector<std::string>& strs) {
return pystrs;
}
std::string unwrap_string(PyObject* pystr) {
static std::string unwrap_string(PyObject* pystr) {
TORCH_INTERNAL_ASSERT(PyUnicode_Check(pystr));
const char* str = PyUnicode_AsUTF8(pystr);
TORCH_INTERNAL_ASSERT(str != nullptr);
@ -796,7 +806,8 @@ static CacheNode* _compiled_autograd_impl(
THPObjectPtr* graph_arg_inputs,
THPObjectPtr* graph_arg_sizes,
THPObjectPtr* graph_arg_ivalue_args,
THPObjectPtr* graph_arg_hooks) {
THPObjectPtr* graph_arg_hooks,
THPObjectPtr* graph_arg_packed_inputs) {
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
std::vector<std::shared_ptr<Node>> worklist{graph_root};
AutogradCompilerCall compiler_call(get_default_dyn_type());
@ -1052,7 +1063,8 @@ static CacheNode* _compiled_autograd_impl(
*graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
*graph_arg_ivalue_args =
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
*graph_arg_hooks = convert_hook_list(compiler_call.hooks);
*graph_arg_hooks = convert_pyobj_list(compiler_call.hooks);
*graph_arg_packed_inputs = convert_pyobj_list(compiler_call.packed_inputs);
return cache;
}
@ -1093,6 +1105,7 @@ static variable_list compiled_autograd(
THPObjectPtr sizes;
THPObjectPtr ivalue_args;
THPObjectPtr hooks;
THPObjectPtr packed_inputs;
CacheNode* cache = _compiled_autograd_impl(
graph_root,
graph_task,
@ -1101,7 +1114,8 @@ static variable_list compiled_autograd(
&inputs,
&sizes,
&ivalue_args,
&hooks);
&hooks,
&packed_inputs);
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
cache->runtime_wrapper.get(),
@ -1110,6 +1124,7 @@ static variable_list compiled_autograd(
sizes.get(),
ivalue_args.get(),
hooks.get(),
packed_inputs.get(),
NULL)));
variable_list outputs = THPVariable_UnpackList(pyresult);
TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());