[compiled autograd] Proxy opaque nodes for built-in autograd nodes (#143296)

This PR is on the way to getting compiled autograd's initial capture to
stop specializing on Tensor metadata.

This PR changes compiled autograd's initial capture to proxy an opaque
(w.r.t. Dynamo) function into the graph for all built-in codegen'ed
autograd nodes and validate_outputs.

We changed each codegen'ed apply_with_saved (e.g.
MulBackward0::apply_with_saved) to call into Python to proxy a function
(compiled_autograd.ops.MulBackward0) into the graph. Then, we use the
node's InputMetadata to "guess" at the properties of the output Tensors
to create some new FakeTensors.

Some details:
- MulBackward0::apply_with_saved lives in libtorch_cpu, but needs to be
  call to Python via libtorch_python. There is an indirection
  (PyCompilerInterface) to do this.
- MulBackward0::apply_with_saved passes a C++ function to Python. To make
  our lives easier, every codegen'ed apply_with_saved passes a C++
  function with the same signature
  `(variable_list, ivalue_list) -> variable_list`.
- We define how to pack arbitrary C++ types into IValue via a helper
  IValuePacker struct and codegen functional variants of each builtin
  C++ autograd node (e.g. MulBackward0_apply_functional_ivalue).

MulBackward0 before this PR:
https://gist.github.com/zou3519/a80381d5fa38e970e413fcd91b0530de

MulBackward0 after this PR:
https://gist.github.com/zou3519/0c2eee8b3d8d96232b51ef430b53c5b0

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143296
Approved by: https://github.com/jansel
This commit is contained in:
rzou
2025-01-22 07:08:05 -08:00
committed by PyTorch MergeBot
parent 0cb9b2284a
commit 5531fafffe
16 changed files with 961 additions and 46 deletions

View File

