functional compiled autograd (#144707)

This PR squashes together the following commits:

https://github.com/pytorch/pytorch/pull/144115
https://github.com/pytorch/pytorch/pull/143417
https://github.com/pytorch/pytorch/pull/143405
https://github.com/pytorch/pytorch/pull/143387
https://github.com/pytorch/pytorch/pull/143304
https://github.com/pytorch/pytorch/pull/143296

This is a refactor of compiled autograd to use "functional autograd". The end goal is that it gets compiled autograd's initial capture to stop specializing on Tensor metadata, therefore allowing compiled autograd to better handle Tensor subclasses.

For more information, please read the commit messages for each PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144707
Approved by: https://github.com/bdhirsh, https://github.com/xmfan, https://github.com/jansel
This commit is contained in:
rzou
2025-01-27 05:20:56 +00:00
committed by PyTorch MergeBot
parent 87fdadde1d
commit ea141d8134
28 changed files with 1809 additions and 223 deletions

View File

@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry {
has_symbolic_sizes_strides_(
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
explicit TensorGeometry(
std::vector<at::SymInt> sizes,
std::vector<at::SymInt> strides,
at::SymInt storage_offset)
: sizes_(std::move(sizes)),
strides_(std::move(strides)),
storage_offset_(std::move(storage_offset)) {
recompute();
}
// true if the tensor is contiguous
bool is_contiguous() const;

View File

@ -138,6 +138,7 @@ core_trainer_sources = [
"torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/utils/warnings.cpp",
"torch/csrc/autograd/jit_decomp_interface.cpp",
"torch/csrc/dynamo/compiled_autograd.cpp",
"torch/csrc/jit/frontend/name_mangler.cpp",
"torch/csrc/jit/ir/type_hashing.cpp",
"torch/csrc/jit/serialization/pickler.cpp",

View File

@ -121,23 +121,30 @@ class _multiply_invoke(torch.nn.Module):
out.backward(grad_out)
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(x.grad, grad_out * grad_out)
self.assertExpectedInline(
actual,
"""\
if backend in ["aot_eager", "inductor"]:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list):
l_inputs_ = L_inputs_
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
new_grad: "f32[s0]" = torch.clone(getitem)
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None
getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None
result: "f32[s0]" = getitem * getitem; getitem = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None
getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
new_grad: "f32[2]" = torch.clone(getitem_5)
result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
""",
)
)
graph = None
@ -162,7 +169,7 @@ class GraphModule(torch.nn.Module):
gm, backend=inner_compiler, fullgraph=True, dynamic=True
)
for backend in ["eager", "aot_eager", "inductor"]:
for backend in ["inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)
@ -187,26 +194,33 @@ class GraphModule(torch.nn.Module):
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(obj.counter, 1)
self.assertEqual(x.grad, grad_out + grad_out)
self.assertExpectedInline(
actual,
"""\
if backend in ["aot_eager", "inductor"]:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"):
def forward(self, L_inputs_ : list, L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s1)"):
l_inputs_ = L_inputs_
l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
new_grad: "f32[s0]" = torch.clone(getitem)
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None
getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None
add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_3); getitem_3 = None
getitem_5: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
result: "f32[s0]" = getitem * getitem; getitem = None
new_grad: "f32[2]" = torch.clone(getitem_5)
new_grad_1: "f32[s0]" = torch.clone(result); result = None
add: "Sym(s1 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
result: "f32[2]" = getitem_5 * getitem_5; getitem_5 = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1, add)
""",
)
)
out = fn(x, y)
out.backward(grad_out)

View File

@ -22,6 +22,7 @@ from torch import _inductor as inductor
from torch._dynamo import compiled_autograd, config
from torch._dynamo.backends.debugging import aot_eager
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import normalize_gm
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
@ -2821,8 +2822,11 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
opt_bwd()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
# always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
# Compiled autograd's initial capture lifts custom C++ autograd::Function bwd instead of tracing
# into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops.
# In the future, we can consider having a cpu scalar movement pass sometime after we trace
# into the custom C++ autograd::Function (like in AOTDispatcher)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
def test_logs(self):
logs, ctx = logs_to_string(
@ -2941,12 +2945,11 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
expected_logs = [
"code: CompiledFunctionBackward (NodeCall 2)",
"code: CompiledFunctionBackward0 (NodeCall 2)",
"aot0_primals_3",
"aot0_relu",
"aot0_le",
"aot0_permute_2",
"code: CompiledFunctionBackward0 (NodeCall 2)",
"aot0_tangents_1",
"aot0_full_default",
"aot0_where",
"aot0_mm",
@ -2996,20 +2999,17 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
expected_logs = [
"CompiledFunctionBackward1",
"aot1_tangents_1",
"aot1_sin_1",
"aot1_primals_2",
"aot1_neg",
"aot0_tangents_2",
"aot1_cos_1",
"aot1_primals_1",
"aot0_tangents_1",
"CompiledFunctionBackward0",
"aot0_sin_1",
"aot0_neg",
"aot0_sin",
"aot0_mul",
"aot0_cos_1",
"aot0_mul_1",
"aot0_cos",
"aot0_add",
]
@ -3154,6 +3154,120 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)
def test_tensor_subclass_basic(self):
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
lib.define("to_twotensor(Tensor a, Tensor b) -> Tensor")
lib.define("from_twotensor(Tensor c) -> (Tensor, Tensor)")
def to_twotensor_backward(ctx, grad):
return torch.ops.mylib.from_twotensor(grad)
def from_twotensor_backward(ctx, grad_a, grad_b):
raise AssertionError("shouldn't get hit")
torch.library.register_autograd(
"mylib::to_twotensor", to_twotensor_backward, lib=lib
)
torch.library.register_autograd(
"mylib::from_twotensor", from_twotensor_backward, lib=lib
)
@torch.library.register_torch_dispatch(
"mylib::to_twotensor", TwoTensorMode, lib=lib
)
def _(_0, _1, _2, args, kwargs):
assert not kwargs
a, b = args
return TwoTensor(a.clone(), b.clone())
@torch.library.register_torch_dispatch(
"mylib::from_twotensor", TwoTensor, lib=lib
)
def _(_0, _1, _2, args, kwargs):
assert not kwargs
(c,) = args
return c.a.clone(), c.b.clone()
@torch.compile(backend="aot_eager", fullgraph=True)
def fn(x):
return x * x + 2
param1 = torch.randn(4, 4, requires_grad=True)
param2 = torch.randn(4, 4, requires_grad=True)
with TwoTensorMode():
x = torch.ops.mylib.to_twotensor(param1, param2)
inner_compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager")
graphs = []
def compiler_fn(gm):
graphs.append(gm)
return inner_compiler_fn(gm)
with compiled_autograd._enable(compiler_fn):
res = fn(x)
res.sum().backward()
self.assertEqual(param1.grad, 2 * param1)
self.assertEqual(param2.grad, 2 * param2)
self.assertEqual(len(graphs), 1)
graph_code = normalize_gm(graphs[0].print_readable(print_output=False))
# The graph should have make_subclass calls in it.
self.assertExpectedInline(
graph_code,
"""\
class CompiledAutograd0(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks):
getitem = inputs[0]
getitem_1 = inputs[1]
getitem_2 = inputs[2]
getitem_3 = inputs[3]
getitem_4 = inputs[4]; inputs = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [], True)]); getitem = None
getitem_5 = validate_outputs[0]; validate_outputs = None
sum_backward0 = torch__dynamo_compiled_autograd_ops_SumBackward0([getitem_5], [True], [4, 4]); getitem_5 = None
getitem_6 = sum_backward0[0]; sum_backward0 = None
validate_outputs_1 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_6], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], True)]); getitem_6 = None
getitem_7 = validate_outputs_1[0]; validate_outputs_1 = None
getitem_8 = hooks[0]; getitem_8 = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], getitem_7); getitem_1 = getitem_2 = getitem_7 = None
aot0_primals_1 = call_aot_bwd_prologue[0]
aot0_primals_2 = call_aot_bwd_prologue[1]
aot0_tangents_1 = call_aot_bwd_prologue[2]
aot0_tangents_2 = call_aot_bwd_prologue[3]; call_aot_bwd_prologue = None
aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None
aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None
aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None
aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None
make_subclass = torch__dynamo_compiled_autograd_make_subclass(aot0_add_2, aot0_add_3); aot0_add_2 = aot0_add_3 = None
getitem_13 = hooks[1]; hooks = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_13, (), make_subclass); getitem_13 = make_subclass = None
getitem_16 = call_backward[0]
getitem_17 = call_backward[1]; call_backward = None
validate_outputs_2 = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem_16, getitem_17], [((None, None, device(type='cpu'), 6, 0, None), [4, 4], False), ((None, None, device(type='cpu'), 6, 0, None), [4, 4], False)]); getitem_16 = getitem_17 = None
getitem_19 = validate_outputs_2[0]
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_19); getitem_4 = getitem_19 = accumulate_grad__1 = None
getitem_20 = validate_outputs_2[1]; validate_outputs_2 = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_20); getitem_3 = getitem_20 = accumulate_grad_ = None
_exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None
return []
""", # noqa: B950
)
# https://github.com/pytorch/pytorch/issues/138920
def test_compiled_autograd_does_not_specialize_on_bw_symints(self):
class Mod(torch.nn.Module):
@ -3247,7 +3361,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
# because we ignore all of these guards anyway in CA.
# Once we stop using make_fx in CA, we won't have to worry about this specialization.
view_nodes = graphs[1].graph.find_nodes(
op="call_function", target=torch.ops.aten.view.default
op="call_function", target=torch.ops.aten.reshape.default
)
# First 2 view nodes have a first argument that is a SymInt, not an int burned into the graph
self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node))
@ -3640,6 +3754,7 @@ known_failing_tests = {
"test_tp_compile_comm_reordering",
"test_unwrap_async_collective_tensor_tangent",
# Uncategorized
"test_not_implemented_grad", # Dynamo changes the types of exceptions
}
if not HAS_CUDA:

View File

@ -337,7 +337,9 @@ class DistributedPatternTests(TestCase):
self.assertEqual(fw_cnt.frame_count, 1)
self.assertEqual(fw_cnt.op_count, 5)
self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None
self.assertEqual(bw_cnt.op_count, 48)
self.assertEqual(
bw_cnt.op_count, 72
) # Number of ops in the Dynamo-produced graphs
def test_module_backward_hooks_aot(self):
m1, inp1 = init_module_bw_hooks(True)

View File

@ -107,6 +107,17 @@ static variable_list ${op}_apply_functional(
${body}
return grad_inputs;
}
inline variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args)
{
#ifdef C10_MOBILE
TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile");
#else
auto packed_args = PackedArgs(args);
auto needs_input_grad = packed_args.unpack<std::array<bool, ${num_inputs}>>();
${unpack_ivalues}
return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args});
#endif
}
variable_list ${op}::apply(variable_list&& grads) {
${thread_lock}
@ -120,11 +131,35 @@ void ${op}::compiled_args(CompiledNodeArgs& args) {
${compiled_args}
}
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
${apply_with_saved_before}
variable_list result = apply(variable_list(grads));
${apply_with_saved_after}
return result;
#ifdef C10_MOBILE
TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile");
#else
${apply_with_saved_before}
static bool called = false;
if (!called) {
called = true;
${compute_schema}
const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface();
pyinterface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema);
}
variable_list output_result;
PackedArgs packed_args;
${asserts}
${unpacks}
${compute_needs_input_grad}
packed_args.pack(needs_input_grad);
${get_packed_args}
output_result = compiled_autograd_apply_functional(packed_args, next_edges(), saved, grads, name());
${apply_with_saved_after}
return output_result;
#endif
}
"""
)
@ -993,14 +1028,38 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
f"{T} {x}"
for T, x in zip(apply_functional_args_ref_types, apply_functional_args)
]
get_packed_args = "\n".join(
f"packed_args.pack({name});" for name in apply_functional_args
)
unpack_ivalues = []
for typ, name in zip(apply_functional_args_ref_types, apply_functional_args):
if typ.endswith("&"):
typ = typ[:-1]
unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();")
schema_args = [f"std::array<bool, {len(input_name_to_idx)}>"]
for typ in apply_functional_args_ref_types:
if typ.endswith("&"):
typ = typ[:-1]
if typ.startswith("const"):
typ = typ[5:]
schema_args.append(typ.strip())
compute_schema = ["std::vector<at::TypePtr> schema = {"]
for schema_arg in schema_args:
compute_schema.append(
f" torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type(),"
)
compute_schema.append("};")
return template.substitute(
unpacks="\n".join(unpack),
op=info.op,
compute_schema="\n".join(compute_schema),
apply_functional_args=apply_functional_args,
apply_functional_args_signature=apply_functional_args_signature,
compute_needs_input_grad=compute_needs_input_grad,
num_inputs=len(input_name_to_idx),
unpack_ivalues="\n".join(unpack_ivalues),
compute_index_ranges=compute_index_ranges,
saved_variables=saved_variables,
release_variables=release_variables,
@ -1015,4 +1074,5 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
compiled_args=compiled_args,
apply_with_saved_before=apply_with_saved_before,
apply_with_saved_after=apply_with_saved_after,
get_packed_args=get_packed_args,
)

View File

@ -15,6 +15,30 @@ using at::TensorList;
namespace torch::autograd::generated {
static at::IValue compute_output_metadata(const torch::autograd::edge_list& next_edges) {
auto output_metadata = torch::dynamo::autograd::IValuePacker<
std::vector<std::optional<InputMetadata>>>::pack(
torch::dynamo::autograd::get_input_metadata(next_edges));
return output_metadata;
}
static C10_NOINLINE variable_list compiled_autograd_apply_functional(
const PackedArgs& packed_args,
const edge_list& next_edges,
SwapSavedVariables& saved,
const variable_list& grads,
const std::string& name) {
auto output_metadata = compute_output_metadata(next_edges);
const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface();
return pyinterface->call_function(
saved.get_py_compiler(),
"apply_functional",
name,
grads,
packed_args.vec(),
output_metadata);
}
${autograd_function_definitions}
} // namespace torch::autograd::generated

View File

@ -4,10 +4,11 @@ import functools
import itertools
import operator
import time
from collections import defaultdict
from collections import Counter, defaultdict
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree
from torch._dynamo.external_utils import (
call_backward,
call_hook,
@ -65,6 +66,50 @@ def maybe_clone(x):
return x
# We lazily bind "functional backward" variants for PyTorch built-in autograd
# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0
# Each "functional backward" is bound the first time the node's apply_with_saved
# function is called. It's possible to avoid lazy binding and instead bind
# all of this upfront (perhaps at import time) via codegen changes.
class OpNamespace:
def __init__(self):
self.custom_function_name_counter: Counter[str] = Counter()
def add(self, name, fn, is_custom_function=False):
if is_custom_function:
name = "CppNode" + name
count = self.custom_function_name_counter[name]
self.custom_function_name_counter[name] += 1
name = f"{name}{count}"
else:
assert not hasattr(self, name)
result = Op(name, fn, is_custom_function)
torch._dynamo.allow_in_graph(result)
setattr(self, name, result)
return name
def get(self, name):
return getattr(self, name)
class Op:
def __init__(self, name, fn, is_custom_function):
self.fn = fn
self.is_custom_function = is_custom_function
self.__name__ = name
self.__module__ = "torch._dynamo.compiled_autograd.ops"
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def __repr__(self):
return self.__module__ + "." + self.__name__
ops = OpNamespace()
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
_impure_targets = OrderedSet(
[
@ -137,7 +182,8 @@ class AutogradCompilerInstance:
self.fx_tracer.root = torch.nn.Module()
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = (
self.symnode_proxy_lookup = {}
args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
self.fx_tracer.create_proxy("placeholder", name, (), {})
for name in _graph_placeholders
)
@ -160,7 +206,9 @@ class AutogradCompilerInstance:
)
for idx, val in enumerate(sizes)
]
self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins)
self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins)
for i, symint in enumerate(sizes):
self.symnode_proxy_lookup[symint.node] = self.sizes_proxy[i]
for idx, val in enumerate(scalars):
source = self.source("scalars", idx)
@ -182,7 +230,9 @@ class AutogradCompilerInstance:
)
else:
raise AssertionError("Unexpected scalar type: ", type(val))
self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins)
self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins)
for i, symval in enumerate(scalars):
self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr]
# TODO(jansel): are all these modes needed?
self.stack.enter_context(decompose({}))
@ -197,31 +247,203 @@ class AutogradCompilerInstance:
)
return inputs, sizes, scalars
def proxy_call_aot_backward(
self,
pinputs,
psaved_tensors,
saved_tensors,
pctx,
ctx,
maybe_backward_state_idx,
):
# The AOTBackward call consists of three things: the prologue, the
# backward graph, and the epilogue.
# Our strategy is:
# - allow_in_graph the prologue (in the CA graph and Dynamo graph),
# - copy-paste the backward graph into the CA graph so that CA passes and Dynamo can see it
# - trace directly through the epilogue. Anything that gets baked in is
# constant metadata (for example, metadata about the number of outputs, or removing
# RNG arguments or effect tokens).
# If Dynamo graph capture were better, then we could add a node for the prologue
# into the CA graph and have Dynamo trace into it.
psymints = [self.to_proxy(e) for e in ctx._get_compiled_autograd_symints()]
# NOTE: we should only close over constants
CompiledFunction = ctx._forward_cls
metadata = CompiledFunction.metadata
maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata
del CompiledFunction
@torch._dynamo.allow_in_graph # type: ignore[misc]
def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args):
out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional(
ctx_saved_tensors,
ctx_symints,
metadata,
maybe_subclass_metadata,
*flat_args,
)
return out
pgrads = self.fx_tracer.create_proxy(
kind="call_function",
target=call_aot_bwd_prologue,
args=(
psaved_tensors,
psymints,
*pinputs,
),
kwargs={},
)
pbackward_state = None
if maybe_backward_state_idx is not None:
pbackward_state = self.hooks_proxy[maybe_backward_state_idx] # type: ignore[index]
# Copy-paste the AOT backward graph into the compiled autograd graph
def copy_paste_aot_backward_graph():
def num_inputs(graph):
num_args = 0
for node in graph.nodes:
if node.op == "placeholder":
num_args += 1
continue
else:
break
return num_args
# set up the proxy inputs to ctx._bw_module
# the calling convention is: [*symints, *args (primals and tangents), backward_state]
num_args = num_inputs(ctx._bw_module.graph)
pall_args = [
pgrads[i] for i in range(num_args - int(pbackward_state is not None))
]
# replace the symints with our symints
symints = ctx._get_compiled_autograd_symints()
assert len(symints) == len(ctx.symints)
psymints = [self.to_proxy(e) for e in symints]
pall_args[: len(symints)] = psymints
# Add backward_state
if pbackward_state is not None:
pall_args.append(pbackward_state)
# run over all nodes of the aot_backward graph.
# copy and paste them all into the compiled autograd graph.
args_idx = 0
value_remap = {}
poutputs: Optional[list[torch.fx.Proxy]] = None
for node in ctx._bw_module.graph.nodes:
if node.op == "placeholder":
value_remap[node] = pall_args[args_idx].node
args_idx += 1
elif node.op == "output":
assert len(node.args) == 1
poutputs = [
torch.fx.Proxy(value_remap[n], self.fx_tracer)
if isinstance(n, torch.fx.Node)
else n
for n in node.args[0]
]
elif node.op == "get_attr":
name = node.target
qualname = self.fx_tracer.get_fresh_qualname(name)
setattr(
self.fx_tracer.root, qualname, getattr(ctx._bw_module, name)
)
result = self.fx_tracer.create_node("get_attr", qualname, (), {})
value_remap[node] = result
elif node.op == "call_function":
result = self.fx_tracer.graph.node_copy(
node, lambda n: value_remap[n]
)
value_remap[node] = result
else:
raise AssertionError("shouldn't get here")
assert poutputs is not None
# In general we don't know what the shapes of the outputs are, so allocate
# some dummy sizes for them.
def dummy():
with disable_proxy_modes_tracing():
return torch.zeros(0, 0, 0, 0, 123)
outputs = [
dummy() if isinstance(o, torch.fx.Proxy) else o for o in poutputs
]
self.bind_tensors_to_proxies(outputs, poutputs)
return outputs
outputs = copy_paste_aot_backward_graph()
def proxy_subclass_constructor(subclass_meta, is_runtime, unwrapped_args):
@torch._dynamo.allow_in_graph
def make_subclass(*unwrapped_args):
return subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime)
punwrapped_args = pytree.tree_map(self.to_proxy, unwrapped_args)
poutput = self.fx_tracer.create_proxy(
kind="call_function",
target=make_subclass,
args=tuple(punwrapped_args),
kwargs={},
)
output = self.allocate_dummy()
self.bind_tensors_to_proxies([output], [poutput])
return output
results = torch._functorch._aot_autograd.runtime_wrappers._backward_epilogue_functional(
metadata,
maybe_subclass_metadata,
outputs,
make_subclass_override=proxy_subclass_constructor,
)
presults = pytree.tree_map(self.to_proxy, results)
return presults
def proxy_call_backward(
self,
inputs,
output_metadatas,
saved_tensors,
backward_idx: int,
ctx: torch.autograd.function.BackwardCFunction,
maybe_backward_state_idx: Optional[int],
):
assert self.hooks_proxy is not None
backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index]
proxies = self.fx_tracer.create_proxy(
kind="call_function",
target=call_backward,
args=(
backward_c_function,
self.to_proxy(saved_tensors),
*self.to_proxy(inputs),
),
kwargs={},
)
pctx = self.hooks_proxy[backward_idx] # type: ignore[index]
pinputs = self.to_proxy(inputs)
psaved_tensors = self.to_proxy(saved_tensors)
if hasattr(ctx._forward_cls, "_aot_id"): # type: ignore[attr-defined]
# AOT backward
proxies = self.proxy_call_aot_backward(
pinputs,
psaved_tensors,
saved_tensors,
pctx,
ctx,
maybe_backward_state_idx,
)
else:
proxies = self.fx_tracer.create_proxy(
kind="call_function",
target=call_backward,
args=(
pctx,
psaved_tensors,
*pinputs,
),
kwargs={},
)
assert proxies is not None
with disable_proxy_modes_tracing():
# create fake Tensors
grad_ins: list[Optional[torch.Tensor]] = []
for output_metadata in output_metadatas:
if output_metadata is None:
for idx, output_metadata in enumerate(output_metadatas):
if output_metadata is None or proxies[idx] is None:
grad_ins.append(None)
continue
@ -232,6 +454,71 @@ class AutogradCompilerInstance:
self.bind_tensors_to_proxies(grad_ins, proxies)
return tuple(grad_ins)
def call_copy_slices_prologue(self, inputs, base, view):
args = (
inputs,
base.sizes(),
base.strides(),
base.storage_offset(),
view.sizes(),
view.strides(),
view.storage_offset(),
)
return self.proxy_call(copy_slices_prologue, args, [None] * 3)
def call_copy_slices_epilogue(self, needs_input_grad, result, res, grad_slice):
return self.proxy_call(
copy_slices_epilogue,
(needs_input_grad, result, res, grad_slice),
[None] * len(needs_input_grad),
)
def allocate_dummy(self):
with disable_proxy_modes_tracing():
# Weird quantity so it's easy to grep
return torch.zeros([0, 123456789])
def bind_function(self, fn_name, fn, is_custom_function):
"""Binds ops.fn_name = fn"""
return ops.add(fn_name, fn, is_custom_function)
def apply_functional(self, fn_name, grads, args, output_metadata):
"""Proxies a call to ops.fn_name(grads, *args) into the graph"""
op = ops.get(fn_name)
return self.proxy_call(op, (grads, *args), output_metadata)
def proxy_call(self, fn, args, output_metadata):
"""Proxies a call to fn(*args) into the graph"""
flat_args, _ = pytree.tree_flatten(args)
proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args)
proxy_out = self.fx_tracer.create_proxy(
"call_function", fn, args=proxy_args, kwargs={}
)
result = [self.allocate_dummy() for _ in output_metadata]
self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(len(result))])
return result
def validate_outputs(self, _, outputs, args, output_metadata):
"""Proxies a call to ops.validate_outputs(outputs, *args) into the graph"""
op = ops.get("validate_outputs")
proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args))
new_proxy_outputs = self.fx_tracer.create_proxy(
"call_function", op, args=proxy_args, kwargs={}
)
assert len(output_metadata) == len(outputs)
self.bind_tensors_to_proxies(outputs, new_proxy_outputs)
return outputs
def accumulate(self, old_var, new_var):
old_var_proxy = self.to_proxy(old_var)
new_var_proxy = self.to_proxy(new_var)
proxy_out = self.fx_tracer.create_proxy(
"call_function", torch.add, args=(old_var_proxy, new_var_proxy), kwargs={}
)
result = self.allocate_dummy()
self.bind_tensors_to_proxies([result], [proxy_out])
return result
def proxy_call_hook(self, hook, *args, **kwargs):
return self.fx_tracer.create_proxy(
"call_function",
@ -314,6 +601,7 @@ class AutogradCompilerInstance:
assert nodes[first_getitem_idx] == inputs_users[0]
last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
assert nodes[last_getitem_idx] == inputs_users[-1]
# getitem nodes on inputs
for i, node in enumerate(inputs_users):
if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
has_cuda_inputs = True
@ -323,9 +611,16 @@ class AutogradCompilerInstance:
is_scalar = len(node.meta["val"].size()) == 0
if is_cpu and is_scalar:
node_users = list(node.users.keys())
# We can only move the cpu scalar if it is not exposed to user code.
if all(
isinstance(user.target, torch._ops.OpOverload)
and user.target.namespace in ("prims", "aten")
(
isinstance(user.target, torch._ops.OpOverload)
and user.target.namespace in ("prims", "aten")
)
or (
isinstance(user.target, Op)
and not user.target.is_custom_function
)
for user in node_users
):
# all users are prims/aten, can move safely
@ -335,6 +630,7 @@ class AutogradCompilerInstance:
# this is to handle the case where cudagraphs is enabled on a cpu-only graph
if has_cuda_inputs:
for node in to_move.values():
verbose_log.debug("Moving node %s from cpu to cuda", node)
node.meta["val"] = node.meta["val"].cuda()
# return runtime indices we need to move to cuda
@ -368,7 +664,10 @@ class AutogradCompilerInstance:
or (node.op == "call_function" and node.target in _impure_targets)
)
before = len(self.fx_tracer.graph.nodes)
self.fx_tracer.graph.eliminate_dead_code(is_impure)
after = len(self.fx_tracer.graph.nodes)
verbose_log.debug("DCE removed %d nodes", before - after)
def end_capture(self, outputs):
self.fx_tracer.create_proxy(
@ -384,6 +683,18 @@ class AutogradCompilerInstance:
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
{},
)
runtime_inputs_to_move: list[int] = []
if snapshot_cudagraph_enabled():
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
# We traced using dummy tensors. Delete all the metadata of the dummy tensors.
# It's probably better to refactor this class to use a different tracer
# than the make_fx tracer, but that is a larger change.
for node in self.fx_tracer.graph.nodes:
for field in ["tensor_meta", "example_value", "val"]:
if field in node.meta:
del node.meta[field]
self.rename_aot_dispatcher_nodes()
self.reorder_tensor_pre_hook_nodes()
self.reorder_pre_hook_nodes_to_schedule_asap()
@ -402,9 +713,6 @@ class AutogradCompilerInstance:
# Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and
# should prevent these ops from going into the CA graph.
self.dce()
runtime_inputs_to_move: list[int] = []
if snapshot_cudagraph_enabled():
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
graph = GraphModule(
self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}"
@ -778,8 +1086,11 @@ class AutogradCompilerInstance:
return [self.to_proxy(x) for x in t]
if isinstance(t, tuple):
return tuple(self.to_proxy(x) for x in t)
# can it be torch.SymInt as the code used to imply?
assert isinstance(t, torch.Tensor)
if isinstance(t, (torch.SymInt, torch.SymFloat)):
return self.symnode_proxy_lookup[t.node]
if not isinstance(t, torch.Tensor):
# constant types like device, dtype, str
return t
proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
return proxy_tensor.proxy
@ -921,3 +1232,39 @@ def reset() -> None:
torch._C._dynamo.compiled_autograd.clear_cache()
global COMPILE_COUNTER
COMPILE_COUNTER = itertools.count()
# Reimplementation of part of CopySlices::apply in Python.
# The shared code is really similar so we're not going to try to deduplicate.
def copy_slices_prologue(
inputs,
base_sizes,
base_strides,
base_storage_offset,
view_sizes,
view_strides,
view_storage_offset,
):
grad = inputs[0]
result = grad.new_empty_strided(base_sizes, base_strides)
assert grad is not None
result.copy_(grad)
offset = view_storage_offset - base_storage_offset
grad_slice = result.as_strided(view_sizes, view_strides, offset)
return [result, grad_slice, grad_slice.clone(memory_format=torch.contiguous_format)]
# Reimplementation of part of CopySlices::apply in Python.
# The shared code is really similar so we're not going to try to deduplicate.
def copy_slices_epilogue(needs_input_grad, result, res, grad_slice):
grad_inputs = [None] * len(needs_input_grad)
for i in range(len(needs_input_grad)):
if needs_input_grad[i]:
if res[i] is None:
continue
if i == 0:
grad_slice.copy_(res[i])
grad_inputs[i] = result
else:
grad_inputs[i] = res[i]
return grad_inputs

View File

@ -116,6 +116,14 @@ def call_backward(
return grads
def normalize_as_list(x: Any) -> list[Any]:
if isinstance(x, tuple):
return list(x)
elif isinstance(x, list):
return x
return [x]
def untyped_storage_size(x: torch.Tensor) -> int:
return x.untyped_storage().size()

View File

@ -72,6 +72,8 @@ def radians(x):
def accumulate_grad(x, new_grad):
if new_grad is None:
return
new_grad = torch.clone(new_grad)
if x.grad is None:
x.grad = new_grad

View File

@ -3276,6 +3276,7 @@ if torch.distributed.is_available():
MOD_INLINELIST = [
"torch._decomp",
"torch._dynamo._trace_wrapped_higher_order_op",
"torch._dynamo.compiled_autograd",
"torch._dynamo.comptime",
"torch._dynamo.polyfills",
"torch._functorch._aot_autograd.subclass_parametrization",

View File

@ -62,7 +62,6 @@ from .traced_function_transforms import aot_dispatch_subclass
from .utils import (
call_func_at_runtime_with_args,
make_boxed_func,
normalize_as_list,
partial_flatten_asdict,
strict_zip,
)
@ -1683,7 +1682,9 @@ def _backward_prologue_functional(
# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants.
def _backward_epilogue_functional(metadata, maybe_subclass_metadata, out):
def _backward_epilogue_functional(
metadata, maybe_subclass_metadata, out, *, make_subclass_override=None
):
# Toss out the backward output tokens
num_bw_tokens = metadata.num_backward_tokens
if num_bw_tokens > 0:
@ -1703,6 +1704,7 @@ def _backward_epilogue_functional(metadata, maybe_subclass_metadata, out):
subclass_metas=maybe_subclass_metadata.grad_input_metas,
included_subclass_symints=True,
is_runtime=True,
make_subclass_override=make_subclass_override,
)
return outs_wrapped
return out
@ -1728,6 +1730,13 @@ class AOTDispatchAutograd:
expected_meta = meta.meta
runtime_type = type(x)
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
# When we're inside compiled autograd's AOTDispatcher step,
# regular Tensors look like FunctionalTensors.
# Tensor subclasses still look like Tensor subclasses though.
if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor):
runtime_type = torch.Tensor
runtime_meta = None
runtime_subclass_keys: Sequence[str] = []
@ -2001,23 +2010,9 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
@staticmethod
def _backward_impl(ctx, all_args):
if ctx._is_compiled_autograd_tracing():
if lazy_backward_info is None:
raise RuntimeError(
"""This compiled backward function was saved by AOTAutogradCache, which does not support
compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`."""
)
bw_module = lazy_backward_info.bw_module
# For compiled autograd, run raw FX graph so that it can be inlined into the larger graph
symints = ctx._get_compiled_autograd_symints()
assert len(symints) == len(ctx.symints)
all_args[: len(symints)] = symints
if backward_state_indices:
assert ctx._compiled_autograd_backward_state.proxy is not None
all_args.append(ctx._compiled_autograd_backward_state)
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
return normalize_as_list(bw_module(*all_args))
assert (
not ctx._is_compiled_autograd_tracing()
), "compiled autograd reimplements this function at proxy_call_aot_backward"
assert (
not backward_state_indices

View File

@ -8,7 +8,7 @@ and this includes tensor subclasses that implement __torch_dispatch__.
import collections
import typing
from collections.abc import Iterable
from typing import Any, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union
import torch
import torch.utils._pytree as pytree
@ -326,6 +326,7 @@ def wrap_tensor_subclasses(
num_fw_outs_saved_for_bw: Optional[int] = None,
included_subclass_symints: bool = False,
is_runtime: bool = False,
make_subclass_override: Optional[Callable] = None,
) -> tuple[Any, ...]:
wrapped_args = []
num_args_tallied = 0
@ -336,9 +337,15 @@ def wrap_tensor_subclasses(
else:
assert isinstance(subclass_meta, SubclassCreationMeta)
assert subclass_meta.included_subclass_symints == included_subclass_symints
wrapped_args.append(
subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime)
)
if make_subclass_override:
wrapped_args.append(
make_subclass_override(subclass_meta, is_runtime, unwrapped_args)
)
else:
wrapped_args.append(
subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime)
)
num_args_tallied += subclass_meta.arg_count
# Note: [Partitioner handling for Subclasses, Part 2]

View File

@ -334,6 +334,9 @@ class FunctionMeta(type):
backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined]
"_compiled_autograd_should_lift", True
)
backward_fn._bw_module = None # type: ignore[attr-defined]
if getattr(cls, "_lazy_backward_info", None):
backward_fn._bw_module = cls._lazy_backward_info.bw_module # type: ignore[attr-defined]
cls._backward_cls = backward_fn
super().__init__(name, bases, attrs)

View File

@ -503,6 +503,16 @@ void check_variable_result(
}
}
AutogradContext::AutogradContext(PackedArgs& packed_args) {
saved_data = packed_args.unpack_saved_data();
saved_variables_override_ = packed_args.unpack<variable_list>();
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
materialize_grads_ = packed_args.unpack<bool>();
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
has_freed_buffers_ = packed_args.unpack<bool>();
needs_input_grad_override_ = packed_args.unpack<std::vector<bool>>();
}
void AutogradContext::save_for_backward(variable_list to_save) {
to_save_ = std::move(to_save);
}
@ -527,6 +537,9 @@ void AutogradContext::save_variables() {
variable_list AutogradContext::get_saved_variables() const {
TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE);
if (saved_variables_override_.has_value()) {
return *saved_variables_override_;
}
variable_list saved;
saved.reserve(saved_variables_.size());
auto ptr = grad_fn_.lock();
@ -538,6 +551,9 @@ variable_list AutogradContext::get_saved_variables() const {
}
bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
if (needs_input_grad_override_.has_value()) {
return needs_input_grad_override_.value().at(output_edge_index);
}
auto ptr = grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
return ptr->task_should_compute_output(output_edge_index);
@ -545,6 +561,15 @@ bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
bool AutogradContext::needs_input_grad(
std::initializer_list<IndexRange> idxs) const {
if (needs_input_grad_override_.has_value()) {
return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
bool result = false;
for (const auto i : c10::irange(range.first, range.second)) {
result |= needs_input_grad_override_.value().at(i);
}
return result;
});
}
auto ptr = grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
return ptr->task_should_compute_output(idxs);

View File

@ -126,6 +126,8 @@ struct TORCH_API AutogradContext {
AutogradContext& operator=(AutogradContext&& other) = delete;
~AutogradContext() = default;
AutogradContext(PackedArgs& packed_args);
/// Can be used to save non-variable data for `backward`.
ska::flat_hash_map<std::string, at::IValue> saved_data;
@ -169,12 +171,103 @@ struct TORCH_API AutogradContext {
std::weak_ptr<Node> grad_fn_;
bool has_freed_buffers_{false};
// Compiled autograd overrides saved_variables() and needs_input_grad().
// We store the values we want to return here.
std::optional<variable_list> saved_variables_override_;
std::optional<std::vector<bool>> needs_input_grad_override_;
void save_variables();
template <class T>
friend struct CppNode;
template <class T>
friend variable_list CppNode_apply_functional(
variable_list&& inputs,
AutogradContext& ctx_,
const std::vector<bool>& is_variable_input_,
const std::vector<VariableInfo>& output_info_,
const std::string& name);
};
template <typename T>
inline variable_list CppNode_apply_functional(
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
variable_list&& inputs,
AutogradContext& ctx_,
const std::vector<bool>& is_variable_input_,
const std::vector<VariableInfo>& output_info_,
const std::string& name) {
at::OptionalDeviceGuard _device_guard;
auto num_inputs = inputs.size();
variable_list backward_inputs;
backward_inputs.reserve(num_inputs);
for (const auto i : c10::irange(num_inputs)) {
if (inputs[i].defined() || !ctx_.materialize_grads_) {
backward_inputs.emplace_back(std::move(inputs[i]));
} else {
backward_inputs.emplace_back(output_info_[i].zeros(_device_guard));
}
}
auto outputs = T::backward(&ctx_, backward_inputs);
const auto num_forward_inputs =
static_cast<int64_t>(is_variable_input_.size());
auto num_outputs = static_cast<int64_t>(outputs.size());
// Returning too many results is ok, but only as long as they're all
// undefined. Truncate the result vector in that case.
if (num_outputs > num_forward_inputs) {
bool all_undef = true;
for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
all_undef &= (!outputs[i].defined());
}
if (all_undef) {
outputs.resize(num_forward_inputs);
num_outputs = num_forward_inputs;
}
}
if (num_outputs != num_forward_inputs) {
std::string msg("function ");
msg += name + " returned an incorrect number of gradients (expected ";
msg += std::to_string(num_forward_inputs) + ", got ";
msg += std::to_string(num_outputs) + ")";
throw std::runtime_error(msg);
}
variable_list results;
results.reserve(num_outputs);
for (const auto i : c10::irange(num_outputs)) {
if (!is_variable_input_[i]) {
if (outputs[i].defined()) {
std::string msg("function ");
msg += name +
" returned a gradient different that is defined at position ";
msg += std::to_string(i + 1) +
", std the corresponding forward input was not a Variable";
throw std::runtime_error(msg);
}
continue;
}
results.emplace_back(outputs[i]);
}
return results;
}
template <typename T>
inline variable_list CppNode_apply_functional_ivalue(
const variable_list& inputs,
const std::vector<c10::IValue>& args) {
auto packed_args = PackedArgs(args);
auto ctx = AutogradContext(packed_args);
auto output_info = packed_args.unpack<std::vector<VariableInfo>>();
auto is_variable_input = packed_args.unpack<std::vector<bool>>();
auto name = packed_args.unpack<std::string>();
return CppNode_apply_functional<T>(
variable_list(inputs), ctx, is_variable_input, output_info, name);
}
// CppNode<T> is the Node in the autograd graph that represents the user defined
// backward function for Function<T>. Calls to CppNode::apply are forward to
// T::backward().
@ -232,7 +325,64 @@ struct CppNode : public Node {
saved.before(ctx_.has_freed_buffers_);
saved.before(input_info_);
saved.before(output_info_);
auto results = apply(variable_list(inputs));
PackedArgs packed_args;
packed_args.pack_saved_data(ctx_.saved_data);
variable_list saved_variables = ctx_.get_saved_variables();
packed_args.pack(saved_variables);
packed_args.pack(ctx_.materialize_grads_);
packed_args.pack(ctx_.has_freed_buffers_);
std::vector<bool> needs_input_grad;
{
auto ptr = ctx_.grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
for (const auto i : c10::irange(ptr->next_edges().size())) {
needs_input_grad.push_back(ptr->task_should_compute_output(i));
}
}
packed_args.pack(needs_input_grad);
packed_args.pack(output_info_);
packed_args.pack(is_variable_input_);
packed_args.pack(name());
auto args = std::move(packed_args).vec();
auto output_metadata = torch::dynamo::autograd::
IValuePacker<std::vector<std::optional<InputMetadata>>>::pack(
torch::dynamo::autograd::get_input_metadata(next_edges()));
const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface();
// Each time apply_with_saved is called, we bind a new function to Python.
// This is because the schema might be different on compiled autograd cache
// misses. An alternative is to pass the schema to Python so that it can be
// an input to a function, but the schema can't be put into an FX graph
// right now.
std::vector<at::TypePtr> schema;
schema.reserve(args.size());
for (const auto& ivalue : args) {
if (ivalue.isTensor()) {
schema.emplace_back(at::TensorType::get());
} else {
schema.emplace_back(ivalue.type());
}
}
auto fn_name = pyinterface->bind_function(
saved.get_py_compiler(),
std::string(typeid(T).name()),
CppNode_apply_functional_ivalue<T>,
schema,
/*is_custom_function*/ true);
auto results = pyinterface->call_function(
saved.get_py_compiler(),
"apply_functional",
fn_name,
inputs,
args,
output_metadata);
saved.after(ctx_.saved_data);
TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
@ -403,68 +553,13 @@ auto Function<T>::apply(Args&&... args)
template <class T>
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
variable_list CppNode<T>::apply(variable_list&& inputs) {
at::OptionalDeviceGuard _device_guard;
auto num_inputs = inputs.size();
variable_list backward_inputs;
backward_inputs.reserve(num_inputs);
for (const auto i : c10::irange(num_inputs)) {
if (inputs[i].defined() || !ctx_.materialize_grads_) {
backward_inputs.emplace_back(std::move(inputs[i]));
} else {
backward_inputs.emplace_back(output_info_[i].zeros(_device_guard));
}
}
// Acquire lock to here protect thread safety on custom C++ Autograd Node
// This is needed for the custom Autograd Node since we don't know if the
// user defined Node will write to the shared data during backward.
// see Note [Thread Safety on Autograd Node]
std::lock_guard<std::mutex> lock(mutex_);
auto outputs = T::backward(&ctx_, backward_inputs);
const auto num_forward_inputs =
static_cast<int64_t>(is_variable_input_.size());
auto num_outputs = static_cast<int64_t>(outputs.size());
// Returning too many results is ok, but only as long as they're all
// undefined. Truncate the result vector in that case.
if (num_outputs > num_forward_inputs) {
bool all_undef = true;
for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
all_undef &= (!outputs[i].defined());
}
if (all_undef) {
outputs.resize(num_forward_inputs);
num_outputs = num_forward_inputs;
}
}
if (num_outputs != num_forward_inputs) {
std::string msg("function ");
msg += name() + " returned an incorrect number of gradients (expected ";
msg += std::to_string(num_forward_inputs) + ", got ";
msg += std::to_string(num_outputs) + ")";
throw std::runtime_error(msg);
}
variable_list results;
results.reserve(num_outputs);
for (const auto i : c10::irange(num_outputs)) {
if (!is_variable_input_[i]) {
if (outputs[i].defined()) {
std::string msg("function ");
msg += name() +
" returned a gradient different that is defined at position ";
msg += std::to_string(i + 1) +
", std the corresponding forward input was not a Variable";
throw std::runtime_error(msg);
}
continue;
}
results.emplace_back(outputs[i]);
}
return results;
return CppNode_apply_functional<T>(
std::move(inputs), ctx_, is_variable_input_, output_info_, name());
}
template <class T>

View File

@ -897,6 +897,19 @@ bool has_input_metadata<Edge>(const Edge& thing) {
return thing.is_valid();
}
std::vector<std::optional<InputMetadata>> collect_input_metadata(
const edge_list& edges) {
std::vector<std::optional<InputMetadata>> input_metadata;
for (const auto& edge : edges) {
if (!edge.is_valid()) {
input_metadata.emplace_back(std::nullopt);
continue;
}
input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr));
}
return input_metadata;
}
// Given an vector<Edge> or vector<optional<InputMetdata>>, validate the
// outputs. This involves using the InputMetadata to check the outputs and also
// potentially calling .sum_to on the outputs.

View File

@ -47,6 +47,8 @@ TORCH_API void validate_outputs(
const std::vector<std::optional<InputMetadata>>& input_metadata,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
TORCH_API std::vector<std::optional<InputMetadata>> collect_input_metadata(
const edge_list& edges);
struct NodeTask {
std::weak_ptr<GraphTask> base_;

View File

@ -34,8 +34,12 @@ using tensor_list = std::vector<at::Tensor>;
using variable_list = std::vector<Variable>;
using edge_list = std::vector<Edge>;
using saved_variable_list = std::vector<SavedVariable>;
using ivalue_list = std::vector<c10::IValue>;
using functional_apply_t = std::function<
variable_list(const variable_list&, const std::vector<c10::IValue>&)>;
using IndexRange = std::pair<size_t, size_t>;
using torch::dynamo::autograd::CompiledNodeArgs;
using torch::dynamo::autograd::PackedArgs;
using torch::dynamo::autograd::SwapSavedVariables;
// Custom deleter to prevent stack overflows.
@ -604,6 +608,12 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
std::string("apply_with_saved not implemented: ") + name());
}
// If this node is the AOTBackward node produced by torch.compile.
// Compiled Autograd special-cases on this information.
virtual bool is_aot_backward() const {
return false;
}
protected:
/// Performs the `Node`'s actual operation.
virtual variable_list apply(variable_list&& inputs) = 0;

View File

@ -8,6 +8,7 @@
namespace torch::dynamo::autograd {
class CompiledNodeArgs;
class SwapSavedVariables;
struct PackedArgs;
} // namespace torch::dynamo::autograd
// A hook that's called on gradients

View File

@ -16,6 +16,8 @@
namespace torch::autograd {
using torch::dynamo::autograd::IValuePacker;
static variable_list CopyBackwards_apply_functional(
variable_list&& grads,
std::array<bool, 2> needs_input_grad,
@ -41,6 +43,16 @@ static variable_list CopyBackwards_apply_functional(
return grad_inputs;
}
static variable_list CopyBackwards_apply_functional_ivalue(
const variable_list& grads,
const ivalue_list& args) {
PackedArgs r(args);
auto needs_input_grad = r.unpack<std::array<bool, 2>>();
auto src_options = r.unpack<c10::TensorOptions>();
return CopyBackwards_apply_functional(
variable_list(grads), needs_input_grad, src_options);
}
auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
return CopyBackwards_apply_functional(
std::move(grads),
@ -51,11 +63,43 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
args.collect(src_options);
}
variable_list CopyBackwards::apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) {
saved.before(src_options);
auto result = apply(variable_list(inputs));
static c10::once_flag flag;
c10::call_once(flag, [&]() {
std::vector<at::TypePtr> schema = {
IValuePacker<std::array<bool, 2>>::packed_type(),
IValuePacker<c10::TensorOptions>::packed_type()};
const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
interface->bind_function(
saved.get_py_compiler(),
name(),
CopyBackwards_apply_functional_ivalue,
schema);
});
PackedArgs packed_args;
packed_args.pack<std::array<bool, 2>>(
{task_should_compute_output(0), task_should_compute_output(1)});
packed_args.pack(src_options);
auto output_metadata = torch::dynamo::autograd::
IValuePacker<std::vector<std::optional<InputMetadata>>>::pack(
torch::dynamo::autograd::get_input_metadata(next_edges()));
const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
auto result = interface->call_function(
saved.get_py_compiler(),
"apply_functional",
name(),
inputs,
std::move(packed_args).vec(),
output_metadata);
saved.after(src_options);
return result;
}
@ -80,38 +124,7 @@ CopySlices::CopySlices(
}
}
// common code between apply/apply_with_saved
template <typename T>
inline variable_list CopySlices::apply_impl(
variable_list&& inputs,
const T& call_fn) {
check_input_variables("CopySlices", inputs, 1, -1, true);
auto& grad = std::move(inputs)[0];
if (!grad.defined()) {
return variable_list(num_outputs());
}
// Acquire lock to here protect thread safety on fn
// see Note [Thread Safety on Autograd Node]
std::lock_guard<std::mutex> lock(mutex_);
if (!fn) {
throw std::runtime_error(ERR_BACKWARD_TWICE);
}
auto result =
grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides());
result.copy_(grad);
at::Tensor grad_slice;
if (view_fn) {
grad_slice = (*view_fn)(result);
} else {
auto offset = view.sym_storage_offset() - base.sym_storage_offset();
grad_slice =
result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset);
}
void CopySlices::update_exec_info() {
// See Note [View + Inplace update for view tensor] For more details on this
// block Since the gradient edge for the 0th input is different between `this`
// and `fn`, make sure that the one from `fn` has the same metadata in the
@ -154,6 +167,41 @@ inline variable_list CopySlices::apply_impl(
TORCH_INTERNAL_ASSERT(
fn->next_edge(i).function.get() == this->next_edge(i).function.get());
}
}
// common code between apply/apply_with_saved
template <typename T>
inline variable_list CopySlices::apply_impl(
variable_list&& inputs,
const T& call_fn) {
check_input_variables("CopySlices", inputs, 1, -1, true);
auto& grad = std::move(inputs)[0];
if (!grad.defined()) {
return variable_list(num_outputs());
}
// Acquire lock to here protect thread safety on fn
// see Note [Thread Safety on Autograd Node]
std::lock_guard<std::mutex> lock(mutex_);
if (!fn) {
throw std::runtime_error(ERR_BACKWARD_TWICE);
}
auto result =
grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides());
result.copy_(grad);
at::Tensor grad_slice;
if (view_fn) {
grad_slice = (*view_fn)(result);
} else {
auto offset = view.sym_storage_offset() - base.sym_storage_offset();
grad_slice =
result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset);
}
update_exec_info();
// TODO: We clone grad_slice because we modify it below and "fn" might save
// it for the backward of res. We might be able to avoid the clone() if
@ -201,17 +249,38 @@ variable_list CopySlices::apply_with_saved(
SwapSavedVariables& saved) {
saved.before(base);
saved.before(view);
int call_count = 0;
variable_list result = apply_impl(
variable_list(grads),
[this, &saved, &call_count](const variable_list& inputs2) {
call_count++;
return fn->apply_with_saved(inputs2, saved);
});
TORCH_INTERNAL_ASSERT(call_count == 1);
auto results = variable_list(num_outputs());
if (grads[0].defined()) {
if (!fn) {
throw std::runtime_error(ERR_BACKWARD_TWICE);
}
update_exec_info();
std::vector<bool> needs_input_grad;
for (const auto i : c10::irange(num_outputs())) {
needs_input_grad.emplace_back(task_should_compute_output(i));
}
// Not yet supported, also doesn't happen in typical eager mode execution
// (this only happens by default with torch-xla).
TORCH_INTERNAL_ASSERT(!view_fn);
const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
variable_list stuff = interface->call_copy_slices_prologue(
saved.get_py_compiler(), grads, base, view);
TORCH_INTERNAL_ASSERT(stuff.size() == 3);
// These variables are named the same as in CopySlices::apply_impl.
// Follow along there.
auto result = stuff[0];
auto grad_slice = stuff[1];
auto grad_slice_clone = stuff[2];
auto res = fn->apply_with_saved({grad_slice_clone}, saved);
results = interface->call_copy_slices_epilogue(
saved.get_py_compiler(), needs_input_grad, result, res, grad_slice);
}
saved.after(base);
saved.after(view);
return result;
return results;
}
auto CopySlices::apply(variable_list&& inputs1) -> variable_list {

View File

@ -172,6 +172,7 @@ struct TORCH_API CopySlices : public Node {
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;
void update_exec_info();
at::TensorGeometry base;
// view and view_fn are redundant and view_fn will be used if available.

View File

@ -131,6 +131,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
if (!ParameterClass)
return nullptr;
py::class_<at::TensorGeometry>(m, "TensorGeometry")
.def("sizes", &at::TensorGeometry::sizes)
.def("strides", &at::TensorGeometry::strides)
.def("storage_offset", &at::TensorGeometry::storage_offset);
py::class_<LegacyEvent>(m, "ProfilerEvent")
.def("kind", &LegacyEvent::kindStr)
.def("name", [](const LegacyEvent& e) { return e.name(); })

View File

@ -30,6 +30,7 @@
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/profiler/api.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_dtypes.h>
@ -237,16 +238,23 @@ auto PyNode::defer_to_dynamo(
TORCH_INTERNAL_ASSERT(
_backward_idx.has_value(),
"indices should already be set by compiled_args, called before apply_with_saved");
TORCH_INTERNAL_ASSERT(!_backward_state_idx.has_value());
PyObject* backward_state_idx = Py_None;
if (_backward_state_idx.has_value()) {
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value());
// this might be simplifiable now that we no longer inline
Py_CLEAR(py_fn->compiled_autograd_backward_state);
}
THPObjectPtr r(PyObject_CallMethod(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
compiler.value(),
"proxy_call_backward",
"OOOi",
"OOOiOO",
pyInputs.get(),
fwdInputMetadatas.get(),
saved_tensors.get(),
*_backward_idx));
*_backward_idx,
obj,
backward_state_idx));
if (!r)
throw_python_error();
@ -288,6 +296,11 @@ auto PyNode::name() const -> std::string {
return name;
}
bool PyNode::is_aot_backward() const {
py::handle handle(obj);
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
}
auto PyNode::compiled_autograd_should_lift() const -> bool {
pybind11::gil_scoped_acquire gil;
static PyObject* attr_name =
@ -340,11 +353,8 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
args.collect(f->output_info);
args.collect(f->input_info);
if (compiled_autograd_should_lift()) {
Py_INCREF(obj);
_backward_idx =
args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
}
Py_INCREF(obj);
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
PyObject* bw_state = f->compiled_autograd_backward_state;
if (args.cond(bw_state != nullptr)) {
@ -366,28 +376,8 @@ variable_list PyNode::apply_with_saved(
saved.before(f->output_info);
saved.before(f->input_info);
f->compiled_autograd_tracing = true;
variable_list result;
if (!compiled_autograd_should_lift()) {
if (_backward_state_idx.has_value()) {
PyObject* r = PyObject_CallMethod(
saved.get_py_compiler(),
"bind_backward_state",
"i",
*_backward_state_idx);
if (r == nullptr) {
throw python_error();
}
THPObjectPtr prior(f->compiled_autograd_backward_state);
f->compiled_autograd_backward_state = r;
result = apply(variable_list(inputs));
Py_CLEAR(f->compiled_autograd_backward_state);
f->compiled_autograd_backward_state = prior.release();
} else {
result = apply(variable_list(inputs));
}
} else {
result = defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
}
variable_list result =
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
f->compiled_autograd_tracing = false;
saved.after(f->compiled_autograd_symints);
saved.after(f->saved_variables);
@ -1092,6 +1082,7 @@ PyObject* process_outputs(
THPFunction* grad_fn,
const UnpackedInput& unpacked,
PyObject* inputs,
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
THPObjectPtr&& raw_output,
bool is_executable,
torch::jit::Node* node,

View File

@ -43,6 +43,8 @@ struct PyNode : public Node {
std::string name() const override;
bool is_traceable() override;
bool is_aot_backward() const override;
void compiled_args(CompiledNodeArgs& args) override;
variable_list apply_with_saved(
const variable_list& inputs,

View File

@ -0,0 +1,27 @@
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
namespace torch::dynamo::autograd {
std::unique_ptr<PyCompilerInterface> kPyCompilerInterface;
const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface() {
TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr);
return kPyCompilerInterface;
}
void setPyCompilerInterface(std::unique_ptr<PyCompilerInterface>&& impl) {
TORCH_INTERNAL_ASSERT(impl != nullptr);
kPyCompilerInterface = std::move(impl);
}
void resetPyCompilerInterface() {
kPyCompilerInterface.reset();
}
std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges) {
return torch::autograd::collect_input_metadata(edges);
}
} // namespace torch::dynamo::autograd

View File

@ -900,6 +900,542 @@ class SwapSavedVariables {
StashedVars<at::IValue> stashed_ivalues;
};
// NOTE: [Compiled Autograd and backward functions]
// Built-in autograd nodes have functional apply variants
// (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph
// capture wants to take a variant of this function and proxy it into the graph.
// Every autograd node defines an apply_with_saved function, that when invoked,
// proxys a call to a function into the Compiled Autograd graph.
//
// Some requirements that we have are:
// - The proxy'ed function must have inputs that are FX-graphable types.
// - Windows has a DLL symbol limit of 65536.
// - Node::apply_with_saved is in libtorch_cpu which does not have direct access
// to Python
//
// There were multiple ways to skin the cat, but what we end up doing is:
// - for e.g. MulBackward0_apply_functional, we create a new C++ function
// MulBackward0_apply_functional_ivalue that accepts vector<IValue>.
// - We define how to pack and unpack arbitrary C++ types into IValues.
// - apply_with_saved passes MulBackward0_apply_functional_ivalue and
// the IValue arguments to Python via an indirection.
// In Python, these get proxy'ed into a graph.
// Helper struct for packing/unpacking an arbitrary C++ type into a single
// IValue. There are various full and partial specializations for IValuePacker
// to handle packing specific types (like TensorOptions) into an IValue.
template <typename T>
struct IValuePacker {
// Defines how to pack T into an IValue.
static at::IValue pack(const T& t) {
return t;
}
// Defines how to unpack an IValue into T.
static T unpack(const at::IValue& t) {
return t.to<T>();
}
// Returns the TypePtr for the IValue (this is like the "type" of the IValue).
// We use this when passing the packed IValue from Python to C++.
// In Python, the IValue is just a PyObject* with the native type.
// For example, it may be a Python int, a Python List[int], etc.
// When passing this PyObject* into C++, we need to know how to parse it
// into a C++ type that then gets put into an IValue.
// That's what the TypePtr is for: it contains the information to do the
// parsing. See torch::jit::toIValue for more information.
static at::TypePtr packed_type() {
if constexpr (::std::is_same_v<T, at::Tensor>) {
return at::TensorType::get();
} else if constexpr (::std::is_same_v<T, int64_t>) {
return at::IntType::get();
} else if constexpr (::std::is_same_v<T, c10::SymInt>) {
return at::SymIntType::get();
} else if constexpr (::std::is_same_v<T, bool>) {
return at::BoolType::get();
} else if constexpr (::std::is_same_v<T, double>) {
return at::FloatType::get();
} else if constexpr (::std::is_same_v<T, c10::SymFloat>) {
return at::SymFloatType::get();
} else if constexpr (::std::is_same_v<T, c10::SymBool>) {
return at::SymBoolType::get();
} else if constexpr (::std::is_same_v<T, c10::Layout>) {
return at::LayoutType::get();
} else if constexpr (::std::is_same_v<T, ::std::string>) {
return at::StringType::get();
} else if constexpr (::std::is_same_v<T, at::Device>) {
return at::DeviceObjType::get();
} else if constexpr (::std::is_same_v<T, at::Scalar>) {
return at::NumberType::get();
} else if constexpr (::std::is_same_v<T, at::MemoryFormat>) {
return at::MemoryFormatType::get();
} else if constexpr (::std::is_same_v<T, at::ScalarType>) {
return at::ScalarTypeType::get();
} else {
// If you got here, you have probably added a member of a new type
// to a built-in C++ autograd node.
// Unfortunately, we don't know how to handle this type yet.
// To get this new type to work with Compiled Autograd, please
// either change it to be an IValue-constructible type, or
// define how to pack and unpack an object of this time into an IValue
// by creating a specialization of IValuePacker for this type.
// See NOTE: [Compiled Autograd and backward functions] for context.
TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type");
return at::NoneType::get();
}
}
};
template <>
struct IValuePacker<size_t> {
static at::IValue pack(const size_t& t) {
// We generally use size_t as the size of a list of Tensors or number of
// dimensions. The number of dimensions generally do not exceed 64
// (TensorIterator has that limitation), and lists of Tensors generally do
// not exceed the int64_t max (you'd probably run out of RAM or run into
// significant Tensor overhead). If you run into this limitation the fix is
// to figure out how to pack size_t into int64_t. Note that size_t has some
// weird behavior on Mac OS.
uint64_t maximum_value = std::numeric_limits<int64_t>::max();
TORCH_INTERNAL_ASSERT(
static_cast<uint64_t>(t) <= maximum_value,
"size_t too large to pack into IValue");
return static_cast<int64_t>(t); // pack as int64_t
}
static size_t unpack(const at::IValue& t) {
return static_cast<size_t>(t.toInt());
}
static at::TypePtr packed_type() {
return IValuePacker<int64_t>::packed_type();
}
};
template <>
struct IValuePacker<std::vector<at::SymInt>> {
static at::IValue pack(const std::vector<at::SymInt>& t) {
return t;
}
static std::vector<at::SymInt> unpack(const at::IValue& t) {
// We need this because there's no t.to<std::vector<at::SymInt>>() override?
return t.toSymIntVector();
}
static at::TypePtr packed_type() {
return at::ListType::create(at::SymIntType::get());
}
};
template <>
struct IValuePacker<VariableInfo> {
static at::IValue pack(const VariableInfo& t) {
auto tuple = std::make_tuple(
t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty);
return tuple;
}
static VariableInfo unpack(const at::IValue& t) {
auto tuple = t.toTuple();
const auto& tuple_elements = tuple->elements();
const auto elements = tuple_elements.asArrayRef();
TORCH_INTERNAL_ASSERT(elements.size() == 6);
VariableInfo v;
v.layout = elements[0].toLayout();
v.device = elements[1].toDevice();
v.scalar_type = elements[2].toScalarType();
v.size = elements[3].toSymIntVector();
v.requires_grad = elements[4].toBool();
v.is_empty = elements[5].toBool();
return v;
}
static at::TypePtr packed_type() {
return at::TupleType::create({
at::LayoutType::get(),
at::DeviceObjType::get(),
at::ScalarTypeType::get(),
at::ListType::create(at::SymIntType::get()),
at::BoolType::get(),
at::BoolType::get(),
});
}
};
template <>
struct IValuePacker<caffe2::TypeMeta> {
static at::IValue pack(const caffe2::TypeMeta& t) {
return at::typeMetaToScalarType(t); // pack as at::ScalarType
}
static caffe2::TypeMeta unpack(const at::IValue& t) {
return caffe2::TypeMeta::fromScalarType(t.to<at::ScalarType>());
}
static at::TypePtr packed_type() {
return IValuePacker<at::ScalarType>::packed_type();
}
};
inline std::optional<at::ScalarType> optTypeMetaToScalarType(
const std::optional<caffe2::TypeMeta>& t) {
if (t.has_value()) {
return at::typeMetaToScalarType(t.value());
} else {
return std::nullopt;
}
}
using packed_tensoroptions_t = std::tuple<
std::optional<bool>,
std::optional<at::MemoryFormat>,
std::optional<at::Device>,
std::optional<at::ScalarType>,
std::optional<at::Layout>,
std::optional<bool>>;
inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) {
auto tuple = std::make_tuple(
t.requires_grad_opt(),
t.memory_format_opt(),
t.device_opt(),
optTypeMetaToScalarType(t.dtype_opt()),
t.layout_opt(),
t.pinned_memory_opt());
return tuple;
}
inline at::TensorOptions unpack_TensorOptions(
const packed_tensoroptions_t& tuple) {
at::TensorOptions result;
auto maybe_requires_grad = std::get<0>(tuple);
if (maybe_requires_grad.has_value()) {
result = result.requires_grad(maybe_requires_grad.value());
}
auto maybe_memory_format = std::get<1>(tuple);
if (maybe_memory_format.has_value()) {
result = result.memory_format(maybe_memory_format.value());
}
auto maybe_device = std::get<2>(tuple);
if (maybe_device.has_value()) {
result = result.device(maybe_device.value());
}
auto maybe_dtype = std::get<3>(tuple);
if (maybe_dtype.has_value()) {
result =
result.dtype(caffe2::TypeMeta::fromScalarType(maybe_dtype.value()));
}
auto maybe_layout = std::get<4>(tuple);
if (maybe_layout.has_value()) {
result = result.layout(maybe_layout.value());
}
auto maybe_pinned_memory = std::get<5>(tuple);
if (maybe_pinned_memory.has_value()) {
result = result.pinned_memory(maybe_pinned_memory.value());
}
return result;
}
template <>
struct IValuePacker<at::TensorOptions> {
static at::IValue pack(const at::TensorOptions& t) {
return pack_TensorOptions(t);
}
static at::TensorOptions unpack(const at::IValue& t) {
auto tuple = t.to<packed_tensoroptions_t>();
return unpack_TensorOptions(tuple);
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{at::OptionalType::create(at::BoolType::get()),
at::OptionalType::create(at::MemoryFormatType::get()),
at::OptionalType::create(at::DeviceObjType::get()),
at::OptionalType::create(at::ScalarTypeType::get()),
at::OptionalType::create(at::LayoutType::get()),
at::OptionalType::create(at::BoolType::get())});
}
};
template <>
struct IValuePacker<TypeAndSize> {
static at::IValue pack(const TypeAndSize& t) {
auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options));
return tuple;
}
static TypeAndSize unpack(const at::IValue& t) {
auto tuple =
t.to<std::tuple<std::vector<at::SymInt>, packed_tensoroptions_t>>();
TypeAndSize result;
result.sym_sizes = std::get<0>(tuple);
result.options = unpack_TensorOptions(std::get<1>(tuple));
return result;
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{IValuePacker<std::vector<at::SymInt>>::packed_type(),
IValuePacker<at::TensorOptions>::packed_type()});
}
};
template <typename T>
struct IValuePacker<std::optional<T>> {
static at::IValue pack(const std::optional<T>& t) {
if (t.has_value()) {
return IValuePacker<T>::pack(t.value());
} else {
return std::nullopt;
}
}
static std::optional<T> unpack(const at::IValue& t) {
if (t.isNone()) {
return std::nullopt;
} else {
return IValuePacker<T>::unpack(t);
}
}
static at::TypePtr packed_type() {
return at::OptionalType::create(IValuePacker<T>::packed_type());
}
};
template <typename T>
struct IValuePacker<std::vector<T>> {
static at::IValue pack(const std::vector<T>& t) {
if constexpr (::std::is_constructible_v<at::IValue, T>) {
return t;
}
if (t.empty()) {
auto lst = c10::impl::GenericList(at::AnyType::get());
return lst;
}
auto type_ptr = IValuePacker<T>::pack(t[0]).type();
auto lst = c10::impl::GenericList(type_ptr);
for (const auto& elt : t) {
lst.emplace_back(IValuePacker<T>::pack(elt));
}
return lst;
}
static std::vector<T> unpack(const at::IValue& t) {
if constexpr (::std::is_constructible_v<at::IValue, T>) {
return t.to<::std::vector<T>>();
}
std::vector<T> result;
auto lst = t.toList();
for (const at::IValue& elt : lst) {
result.emplace_back(IValuePacker<T>::unpack(elt));
}
return result;
}
static at::TypePtr packed_type() {
return at::ListType::create(IValuePacker<T>::packed_type());
}
};
template <typename T>
struct IValuePacker<c10::List<T>> {
static at::IValue pack(const c10::List<T>& t) {
return IValuePacker<std::vector<T>>::pack(t.vec());
}
static c10::List<T> unpack(const at::IValue& t) {
return c10::List<T>(IValuePacker<std::vector<T>>::unpack(t));
}
static at::TypePtr packed_type() {
return IValuePacker<std::vector<T>>::packed_type();
}
};
template <size_t N>
struct IValuePacker<std::array<bool, N>> {
static at::IValue pack(const std::array<bool, N>& t) {
std::vector<bool> result(t.begin(), t.end());
return IValuePacker<std::vector<bool>>::pack(result);
}
static std::array<bool, N> unpack(const at::IValue& t) {
std::array<bool, N> result;
auto packed = IValuePacker<std::vector<bool>>::unpack(t);
for (size_t i = 0; i < packed.size(); i++) {
result[i] = packed[i];
}
return result;
}
static at::TypePtr packed_type() {
return IValuePacker<std::vector<bool>>::packed_type();
}
};
template <>
struct IValuePacker<at::TensorGeometry> {
static at::IValue pack(const at::TensorGeometry& t) {
auto tuple = std::make_tuple(
t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset());
return tuple;
}
static at::TensorGeometry unpack(const at::IValue& t) {
auto tuple = t.to<std::tuple<
std::vector<at::SymInt>,
std::vector<at::SymInt>,
at::SymInt>>();
return at::TensorGeometry(
std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple));
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{IValuePacker<std::vector<at::SymInt>>::packed_type(),
IValuePacker<std::vector<at::SymInt>>::packed_type(),
at::SymIntType::get()});
}
};
template <>
struct IValuePacker<InputMetadata> {
static at::IValue pack(const InputMetadata& t) {
TORCH_INTERNAL_ASSERT(!t.is_nested_tensor());
auto tuple = std::make_tuple(
pack_TensorOptions(t.options()),
t.shape_as_dim_vector().vec(),
t.is_tensor_subclass());
return tuple;
}
static InputMetadata unpack(const at::IValue& t) {
auto tuple = t.to<
std::tuple<packed_tensoroptions_t, std::vector<at::SymInt>, bool>>();
return InputMetadata(
unpack_TensorOptions(std::get<0>(tuple)),
SymIntSmallVec(std::get<1>(tuple)),
std::get<2>(tuple),
false);
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{IValuePacker<at::TensorOptions>::packed_type(),
IValuePacker<std::vector<at::SymInt>>::packed_type(),
at::BoolType::get()});
}
};
template <typename T>
struct IValuePacker<at::OptionalArray<T>> {
static at::IValue pack(const at::OptionalArray<T>& t) {
return IValuePacker<std::optional<std::vector<T>>>::pack(t.list);
}
static at::OptionalArray<T> unpack(const at::IValue& t) {
auto result = IValuePacker<std::optional<std::vector<T>>>::unpack(t);
if (result.has_value()) {
return {result.value()};
} else {
return {};
}
}
static at::TypePtr packed_type() {
return IValuePacker<std::optional<std::vector<T>>>::packed_type();
}
};
// This is a helper struct for packing and unpacking multiple arguments into
// an ivalue_list. It leverages IValuePacker<T>.
struct PackedArgs {
PackedArgs() = default;
explicit PackedArgs(std::vector<at::IValue> stack_)
: stack(std::move(stack_)) {}
const std::vector<at::IValue>& vec() const {
return stack;
}
template <typename T>
void pack(const T& t) {
stack.emplace_back(IValuePacker<T>::pack(t));
}
template <typename T>
T unpack() {
return IValuePacker<T>::unpack(std::move(stack[idx++]));
}
void pack_saved_data(const ska::flat_hash_map<std::string, at::IValue>& dct) {
std::vector<std::string> keys;
std::vector<at::IValue> values;
for (const auto& [key, value] : dct) {
keys.emplace_back(key);
values.emplace_back(value);
}
pack(keys);
for (const auto& value : values) {
pack(value);
}
}
ska::flat_hash_map<std::string, at::IValue> unpack_saved_data() {
ska::flat_hash_map<std::string, at::IValue> dct;
auto keys = unpack<std::vector<std::string>>();
for (const auto& key : keys) {
dct.insert({key, std::move(stack[idx++])});
}
return dct;
}
private:
std::vector<at::IValue> stack;
int64_t idx = 0;
};
// This is a layer of indirection for calling methods on the Python
// AutogradCompilerInstance (referred to as the "py_compiler") from
// libtorch_cpu (where Python is not available).
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
// overrides the methods to do the actual calls back to Python.
struct TORCH_API PyCompilerInterface {
PyCompilerInterface() = default;
PyCompilerInterface(const PyCompilerInterface&) = delete;
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
PyCompilerInterface(PyCompilerInterface&&) = delete;
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
virtual ~PyCompilerInterface() = default;
// Invokes py_compiler.bind_function(fn_name, fn)
virtual std::string bind_function(
PyObject* py_compiler,
const std::string& fn_name,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
functional_apply_t fn,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function = false) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
// output_metadata)
virtual variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_prologue(
PyObject* py_compiler,
const variable_list& inputs,
const at::TensorGeometry& base,
const at::TensorGeometry& view) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
virtual variable_list call_copy_slices_epilogue(
PyObject* py_compiler,
const std::vector<bool>& needs_input_grad,
const at::Tensor& result,
const variable_list& res,
const at::Tensor& grad_slice) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
};
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
TORCH_API void setPyCompilerInterface(
std::unique_ptr<PyCompilerInterface>&& impl);
TORCH_API void resetPyCompilerInterface();
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
// symbol resolution issues. Instead requiring downstream users to include
// engine.h to access collect_input_metadata, we provide it here (with a
// different name to avoid ambigous symbols...)
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges);
} // namespace torch::dynamo::autograd
template <>

View File

@ -52,6 +52,156 @@ Notes:
namespace torch::dynamo::autograd {
using c10::SymInt;
// List[Optional[Tensor]] in Python can't be directly parsed into a
// List[Tensor], so we need to do this conversion manually.
static std::vector<at::Tensor> toTensorList(
const std::vector<std::optional<at::Tensor>>& inputs) {
std::vector<at::Tensor> result;
result.reserve(inputs.size());
for (const auto& inp : inputs) {
if (inp.has_value()) {
result.emplace_back(*inp);
} else {
result.emplace_back();
}
}
return result;
}
// Binds a function (that represents some backward computation) to Python.
// All of these functions have a common signature, which is
// (in C++) (vector<Tensor>, vector<ivalue>) -> vector<Tensor>
// (in Python) (List[Optional[Tensor]], *packed_args: IValue) ->
// List[Optional[Tensor]]
//
// The vector<Tensor> are the list of gradient Tensors, each of which may be
// undefined (in C++) which corresponds to None (in Python).
static std::string bind_function(
PyObject* py_compiler,
const std::string& fn_name,
functional_apply_t fn,
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function) {
// This is the function that can be called from Python.
auto py_func = py::cpp_function(
[packed_args_schema = std::move(packed_args_schema), fn = std::move(fn)](
std::vector<std::optional<at::Tensor>>& inputs,
const py::args& py_args) -> py::object {
// py_args is a tuple of PyObject*.
// We need to reconstruct a vector<IValue> to invoke `fn`.
// To do so, we use the packed_args_schema to convert each PyObject*
// to its corresponding C++ type that can be stored into IValue.
TORCH_INTERNAL_ASSERT(py_args.size() == packed_args_schema.size());
std::vector<at::IValue> args;
args.reserve(py_args.size());
auto tuple_args = jit::tuple_slice(py_args);
for (uint64_t idx = 0; idx < packed_args_schema.size(); idx++) {
if (packed_args_schema[idx]->isSubtypeOf(
*at::ListType::ofTensors())) {
// List[Tensor] might have Nones, not handled in jit::toIValue
auto tmp = py::cast<std::vector<std::optional<at::Tensor>>>(
tuple_args[idx]);
args.emplace_back(toTensorList(tmp));
} else {
args.emplace_back(jit::toIValue(
tuple_args[idx], packed_args_schema[idx], std::nullopt));
}
}
// None in Python corresponds to undefined Tensor in C++
auto inputs_ = toTensorList(inputs);
auto outputs = fn(inputs_, args);
return jit::toPyObject(at::IValue(outputs));
});
py::handle handle(py_compiler);
auto result =
handle.attr("bind_function")(fn_name, py_func, is_custom_function);
return result.cast<std::string>();
}
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
// output_metadata)
static variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
// convert ivalue_list -> PyObject*
PyObject* py_packed_args =
PyTuple_New(static_cast<Py_ssize_t>(packed_args.size()));
for (const auto i : c10::irange(packed_args.size())) {
py::object obj = jit::toPyObject(packed_args[i]);
Py_INCREF(obj.ptr());
PyTuple_SET_ITEM(py_packed_args, i, obj.ptr());
}
// call the corresponding method on the py_compiler
py::handle handle(py_compiler);
py::object stuff = handle.attr(method_name)(
fn_name,
inputs,
py::handle(py_packed_args),
jit::toPyObject(output_metadata));
// Convert the output from PyObject* to vector<Tensor>
auto tmp = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
return toTensorList(tmp);
}
struct PyCompilerInterfaceImpl : PyCompilerInterface {
std::string bind_function(
PyObject* py_compiler,
const std::string& fn_name,
functional_apply_t fn,
std::vector<at::TypePtr> packed_args_schema,
bool is_custom_function = false) override {
return torch::dynamo::autograd::bind_function(
py_compiler,
fn_name,
std::move(fn),
std::move(packed_args_schema),
is_custom_function);
}
variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) override {
return torch::dynamo::autograd::call_function(
py_compiler,
method_name,
fn_name,
inputs,
packed_args,
output_metadata);
}
variable_list call_copy_slices_prologue(
PyObject* py_compiler,
const variable_list& inputs,
const at::TensorGeometry& base,
const at::TensorGeometry& view) override {
py::handle handle(py_compiler);
py::object stuff =
handle.attr("call_copy_slices_prologue")(inputs, base, view);
return py::cast<std::vector<at::Tensor>>(stuff);
}
variable_list call_copy_slices_epilogue(
PyObject* py_compiler,
const std::vector<bool>& needs_input_grad,
const at::Tensor& result,
const variable_list& res,
const at::Tensor& grad_slice) override {
py::handle handle(py_compiler);
py::object stuff = handle.attr("call_copy_slices_epilogue")(
needs_input_grad, result, res, grad_slice);
auto output = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
return toTensorList(output);
}
};
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
for (const auto i : c10::irange(inputs.size())) {
@ -88,6 +238,22 @@ static void check(bool result) {
check(nullptr);
}
static variable_list validate_outputs(
const variable_list& outputs,
const ivalue_list& args) {
auto r = PackedArgs(args);
auto value = r.unpack<std::vector<std::optional<InputMetadata>>>();
auto new_outputs = outputs;
torch::autograd::validate_outputs(
value, new_outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "[Compiled Autograd Tracing:]" << msg;
return ss.str();
});
return new_outputs;
}
// snapshot of python verbose logging toggle
static PyObject* python_verbose_logger = nullptr;
@ -498,6 +664,21 @@ static void set_ivalue_proxies(
}
}
static at::Tensor call_accumulate(
PyObject* py_compiler,
const at::Tensor& old_var,
const at::Tensor& new_var) {
if (!old_var.defined()) {
return new_var;
}
if (!new_var.defined()) {
return old_var;
}
py::handle handle(py_compiler);
py::object stuff = handle.attr("accumulate")(old_var, new_var);
return py::cast<at::Tensor>(stuff);
}
static TraceState call_begin_capture(
PyObject* self,
CacheNode& cache,
@ -657,6 +838,8 @@ static CacheNode* _compiled_autograd_impl(
ClosingTHPObjectPtr py_compiler(
check(PyObject_CallNoArgs((the_autograd_compiler))));
setPyCompilerInterface(std::make_unique<PyCompilerInterfaceImpl>());
TraceState state = call_begin_capture(
py_compiler, *cache, compiler_call, output_edges.size());
InputBuffers input_buffers;
@ -723,16 +906,52 @@ static CacheNode* _compiled_autograd_impl(
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
variable_list outputs = call.node->apply_with_saved(inputs, saved);
saved.debug_asserts();
saved.before(call.node->next_edges());
validate_outputs(
call.node->next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
<< msg;
return ss.str();
});
auto input_metadata = get_input_metadata(call.node->next_edges());
TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size());
// Lazily bind the `validate_outputs` function to Python.
static c10::once_flag flag;
c10::call_once(flag, [&]() {
auto schema = std::vector<at::TypePtr>{IValuePacker<
std::vector<std::optional<InputMetadata>>>::packed_type()};
bind_function(
py_compiler.get(),
"validate_outputs",
validate_outputs,
schema,
false);
});
// Don't emit validate_outputs nodes that follow a CompiledBackward node.
// These nodes would otherwise prevent reordering of accumulate_grad
// nodes.
//
// Note that this will not cause correctness issues, because
// 1) AOTAutograd already coerces gradients to have the same metadata as
// the inputs. 2) the AOTAutograd graph already has the necessary
// aten::sum_to nodes in it (so it doesn't need to rely on
// validate_outputs to handle that).
//
// However, we may be dropping some (edge case) safety checks compared to
// eager: a backward that would have errored out in eager may not error
// out in compiled autograd (for example, if the user provided an
// incorrect number of gradients).
if (!call.node->is_aot_backward()) {
PackedArgs args;
args.pack(input_metadata);
ivalue_list input_metadata_state = std::move(args).vec();
outputs = call_function(
py_compiler,
"validate_outputs",
"validate_outputs",
outputs,
input_metadata_state,
input_metadata_state[0]);
}
saved.after(call.node->next_edges());
saved.debug_asserts();
@ -754,13 +973,14 @@ static CacheNode* _compiled_autograd_impl(
auto& output = outputs[i];
const auto& next = call.node->next_edge(i);
if (next.is_valid() && output.defined()) {
input_buffers.lookup(next.function.get())
.add(
next.input_nr, std::move(output), std::nullopt, std::nullopt);
auto& buffer = input_buffers.lookup(next.function.get());
buffer.buffer[next.input_nr] = call_accumulate(
py_compiler, buffer.buffer[next.input_nr], output);
}
}
}
resetPyCompilerInterface();
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
TORCH_CHECK(