@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry {
has_symbolic_sizes_strides_(
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
explicit TensorGeometry(
std::vector<at::SymInt> sizes,
std::vector<at::SymInt> strides,
at::SymInt storage_offset)
: sizes_(std::move(sizes)),
strides_(std::move(strides)),
storage_offset_(std::move(storage_offset)) {
recompute();
}
// true if the tensor is contiguous
bool is_contiguous() const;

View File

@ -138,6 +138,7 @@ core_trainer_sources = [
"torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/utils/warnings.cpp",
"torch/csrc/autograd/jit_decomp_interface.cpp",
"torch/csrc/dynamo/compiled_autograd.cpp",
"torch/csrc/jit/frontend/name_mangler.cpp",
"torch/csrc/jit/ir/type_hashing.cpp",
"torch/csrc/jit/serialization/pickler.cpp",

View File

@ -121,23 +121,27 @@ class _multiply_invoke(torch.nn.Module):
out.backward(grad_out)
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(x.grad, grad_out * grad_out)
self.assertExpectedInline(
actual,
"""\
if backend in ["aot_eager", "inductor"]:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list):
l_inputs_ = L_inputs_
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
new_grad: "f32[s0]" = torch.clone(getitem)
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None
getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None
result: "f32[s0]" = getitem * getitem; getitem = None
new_grad: "f32[2]" = torch.clone(getitem_3)
new_grad_1: "f32[s0]" = torch.clone(result); result = None
result: "f32[2]" = getitem_3 * getitem_3; getitem_3 = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
""",
)
)
graph = None
@ -187,26 +191,30 @@ class GraphModule(torch.nn.Module):
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(obj.counter, 1)
self.assertEqual(x.grad, grad_out + grad_out)
self.assertExpectedInline(
actual,
"""\
if backend in ["aot_eager", "inductor"]:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"):
l_inputs_ = L_inputs_
l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
new_grad: "f32[s0]" = torch.clone(getitem)
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None
getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None
new_grad: "f32[2]" = torch.clone(getitem_3)
add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None
result: "f32[s0]" = getitem * getitem; getitem = None
result: "f32[2]" = getitem_3 * getitem_3; getitem_3 = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
new_grad_1: "f32[2]" = torch.clone(result); result = None
return (new_grad, new_grad_1, add)
""",
)
)
out = fn(x, y)
out.backward(grad_out)

View File

@ -2924,7 +2924,6 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
"aot0_le",
"aot0_permute_2",
"code: CompiledFunctionBackward0 (NodeCall 2)",
"aot0_tangents_1",
"aot0_full_default",
"aot0_where",
"aot0_mm",
@ -2974,20 +2973,17 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
expected_logs = [
"CompiledFunctionBackward1",
"aot1_tangents_1",
"aot1_sin_1",
"aot1_primals_2",
"aot1_neg",
"aot0_tangents_2",
"aot1_cos_1",
"aot1_primals_1",
"aot0_tangents_1",
"CompiledFunctionBackward0",
"aot0_sin_1",
"aot0_neg",
"aot0_sin",
"aot0_mul",
"aot0_cos_1",
"aot0_mul_1",
"aot0_cos",
"aot0_add",
]
@ -3618,6 +3614,7 @@ known_failing_tests = {
"test_tp_compile_comm_reordering",
"test_unwrap_async_collective_tensor_tangent",
# Uncategorized
"test_not_implemented_grad", # Dynamo changes the types of exceptions
}
if not HAS_CUDA:

View File

@ -337,7 +337,9 @@ class DistributedPatternTests(TestCase):
self.assertEqual(fw_cnt.frame_count, 1)
self.assertEqual(fw_cnt.op_count, 5)
self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None
self.assertEqual(bw_cnt.op_count, 48)
self.assertEqual(
bw_cnt.op_count, 72
) # Number of ops in the Dynamo-produced graphs
def test_module_backward_hooks_aot(self):
m1, inp1 = init_module_bw_hooks(True)

View File

@ -68,6 +68,7 @@ struct TORCH_API ${op} : public ${superclass} {
}
${will_release_variables}
void compiled_args(CompiledNodeArgs& args) override;
ivalue_list get_packed_args();
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
${saved_variables}
${saved_list_sizes}
@ -107,6 +108,13 @@ static variable_list ${op}_apply_functional(
${body}
return grad_inputs;
}
static variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args)
{
auto packed_args = PackedArgs(args);
auto needs_input_grad = packed_args.unpack<std::array<bool, ${num_inputs}>>();
${unpack_ivalues}
return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args});
}
variable_list ${op}::apply(variable_list&& grads) {
${thread_lock}
@ -120,11 +128,42 @@ void ${op}::compiled_args(CompiledNodeArgs& args) {
${compiled_args}
}
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
${apply_with_saved_before}
variable_list result = apply(variable_list(grads));
${apply_with_saved_after}
return result;
${apply_with_saved_before}
static std::once_flag flag;
std::call_once(flag, [&](){
${compute_schema}
const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
interface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema);
});
variable_list result;
auto packed_args = get_packed_args();
auto output_metadata = torch::dynamo::autograd::IValuePacker<
std::vector<std::optional<InputMetadata>>>::pack(
torch::dynamo::autograd::get_input_metadata(next_edges()));
const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
result = interface->call_function(
saved.get_py_compiler(),
"apply_functional",
name(),
grads,
packed_args,
output_metadata);
${apply_with_saved_after}
return result;
}
ivalue_list ${op}::get_packed_args() {
PackedArgs packed_args;
${asserts}
${unpacks}
${compute_needs_input_grad}
packed_args.pack(needs_input_grad);
${get_packed_args}
return std::move(packed_args).vec();
}
"""
)
@ -993,14 +1032,38 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
f"{T} {x}"
for T, x in zip(apply_functional_args_ref_types, apply_functional_args)
]
get_packed_args = "\n".join(
f"packed_args.pack({name});" for name in apply_functional_args
)
unpack_ivalues = []
for typ, name in zip(apply_functional_args_ref_types, apply_functional_args):
if typ.endswith("&"):
typ = typ[:-1]
unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();")
schema_args = [f"std::array<bool, {len(input_name_to_idx)}>"]
for typ in apply_functional_args_ref_types:
if typ.endswith("&"):
typ = typ[:-1]
if typ.startswith("const"):
typ = typ[5:]
schema_args.append(typ.strip())
compute_schema = ["std::vector<at::TypePtr> schema = {"]
for schema_arg in schema_args:
compute_schema.append(
f" torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type(),"
)
compute_schema.append("};")
return template.substitute(
unpacks="\n".join(unpack),
op=info.op,
compute_schema="\n".join(compute_schema),
apply_functional_args=apply_functional_args,
apply_functional_args_signature=apply_functional_args_signature,
compute_needs_input_grad=compute_needs_input_grad,
num_inputs=len(input_name_to_idx),
unpack_ivalues="\n".join(unpack_ivalues),
compute_index_ranges=compute_index_ranges,
saved_variables=saved_variables,
release_variables=release_variables,
@ -1015,4 +1078,5 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
compiled_args=compiled_args,
apply_with_saved_before=apply_with_saved_before,
apply_with_saved_after=apply_with_saved_after,
get_packed_args=get_packed_args,
)

View File

@ -8,6 +8,7 @@ from collections import defaultdict
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree
from torch._dynamo.external_utils import (
call_backward,
call_hook,
@ -65,6 +66,39 @@ def maybe_clone(x):
return x
# We lazily bind "functional backward" variants for PyTorch built-in autograd
# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0
# Each "functional backward" is bound the first time the node's apply_with_saved
# function is called. It's possible to avoid lazy binding and instead bind
# all of this upfront (perhaps at import time) via codegen changes.
class OpNamespace:
def add(self, name, fn):
assert not hasattr(self, name)
result = Op(name, fn)
torch._dynamo.allow_in_graph(result)
setattr(self, name, result)
return result
def get(self, name):
return getattr(self, name)
class Op:
def __init__(self, name, fn):
self.fn = fn
self.__name__ = name
self.__module__ = "torch._dynamo.compiled_autograd.ops"
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def __repr__(self):
return self.__module__ + "." + self.__name__
ops = OpNamespace()
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
_impure_targets = OrderedSet(
[
@ -137,7 +171,8 @@ class AutogradCompilerInstance:
self.fx_tracer.root = torch.nn.Module()
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = (
self.symnode_proxy_lookup = {}
args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
self.fx_tracer.create_proxy("placeholder", name, (), {})
for name in _graph_placeholders
)
@ -160,7 +195,9 @@ class AutogradCompilerInstance:
)
for idx, val in enumerate(sizes)
]
self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins)
self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins)
for i, symint in enumerate(sizes):
self.symnode_proxy_lookup[symint.node] = self.sizes_proxy[i]
for idx, val in enumerate(scalars):
source = self.source("scalars", idx)
@ -182,7 +219,9 @@ class AutogradCompilerInstance:
)
else:
raise AssertionError("Unexpected scalar type: ", type(val))
self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins)
self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins)
for i, symval in enumerate(scalars):
self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr]
# TODO(jansel): are all these modes needed?
self.stack.enter_context(decompose({}))
@ -216,7 +255,6 @@ class AutogradCompilerInstance:
),
kwargs={},
)
with disable_proxy_modes_tracing():
# create fake Tensors
grad_ins: list[Optional[torch.Tensor]] = []
@ -232,6 +270,65 @@ class AutogradCompilerInstance:
self.bind_tensors_to_proxies(grad_ins, proxies)
return tuple(grad_ins)
# Guess what the outputs should be from the InputMetadata.
# This is not sound in general (we guess contiguous strides
# and no Tensor subclass-ness); we will stop guessing
# the output metadata in a follow-up.
def guess_output(self, input_metadata):
if input_metadata is None:
return None
tensoroptions, shape, _ = input_metadata
kwargs = {}
names = [
"requires_grad",
"memory_format",
"device",
"dtype",
"layout",
"pinned_memory",
]
for name, option in zip(names, tensoroptions):
if option is not None:
kwargs[name] = option
with disable_proxy_modes_tracing():
return torch.ops.aten.zeros(shape, **kwargs)
def bind_function(self, fn_name, fn):
"""Binds ops.fn_name = fn"""
ops.add(fn_name, fn)
def apply_functional(self, fn_name, grads, args, output_metadata):
"""Proxies a call to ops.fn_name(grads, *args) into the graph"""
op = ops.get(fn_name)
return self.proxy_call(op, (grads, *args), output_metadata)
def proxy_call(self, fn, args, output_metadata):
"""Proxies a call to fn(*args) into the graph"""
flat_args, _ = pytree.tree_flatten(args)
proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args)
proxy_out = self.fx_tracer.create_proxy(
"call_function", fn, args=proxy_args, kwargs={}
)
result = [self.guess_output(metadata) for metadata in output_metadata]
self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(len(result))])
return result
def validate_outputs(self, _, outputs, args, output_metadata):
"""Proxies a call to ops.validate_outputs(outputs, *args) into the graph"""
op = ops.get("validate_outputs")
proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args))
new_proxy_outputs = self.fx_tracer.create_proxy(
"call_function", op, args=proxy_args, kwargs={}
)
assert len(output_metadata) == len(outputs)
outputs = [
None if output is None or metadata is None else self.guess_output(metadata)
for output, metadata in zip(outputs, output_metadata)
]
self.bind_tensors_to_proxies(outputs, new_proxy_outputs)
return outputs
def proxy_call_hook(self, hook, *args, **kwargs):
return self.fx_tracer.create_proxy(
"call_function",
@ -314,6 +411,7 @@ class AutogradCompilerInstance:
assert nodes[first_getitem_idx] == inputs_users[0]
last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
assert nodes[last_getitem_idx] == inputs_users[-1]
# getitem nodes on inputs
for i, node in enumerate(inputs_users):
if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
has_cuda_inputs = True
@ -323,9 +421,13 @@ class AutogradCompilerInstance:
is_scalar = len(node.meta["val"].size()) == 0
if is_cpu and is_scalar:
node_users = list(node.users.keys())
# We can only move the cpu scalar if it is not exposed to user code.
if all(
isinstance(user.target, torch._ops.OpOverload)
and user.target.namespace in ("prims", "aten")
(
isinstance(user.target, torch._ops.OpOverload)
and user.target.namespace in ("prims", "aten")
)
or isinstance(user.target, Op)
for user in node_users
):
# all users are prims/aten, can move safely
@ -335,6 +437,7 @@ class AutogradCompilerInstance:
# this is to handle the case where cudagraphs is enabled on a cpu-only graph
if has_cuda_inputs:
for node in to_move.values():
verbose_log.debug("Moving node %s from cpu to cuda", node)
node.meta["val"] = node.meta["val"].cuda()
# return runtime indices we need to move to cuda
@ -368,7 +471,10 @@ class AutogradCompilerInstance:
or (node.op == "call_function" and node.target in _impure_targets)
)
before = len(self.fx_tracer.graph.nodes)
self.fx_tracer.graph.eliminate_dead_code(is_impure)
after = len(self.fx_tracer.graph.nodes)
verbose_log.debug("DCE removed %d nodes", before - after)
def end_capture(self, outputs):
self.fx_tracer.create_proxy(
@ -384,6 +490,10 @@ class AutogradCompilerInstance:
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
{},
)
runtime_inputs_to_move: list[int] = []
if snapshot_cudagraph_enabled():
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
# TODO(rzou): the guessed metadata is incorrect, we will remove it at the end of the PR stack.
self.rename_aot_dispatcher_nodes()
self.reorder_tensor_pre_hook_nodes()
self.reorder_pre_hook_nodes_to_schedule_asap()
@ -402,9 +512,6 @@ class AutogradCompilerInstance:
# Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and
# should prevent these ops from going into the CA graph.
self.dce()
runtime_inputs_to_move: list[int] = []
if snapshot_cudagraph_enabled():
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
graph = GraphModule(
self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}"
@ -778,8 +885,11 @@ class AutogradCompilerInstance:
return [self.to_proxy(x) for x in t]
if isinstance(t, tuple):
return tuple(self.to_proxy(x) for x in t)
# can it be torch.SymInt as the code used to imply?
assert isinstance(t, torch.Tensor)
if isinstance(t, (torch.SymInt, torch.SymFloat)):
return self.symnode_proxy_lookup[t.node]
if not isinstance(t, torch.Tensor):
# constant types like device, dtype, str
return t
proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
return proxy_tensor.proxy

View File

@ -898,6 +898,19 @@ bool has_input_metadata<Edge>(const Edge& thing) {
return thing.is_valid();
}
std::vector<std::optional<InputMetadata>> collect_input_metadata(
const edge_list& edges) {
std::vector<std::optional<InputMetadata>> input_metadata;
for (const auto& edge : edges) {
if (!edge.is_valid()) {
input_metadata.emplace_back(std::nullopt);
continue;
}
input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr));
}
return input_metadata;
}
// Given an vector<Edge> or vector<optional<InputMetdata>>, validate the
// outputs. This involves using the InputMetadata to check the outputs and also
// potentially calling .sum_to on the outputs.

View File

@ -47,6 +47,8 @@ TORCH_API void validate_outputs(
const std::vector<std::optional<InputMetadata>>& input_metadata,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
TORCH_API std::vector<std::optional<InputMetadata>> collect_input_metadata(
const edge_list& edges);
struct NodeTask {
std::weak_ptr<GraphTask> base_;

View File

@ -34,8 +34,12 @@ using tensor_list = std::vector<at::Tensor>;
using variable_list = std::vector<Variable>;
using edge_list = std::vector<Edge>;
using saved_variable_list = std::vector<SavedVariable>;
using ivalue_list = std::vector<c10::IValue>;
using functional_apply_t = std::function<
variable_list(const variable_list&, const std::vector<c10::IValue>&)>;
using IndexRange = std::pair<size_t, size_t>;
using torch::dynamo::autograd::CompiledNodeArgs;
using torch::dynamo::autograd::PackedArgs;
using torch::dynamo::autograd::SwapSavedVariables;
// Custom deleter to prevent stack overflows.
@ -604,6 +608,12 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
std::string("apply_with_saved not implemented: ") + name());
}
// If this node is the AOTBackward node produced by torch.compile.
// Compiled Autograd special-cases on this information.
virtual bool is_aot_backward() const {
return false;
}
protected:
/// Performs the `Node`'s actual operation.
virtual variable_list apply(variable_list&& inputs) = 0;

View File

@ -8,6 +8,7 @@
namespace torch::dynamo::autograd {
class CompiledNodeArgs;
class SwapSavedVariables;
struct PackedArgs;
} // namespace torch::dynamo::autograd
// A hook that's called on gradients

View File

@ -288,6 +288,11 @@ auto PyNode::name() const -> std::string {
return name;
}
bool PyNode::is_aot_backward() const {
py::handle handle(obj);
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
}
auto PyNode::compiled_autograd_should_lift() const -> bool {
pybind11::gil_scoped_acquire gil;
static PyObject* attr_name =

View File

@ -43,6 +43,8 @@ struct PyNode : public Node {
std::string name() const override;
bool is_traceable() override;
bool is_aot_backward() const override;
void compiled_args(CompiledNodeArgs& args) override;
variable_list apply_with_saved(
const variable_list& inputs,

View File

@ -0,0 +1,27 @@
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
namespace torch::dynamo::autograd {
std::unique_ptr<PyCompilerInterface> kPyCompilerInterface;
const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface() {
TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr);
return kPyCompilerInterface;
}
void setPyCompilerInterface(std::unique_ptr<PyCompilerInterface>&& impl) {
TORCH_INTERNAL_ASSERT(impl != nullptr);
kPyCompilerInterface = std::move(impl);
}
void resetPyCompilerInterface() {
kPyCompilerInterface.reset();
}
std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges) {
return torch::autograd::collect_input_metadata(edges);
}
} // namespace torch::dynamo::autograd

View File

@ -900,6 +900,506 @@ class SwapSavedVariables {
StashedVars<at::IValue> stashed_ivalues;
};
// NOTE: [Compiled Autograd and backward functions]
// Built-in autograd nodes have functional apply variants
// (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph
// capture wants to take a variant of this function and proxy it into the graph.
// Every autograd node defines an apply_with_saved function, that when invoked,
// proxys a call to a function into the Compiled Autograd graph.
//
// Some requirements that we have are:
// - The proxy'ed function must have inputs that are FX-graphable types.
// - Windows has a DLL symbol limit of 65536.
// - Node::apply_with_saved is in libtorch_cpu which does not have direct access
// to Python
//
// There were multiple ways to skin the cat, but what we end up doing is:
// - for e.g. MulBackward0_apply_functional, we create a new C++ function
// MulBackward0_apply_functional_ivalue that accepts vector<IValue>.
// - We define how to pack and unpack arbitrary C++ types into IValues.
// - apply_with_saved passes MulBackward0_apply_functional_ivalue and
// the IValue arguments to Python via an indirection.
// In Python, these get proxy'ed into a graph.
// Helper struct for packing/unpacking an arbitrary C++ type into a single
// IValue. There are various full and partial specializations for IValuePacker
// to handle packing specific types (like TensorOptions) into an IValue.
template <typename T>
struct IValuePacker {
// Defines how to pack T into an IValue.
static at::IValue pack(const T& t) {
return t;
}
// Defines how to unpack an IValue into T.
static T unpack(const at::IValue& t) {
return t.to<T>();
}
// Returns the TypePtr for the IValue (this is like the "type" of the IValue).
// We use this when passing the packed IValue from Python to C++.
// In Python, the IValue is just a PyObject* with the native type.
// For example, it may be a Python int, a Python List[int], etc.
// When passing this PyObject* into C++, we need to know how to parse it
// into a C++ type that then gets put into an IValue.
// That's what the TypePtr is for: it contains the information to do the
// parsing. See torch::jit::toIValue for more information.
static at::TypePtr packed_type() {
if constexpr (::std::is_same_v<T, at::Tensor>) {
return at::TensorType::get();
} else if constexpr (::std::is_same_v<T, int64_t>) {
return at::IntType::get();
} else if constexpr (::std::is_same_v<T, c10::SymInt>) {
return at::SymIntType::get();
} else if constexpr (::std::is_same_v<T, bool>) {
return at::BoolType::get();
} else if constexpr (::std::is_same_v<T, double>) {
return at::FloatType::get();
} else if constexpr (::std::is_same_v<T, c10::SymFloat>) {
return at::SymFloatType::get();
} else if constexpr (::std::is_same_v<T, c10::SymBool>) {
return at::SymBoolType::get();
} else if constexpr (::std::is_same_v<T, c10::Layout>) {
return at::LayoutType::get();
} else if constexpr (::std::is_same_v<T, ::std::string>) {
return at::StringType::get();
} else if constexpr (::std::is_same_v<T, at::Device>) {
return at::DeviceObjType::get();
} else if constexpr (::std::is_same_v<T, at::Scalar>) {
return at::NumberType::get();
} else if constexpr (::std::is_same_v<T, at::MemoryFormat>) {
return at::MemoryFormatType::get();
} else if constexpr (::std::is_same_v<T, at::ScalarType>) {
return at::ScalarTypeType::get();
} else {
// If you got here, you have probably added a member of a new type
// to a built-in C++ autograd node.
// Unfortunately, we don't know how to handle this type yet.
// To get this new type to work with Compiled Autograd, please
// either change it to be an IValue-constructible type, or
// define how to pack and unpack an object of this time into an IValue
// by creating a specialization of IValuePacker for this type.
// See NOTE: [Compiled Autograd and backward functions] for context.
TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type");
return at::NoneType::get();
}
}
};
template <>
struct IValuePacker<size_t> {
static at::IValue pack(const size_t& t) {
// We generally use size_t as the size of a list of Tensors or number of
// dimensions. The number of dimensions generally do not exceed 64
// (TensorIterator has that limitation), and lists of Tensors generally do
// not exceed the int64_t max (you'd probably run out of RAM or run into
// significant Tensor overhead). If you run into this limitation the fix is
// to figure out how to pack size_t into int64_t. Note that size_t has some
// weird behavior on Mac OS.
uint64_t maximum_value = std::numeric_limits<int64_t>::max();
TORCH_INTERNAL_ASSERT(
static_cast<uint64_t>(t) <= maximum_value,
"size_t too large to pack into IValue");
return static_cast<int64_t>(t); // pack as int64_t
}
static size_t unpack(const at::IValue& t) {
return static_cast<size_t>(t.toInt());
}
static at::TypePtr packed_type() {
return IValuePacker<int64_t>::packed_type();
}
};
template <>
struct IValuePacker<std::vector<at::SymInt>> {
static at::IValue pack(const std::vector<at::SymInt>& t) {
return t;
}
static std::vector<at::SymInt> unpack(const at::IValue& t) {
// We need this because there's no t.to<std::vector<at::SymInt>>() override?
return t.toSymIntVector();
}
static at::TypePtr packed_type() {
return at::ListType::create(at::SymIntType::get());
}
};
template <>
struct IValuePacker<VariableInfo> {
static at::IValue pack(const VariableInfo& t) {
auto tuple = std::make_tuple(
t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty);
return tuple;
}
static VariableInfo unpack(const at::IValue& t) {
auto tuple = t.to<std::tuple<
at::Layout,
at::Device,
at::ScalarType,
std::vector<at::SymInt>,
bool,
bool>>();
VariableInfo v;
v.layout = std::get<0>(tuple);
v.device = std::get<1>(tuple);
v.scalar_type = std::get<2>(tuple);
v.size = std::get<3>(tuple);
v.requires_grad = std::get<4>(tuple);
v.is_empty = std::get<5>(tuple);
return v;
}
static at::TypePtr packed_type() {
return at::TupleType::create({
at::LayoutType::get(),
at::DeviceObjType::get(),
at::ScalarTypeType::get(),
at::ListType::create(at::SymIntType::get()),
at::BoolType::get(),
at::BoolType::get(),
});
}
};
template <>
struct IValuePacker<caffe2::TypeMeta> {
static at::IValue pack(const caffe2::TypeMeta& t) {
return at::typeMetaToScalarType(t); // pack as at::ScalarType
}
static caffe2::TypeMeta unpack(const at::IValue& t) {
return caffe2::TypeMeta::fromScalarType(t.to<at::ScalarType>());
}
static at::TypePtr packed_type() {
return IValuePacker<at::ScalarType>::packed_type();
}
};
inline std::optional<at::ScalarType> optTypeMetaToScalarType(
const std::optional<caffe2::TypeMeta>& t) {
if (t.has_value()) {
return at::typeMetaToScalarType(t.value());
} else {
return std::nullopt;
}
}
using packed_tensoroptions_t = std::tuple<
std::optional<bool>,
std::optional<at::MemoryFormat>,
std::optional<at::Device>,
std::optional<at::ScalarType>,
std::optional<at::Layout>,
std::optional<bool>>;
inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) {
auto tuple = std::make_tuple(
t.requires_grad_opt(),
t.memory_format_opt(),
t.device_opt(),
optTypeMetaToScalarType(t.dtype_opt()),
t.layout_opt(),
t.pinned_memory_opt());
return tuple;
}
inline at::TensorOptions unpack_TensorOptions(
const packed_tensoroptions_t& tuple) {
at::TensorOptions result;
auto maybe_requires_grad = std::get<0>(tuple);
if (maybe_requires_grad.has_value()) {
result = result.requires_grad(maybe_requires_grad.value());
}
auto maybe_memory_format = std::get<1>(tuple);
if (maybe_memory_format.has_value()) {
result = result.memory_format(maybe_memory_format.value());
}
auto maybe_device = std::get<2>(tuple);
if (maybe_device.has_value()) {
result = result.device(maybe_device.value());
}
auto maybe_dtype = std::get<3>(tuple);
if (maybe_dtype.has_value()) {
result =
result.dtype(caffe2::TypeMeta::fromScalarType(maybe_dtype.value()));
}
auto maybe_layout = std::get<4>(tuple);
if (maybe_layout.has_value()) {
result = result.layout(maybe_layout.value());
}
auto maybe_pinned_memory = std::get<5>(tuple);
if (maybe_pinned_memory.has_value()) {
result = result.pinned_memory(maybe_pinned_memory.value());
}
return result;
}
template <>
struct IValuePacker<at::TensorOptions> {
static at::IValue pack(const at::TensorOptions& t) {
return pack_TensorOptions(t);
}
static at::TensorOptions unpack(const at::IValue& t) {
auto tuple = t.to<packed_tensoroptions_t>();
return unpack_TensorOptions(tuple);
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{at::OptionalType::create(at::BoolType::get()),
at::OptionalType::create(at::MemoryFormatType::get()),
at::OptionalType::create(at::DeviceObjType::get()),
at::OptionalType::create(at::ScalarTypeType::get()),
at::OptionalType::create(at::LayoutType::get()),
at::OptionalType::create(at::BoolType::get())});
}
};
template <>
struct IValuePacker<TypeAndSize> {
static at::IValue pack(const TypeAndSize& t) {
auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options));
return tuple;
}
static TypeAndSize unpack(const at::IValue& t) {
auto tuple =
t.to<std::tuple<std::vector<at::SymInt>, packed_tensoroptions_t>>();
TypeAndSize result;
result.sym_sizes = std::get<0>(tuple);
result.options = unpack_TensorOptions(std::get<1>(tuple));
return result;
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{IValuePacker<std::vector<at::SymInt>>::packed_type(),
IValuePacker<at::TensorOptions>::packed_type()});
}
};
template <typename T>
struct IValuePacker<std::optional<T>> {
static at::IValue pack(const std::optional<T>& t) {
if (t.has_value()) {
return IValuePacker<T>::pack(t.value());
} else {
return std::nullopt;
}
}
static std::optional<T> unpack(const at::IValue& t) {
if (t.isNone()) {
return std::nullopt;
} else {
return IValuePacker<T>::unpack(t);
}
}
static at::TypePtr packed_type() {
return at::OptionalType::create(IValuePacker<T>::packed_type());
}
};
template <typename T>
struct IValuePacker<std::vector<T>> {
static at::IValue pack(const std::vector<T>& t) {
if constexpr (::std::is_constructible_v<at::IValue, T>) {
return t;
}
if (t.empty()) {
auto lst = c10::impl::GenericList(at::AnyType::get());
return lst;
}
auto type_ptr = IValuePacker<T>::pack(t[0]).type();
auto lst = c10::impl::GenericList(type_ptr);
for (const auto& elt : t) {
lst.emplace_back(IValuePacker<T>::pack(elt));
}
return lst;
}
static std::vector<T> unpack(const at::IValue& t) {
if constexpr (::std::is_constructible_v<at::IValue, T>) {
return t.to<::std::vector<T>>();
}
std::vector<T> result;
auto lst = t.toList();
for (const at::IValue& elt : lst) {
result.emplace_back(IValuePacker<T>::unpack(elt));
}
return result;
}
static at::TypePtr packed_type() {
return at::ListType::create(IValuePacker<T>::packed_type());
}
};
template <typename T>
struct IValuePacker<c10::List<T>> {
static at::IValue pack(const c10::List<T>& t) {
return IValuePacker<std::vector<T>>::pack(t.vec());
}
static c10::List<T> unpack(const at::IValue& t) {
return c10::List<T>(IValuePacker<std::vector<T>>::unpack(t));
}
static at::TypePtr packed_type() {
return IValuePacker<std::vector<T>>::packed_type();
}
};
template <size_t N>
struct IValuePacker<std::array<bool, N>> {
static at::IValue pack(const std::array<bool, N>& t) {
std::vector<bool> result(t.begin(), t.end());
return IValuePacker<std::vector<bool>>::pack(result);
}
static std::array<bool, N> unpack(const at::IValue& t) {
std::array<bool, N> result;
auto packed = IValuePacker<std::vector<bool>>::unpack(t);
for (size_t i = 0; i < packed.size(); i++) {
result[i] = packed[i];
}
return result;
}
static at::TypePtr packed_type() {
return IValuePacker<std::vector<bool>>::packed_type();
}
};
template <>
struct IValuePacker<at::TensorGeometry> {
static at::IValue pack(const at::TensorGeometry& t) {
auto tuple = std::make_tuple(
t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset());
return tuple;
}
static at::TensorGeometry unpack(const at::IValue& t) {
auto tuple = t.to<std::tuple<
std::vector<at::SymInt>,
std::vector<at::SymInt>,
at::SymInt>>();
return at::TensorGeometry(
std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple));
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{IValuePacker<std::vector<at::SymInt>>::packed_type(),
IValuePacker<std::vector<at::SymInt>>::packed_type(),
at::SymIntType::get()});
}
};
template <>
struct IValuePacker<InputMetadata> {
static at::IValue pack(const InputMetadata& t) {
TORCH_INTERNAL_ASSERT(!t.is_nested_tensor());
auto tuple = std::make_tuple(
pack_TensorOptions(t.options()),
t.shape_as_dim_vector().vec(),
t.is_tensor_subclass());
return tuple;
}
static InputMetadata unpack(const at::IValue& t) {
auto tuple = t.to<
std::tuple<packed_tensoroptions_t, std::vector<at::SymInt>, bool>>();
return InputMetadata(
unpack_TensorOptions(std::get<0>(tuple)),
SymIntSmallVec(std::get<1>(tuple)),
std::get<2>(tuple),
false);
}
static at::TypePtr packed_type() {
return at::TupleType::create(
{IValuePacker<at::TensorOptions>::packed_type(),
IValuePacker<std::vector<at::SymInt>>::packed_type(),
at::BoolType::get()});
}
};
template <typename T>
struct IValuePacker<at::OptionalArray<T>> {
static at::IValue pack(const at::OptionalArray<T>& t) {
return IValuePacker<std::optional<std::vector<T>>>::pack(t.list);
}
static at::OptionalArray<T> unpack(const at::IValue& t) {
auto result = IValuePacker<std::optional<std::vector<T>>>::unpack(t);
if (result.has_value()) {
return {result.value()};
} else {
return {};
}
}
static at::TypePtr packed_type() {
return IValuePacker<std::optional<std::vector<T>>>::packed_type();
}
};
// This is a helper struct for packing and unpacking multiple arguments into
// an ivalue_list. It leverages IValuePacker<T>.
struct PackedArgs {
PackedArgs() = default;
explicit PackedArgs(std::vector<at::IValue> stack_)
: stack(std::move(stack_)) {}
std::vector<at::IValue> vec() && {
return std::move(stack);
}
template <typename T>
void pack(const T& t) {
stack.emplace_back(IValuePacker<T>::pack(t));
}
template <typename T>
T unpack() {
return IValuePacker<T>::unpack(std::move(stack[idx++]));
}
private:
std::vector<at::IValue> stack;
int64_t idx = 0;
};
// This is a layer of indirection for calling methods on the Python
// AutogradCompilerInstance (referred to as the "py_compiler") from
// libtorch_cpu (where Python is not available).
// A PyCompilerInterfaceImpl in libtorch_python subclasses it and
// overrides the methods to do the actual calls back to Python.
struct TORCH_API PyCompilerInterface {
PyCompilerInterface() = default;
PyCompilerInterface(const PyCompilerInterface&) = delete;
PyCompilerInterface& operator=(const PyCompilerInterface&) = delete;
PyCompilerInterface(PyCompilerInterface&&) = delete;
PyCompilerInterface& operator=(PyCompilerInterface&&) = delete;
virtual ~PyCompilerInterface() = default;
// Invokes py_compiler.bind_function(fn_name, fn)
virtual void bind_function(
PyObject* py_compiler,
const std::string& fn_name,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
functional_apply_t fn,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::TypePtr> packed_args_schema) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
// output_metadata)
virtual variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
}
};
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
TORCH_API void setPyCompilerInterface(
std::unique_ptr<PyCompilerInterface>&& impl);
TORCH_API void resetPyCompilerInterface();
// including torch/csrc/autograd/engine.h breaks BC by somehow introducing
// symbol resolution issues. Instead requiring downstream users to include
// engine.h to access collect_input_metadata, we provide it here (with a
// different name to avoid ambigous symbols...)
TORCH_API std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges);
} // namespace torch::dynamo::autograd
template <>

View File

@ -52,6 +52,118 @@ Notes:
namespace torch::dynamo::autograd {
using c10::SymInt;
// List[Optional[Tensor]] in Python can't be directly parsed into a
// List[Tensor], so we need to do this conversion manually.
static std::vector<at::Tensor> toTensorList(
const std::vector<std::optional<at::Tensor>>& inputs) {
std::vector<at::Tensor> result;
result.reserve(inputs.size());
for (const auto& inp : inputs) {
if (inp.has_value()) {
result.emplace_back(*inp);
} else {
result.emplace_back();
}
}
return result;
}
// Binds a function (that represents some backward computation) to Python.
// All of these functions have a common signature, which is
// (in C++) (vector<Tensor>, vector<ivalue>) -> vector<Tensor>
// (in Python) (List[Optional[Tensor]], *packed_args: IValue) ->
// List[Optional[Tensor]]
//
// The vector<Tensor> are the list of gradient Tensors, each of which may be
// undefined (in C++) which corresponds to None (in Python).
static void bind_function(
PyObject* py_compiler,
const std::string& fn_name,
functional_apply_t fn,
std::vector<at::TypePtr> packed_args_schema) {
// This is the function that can be called from Python.
auto py_func = py::cpp_function(
[packed_args_schema = std::move(packed_args_schema), fn = std::move(fn)](
std::vector<std::optional<at::Tensor>>& inputs,
const py::args& py_args) -> py::object {
// py_args is a tuple of PyObject*.
// We need to reconstruct a vector<IValue> to invoke `fn`.
// To do so, we use the packed_args_schema to convert each PyObject*
// to its corresponding C++ type that can be stored into IValue.
TORCH_INTERNAL_ASSERT(py_args.size() == packed_args_schema.size());
std::vector<at::IValue> args;
args.reserve(py_args.size());
auto tuple_args = jit::tuple_slice(py_args);
for (uint64_t idx = 0; idx < packed_args_schema.size(); idx++) {
args.emplace_back(jit::toIValue(
tuple_args[idx], packed_args_schema[idx], std::nullopt));
}
// None in Python corresponds to undefined Tensor in C++
auto inputs_ = toTensorList(inputs);
auto outputs = fn(inputs_, args);
return jit::toPyObject(at::IValue(outputs));
});
py::handle handle(py_compiler);
handle.attr("bind_function")(fn_name, py_func);
}
// Invokes py_compiler.method_name(fn_name, inputs, packed_args,
// output_metadata)
static variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) {
// convert ivalue_list -> PyObject*
PyObject* py_packed_args =
PyTuple_New(static_cast<Py_ssize_t>(packed_args.size()));
for (const auto i : c10::irange(packed_args.size())) {
py::object obj = jit::toPyObject(packed_args[i]);
Py_INCREF(obj.ptr());
PyTuple_SET_ITEM(py_packed_args, i, obj.ptr());
}
// call the corresponding method on the py_compiler
py::handle handle(py_compiler);
py::object stuff = handle.attr(method_name)(
fn_name,
inputs,
py::handle(py_packed_args),
jit::toPyObject(output_metadata));
// Convert the output from PyObject* to vector<Tensor>
auto tmp = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
return toTensorList(tmp);
}
struct PyCompilerInterfaceImpl : PyCompilerInterface {
void bind_function(
PyObject* py_compiler,
const std::string& fn_name,
functional_apply_t fn,
std::vector<at::TypePtr> packed_args_schema) override {
return torch::dynamo::autograd::bind_function(
py_compiler, fn_name, std::move(fn), std::move(packed_args_schema));
}
variable_list call_function(
PyObject* py_compiler,
const char* method_name,
const std::string& fn_name,
const variable_list& inputs,
const ivalue_list& packed_args,
const c10::IValue& output_metadata) override {
return torch::dynamo::autograd::call_function(
py_compiler,
method_name,
fn_name,
inputs,
packed_args,
output_metadata);
}
};
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
for (const auto i : c10::irange(inputs.size())) {
@ -88,6 +200,22 @@ static void check(bool result) {
check(nullptr);
}
static variable_list validate_outputs(
const variable_list& outputs,
const ivalue_list& args) {
auto r = PackedArgs(args);
auto value = r.unpack<std::vector<std::optional<InputMetadata>>>();
auto new_outputs = outputs;
torch::autograd::validate_outputs(
value, new_outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "[Compiled Autograd Tracing:]" << msg;
return ss.str();
});
return new_outputs;
}
// snapshot of python verbose logging toggle
static PyObject* python_verbose_logger = nullptr;
@ -657,6 +785,8 @@ static CacheNode* _compiled_autograd_impl(
ClosingTHPObjectPtr py_compiler(
check(PyObject_CallNoArgs((the_autograd_compiler))));
setPyCompilerInterface(std::make_unique<PyCompilerInterfaceImpl>());
TraceState state = call_begin_capture(
py_compiler, *cache, compiler_call, output_edges.size());
InputBuffers input_buffers;
@ -723,16 +853,48 @@ static CacheNode* _compiled_autograd_impl(
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
variable_list outputs = call.node->apply_with_saved(inputs, saved);
saved.debug_asserts();
saved.before(call.node->next_edges());
validate_outputs(
call.node->next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
<< msg;
return ss.str();
});
auto input_metadata = get_input_metadata(call.node->next_edges());
TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size());
// Lazily bind the `validate_outputs` function to Python.
static c10::once_flag flag;
c10::call_once(flag, [&]() {
auto schema = std::vector<at::TypePtr>{IValuePacker<
std::vector<std::optional<InputMetadata>>>::packed_type()};
bind_function(
py_compiler.get(), "validate_outputs", validate_outputs, schema);
});
// Don't emit validate_outputs nodes that follow a CompiledBackward node.
// These nodes would otherwise prevent reordering of accumulate_grad
// nodes.
//
// Note that this will not cause correctness issues, because
// 1) AOTAutograd already coerces gradients to have the same metadata as
// the inputs. 2) the AOTAutograd graph already has the necessary
// aten::sum_to nodes in it (so it doesn't need to rely on
// validate_outputs to handle that).
//
// However, we may be dropping some (edge case) safety checks compared to
// eager: a backward that would have errored out in eager may not error
// out in compiled autograd (for example, if the user provided an
// incorrect number of gradients).
if (!call.node->is_aot_backward()) {
PackedArgs args;
args.pack(input_metadata);
ivalue_list input_metadata_state = std::move(args).vec();
outputs = call_function(
py_compiler,
"validate_outputs",
"validate_outputs",
outputs,
input_metadata_state,
input_metadata_state[0]);
}
saved.after(call.node->next_edges());
saved.debug_asserts();
@ -761,6 +923,7 @@ static CacheNode* _compiled_autograd_impl(
}
}
resetPyCompilerInterface();
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
TORCH_CHECK(