From 5531fafffefc45cd894040b2b07b0d5227430082 Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 22 Jan 2025 07:08:05 -0800 Subject: [PATCH] [compiled autograd] Proxy opaque nodes for built-in autograd nodes (#143296) This PR is on the way to getting compiled autograd's initial capture to stop specializing on Tensor metadata. This PR changes compiled autograd's initial capture to proxy an opaque (w.r.t. Dynamo) function into the graph for all built-in codegen'ed autograd nodes and validate_outputs. We changed each codegen'ed apply_with_saved (e.g. MulBackward0::apply_with_saved) to call into Python to proxy a function (compiled_autograd.ops.MulBackward0) into the graph. Then, we use the node's InputMetadata to "guess" at the properties of the output Tensors to create some new FakeTensors. Some details: - MulBackward0::apply_with_saved lives in libtorch_cpu, but needs to be call to Python via libtorch_python. There is an indirection (PyCompilerInterface) to do this. - MulBackward0::apply_with_saved passes a C++ function to Python. To make our lives easier, every codegen'ed apply_with_saved passes a C++ function with the same signature `(variable_list, ivalue_list) -> variable_list`. - We define how to pack arbitrary C++ types into IValue via a helper IValuePacker struct and codegen functional variants of each builtin C++ autograd node (e.g. MulBackward0_apply_functional_ivalue). MulBackward0 before this PR: https://gist.github.com/zou3519/a80381d5fa38e970e413fcd91b0530de MulBackward0 after this PR: https://gist.github.com/zou3519/0c2eee8b3d8d96232b51ef430b53c5b0 Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/143296 Approved by: https://github.com/jansel --- aten/src/ATen/TensorGeometry.h | 10 + build_variables.bzl | 1 + test/dynamo/test_backward_higher_order_ops.py | 40 +- test/inductor/test_compiled_autograd.py | 9 +- test/inductor/test_distributed_patterns.py | 4 +- tools/autograd/gen_autograd_functions.py | 72 ++- torch/_dynamo/compiled_autograd.py | 132 ++++- torch/csrc/autograd/engine.cpp | 13 + torch/csrc/autograd/engine.h | 2 + torch/csrc/autograd/function.h | 10 + torch/csrc/autograd/function_hook.h | 1 + torch/csrc/autograd/python_function.cpp | 5 + torch/csrc/autograd/python_function.h | 2 + torch/csrc/dynamo/compiled_autograd.cpp | 27 + torch/csrc/dynamo/compiled_autograd.h | 500 ++++++++++++++++++ .../csrc/dynamo/python_compiled_autograd.cpp | 179 ++++++- 16 files changed, 961 insertions(+), 46 deletions(-) create mode 100644 torch/csrc/dynamo/compiled_autograd.cpp diff --git a/aten/src/ATen/TensorGeometry.h b/aten/src/ATen/TensorGeometry.h index 41f14a15ba99..06a064063c4e 100644 --- a/aten/src/ATen/TensorGeometry.h +++ b/aten/src/ATen/TensorGeometry.h @@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry { has_symbolic_sizes_strides_( t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} + explicit TensorGeometry( + std::vector sizes, + std::vector strides, + at::SymInt storage_offset) + : sizes_(std::move(sizes)), + strides_(std::move(strides)), + storage_offset_(std::move(storage_offset)) { + recompute(); + } + // true if the tensor is contiguous bool is_contiguous() const; diff --git a/build_variables.bzl b/build_variables.bzl index 8bd8ad3a8df0..a95c03cd0b34 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -138,6 +138,7 @@ core_trainer_sources = [ "torch/csrc/autograd/variable.cpp", "torch/csrc/autograd/utils/warnings.cpp", "torch/csrc/autograd/jit_decomp_interface.cpp", + "torch/csrc/dynamo/compiled_autograd.cpp", "torch/csrc/jit/frontend/name_mangler.cpp", "torch/csrc/jit/ir/type_hashing.cpp", "torch/csrc/jit/serialization/pickler.cpp", diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 14e3f2e044c1..6c5cd6a9f25f 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -121,23 +121,27 @@ class _multiply_invoke(torch.nn.Module): out.backward(grad_out) actual = normalize_gm(graph.print_readable(False)) self.assertEqual(x.grad, grad_out * grad_out) - self.assertExpectedInline( - actual, - """\ + if backend in ["aot_eager", "inductor"]: + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_ : list): l_inputs_ = L_inputs_ - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None - new_grad: "f32[s0]" = torch.clone(getitem) + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None + getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None - result: "f32[s0]" = getitem * getitem; getitem = None + new_grad: "f32[2]" = torch.clone(getitem_3) - new_grad_1: "f32[s0]" = torch.clone(result); result = None + result: "f32[2]" = getitem_3 * getitem_3; getitem_3 = None + + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1) """, - ) + ) graph = None @@ -187,26 +191,30 @@ class GraphModule(torch.nn.Module): actual = normalize_gm(graph.print_readable(False)) self.assertEqual(obj.counter, 1) self.assertEqual(x.grad, grad_out + grad_out) - self.assertExpectedInline( - actual, - """\ + if backend in ["aot_eager", "inductor"]: + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"): l_inputs_ = L_inputs_ l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter - getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None - new_grad: "f32[s0]" = torch.clone(getitem) + validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None + getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None + + new_grad: "f32[2]" = torch.clone(getitem_3) add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None - result: "f32[s0]" = getitem * getitem; getitem = None + result: "f32[2]" = getitem_3 * getitem_3; getitem_3 = None - new_grad_1: "f32[s0]" = torch.clone(result); result = None + new_grad_1: "f32[2]" = torch.clone(result); result = None return (new_grad, new_grad_1, add) """, - ) + ) out = fn(x, y) out.backward(grad_out) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index d916c4186a3c..2a2b75690098 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2924,7 +2924,6 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { "aot0_le", "aot0_permute_2", "code: CompiledFunctionBackward0 (NodeCall 2)", - "aot0_tangents_1", "aot0_full_default", "aot0_where", "aot0_mm", @@ -2974,20 +2973,17 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { expected_logs = [ "CompiledFunctionBackward1", - "aot1_tangents_1", "aot1_sin_1", - "aot1_primals_2", "aot1_neg", "aot0_tangents_2", "aot1_cos_1", - "aot1_primals_1", "aot0_tangents_1", "CompiledFunctionBackward0", + "aot0_sin_1", "aot0_neg", - "aot0_sin", "aot0_mul", + "aot0_cos_1", "aot0_mul_1", - "aot0_cos", "aot0_add", ] @@ -3618,6 +3614,7 @@ known_failing_tests = { "test_tp_compile_comm_reordering", "test_unwrap_async_collective_tensor_tangent", # Uncategorized + "test_not_implemented_grad", # Dynamo changes the types of exceptions } if not HAS_CUDA: diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index 9d1a3202f147..ed50a02d7b27 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -337,7 +337,9 @@ class DistributedPatternTests(TestCase): self.assertEqual(fw_cnt.frame_count, 1) self.assertEqual(fw_cnt.op_count, 5) self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None - self.assertEqual(bw_cnt.op_count, 48) + self.assertEqual( + bw_cnt.op_count, 72 + ) # Number of ops in the Dynamo-produced graphs def test_module_backward_hooks_aot(self): m1, inp1 = init_module_bw_hooks(True) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index fa1d0ce4bc91..a2f9c60b2224 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -68,6 +68,7 @@ struct TORCH_API ${op} : public ${superclass} { } ${will_release_variables} void compiled_args(CompiledNodeArgs& args) override; + ivalue_list get_packed_args(); variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; ${saved_variables} ${saved_list_sizes} @@ -107,6 +108,13 @@ static variable_list ${op}_apply_functional( ${body} return grad_inputs; } +static variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args) +{ + auto packed_args = PackedArgs(args); + auto needs_input_grad = packed_args.unpack>(); + ${unpack_ivalues} + return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args}); +} variable_list ${op}::apply(variable_list&& grads) { ${thread_lock} @@ -120,11 +128,42 @@ void ${op}::compiled_args(CompiledNodeArgs& args) { ${compiled_args} } variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { - ${apply_with_saved_before} - variable_list result = apply(variable_list(grads)); - ${apply_with_saved_after} - return result; + ${apply_with_saved_before} + + static std::once_flag flag; + std::call_once(flag, [&](){ + ${compute_schema} + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + interface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema); + }); + + variable_list result; + auto packed_args = get_packed_args(); + auto output_metadata = torch::dynamo::autograd::IValuePacker< + std::vector>>::pack( + torch::dynamo::autograd::get_input_metadata(next_edges())); + const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); + result = interface->call_function( + saved.get_py_compiler(), + "apply_functional", + name(), + grads, + packed_args, + output_metadata); + + ${apply_with_saved_after} + return result; } +ivalue_list ${op}::get_packed_args() { + PackedArgs packed_args; + ${asserts} + ${unpacks} + ${compute_needs_input_grad} + packed_args.pack(needs_input_grad); + ${get_packed_args} + return std::move(packed_args).vec(); +} + """ ) @@ -993,14 +1032,38 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { f"{T} {x}" for T, x in zip(apply_functional_args_ref_types, apply_functional_args) ] + get_packed_args = "\n".join( + f"packed_args.pack({name});" for name in apply_functional_args + ) + unpack_ivalues = [] + for typ, name in zip(apply_functional_args_ref_types, apply_functional_args): + if typ.endswith("&"): + typ = typ[:-1] + unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();") + + schema_args = [f"std::array"] + for typ in apply_functional_args_ref_types: + if typ.endswith("&"): + typ = typ[:-1] + if typ.startswith("const"): + typ = typ[5:] + schema_args.append(typ.strip()) + compute_schema = ["std::vector schema = {"] + for schema_arg in schema_args: + compute_schema.append( + f" torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type()," + ) + compute_schema.append("};") return template.substitute( unpacks="\n".join(unpack), op=info.op, + compute_schema="\n".join(compute_schema), apply_functional_args=apply_functional_args, apply_functional_args_signature=apply_functional_args_signature, compute_needs_input_grad=compute_needs_input_grad, num_inputs=len(input_name_to_idx), + unpack_ivalues="\n".join(unpack_ivalues), compute_index_ranges=compute_index_ranges, saved_variables=saved_variables, release_variables=release_variables, @@ -1015,4 +1078,5 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { compiled_args=compiled_args, apply_with_saved_before=apply_with_saved_before, apply_with_saved_after=apply_with_saved_after, + get_packed_args=get_packed_args, ) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index fb7017bc6dc9..cd4db23e332f 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -8,6 +8,7 @@ from collections import defaultdict from typing import Any, Optional, TYPE_CHECKING, Union import torch +import torch.utils._pytree as pytree from torch._dynamo.external_utils import ( call_backward, call_hook, @@ -65,6 +66,39 @@ def maybe_clone(x): return x +# We lazily bind "functional backward" variants for PyTorch built-in autograd +# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0 +# Each "functional backward" is bound the first time the node's apply_with_saved +# function is called. It's possible to avoid lazy binding and instead bind +# all of this upfront (perhaps at import time) via codegen changes. +class OpNamespace: + def add(self, name, fn): + assert not hasattr(self, name) + result = Op(name, fn) + torch._dynamo.allow_in_graph(result) + setattr(self, name, result) + return result + + def get(self, name): + return getattr(self, name) + + +class Op: + def __init__(self, name, fn): + self.fn = fn + self.__name__ = name + self.__module__ = "torch._dynamo.compiled_autograd.ops" + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def __repr__(self): + return self.__module__ + "." + self.__name__ + + +ops = OpNamespace() + + _graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] _impure_targets = OrderedSet( [ @@ -137,7 +171,8 @@ class AutogradCompilerInstance: self.fx_tracer.root = torch.nn.Module() self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) self.fx_tracer.tensor_attrs = {} - args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = ( + self.symnode_proxy_lookup = {} + args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = ( self.fx_tracer.create_proxy("placeholder", name, (), {}) for name in _graph_placeholders ) @@ -160,7 +195,9 @@ class AutogradCompilerInstance: ) for idx, val in enumerate(sizes) ] - self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins) + self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins) + for i, symint in enumerate(sizes): + self.symnode_proxy_lookup[symint.node] = self.sizes_proxy[i] for idx, val in enumerate(scalars): source = self.source("scalars", idx) @@ -182,7 +219,9 @@ class AutogradCompilerInstance: ) else: raise AssertionError("Unexpected scalar type: ", type(val)) - self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins) + self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins) + for i, symval in enumerate(scalars): + self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr] # TODO(jansel): are all these modes needed? self.stack.enter_context(decompose({})) @@ -216,7 +255,6 @@ class AutogradCompilerInstance: ), kwargs={}, ) - with disable_proxy_modes_tracing(): # create fake Tensors grad_ins: list[Optional[torch.Tensor]] = [] @@ -232,6 +270,65 @@ class AutogradCompilerInstance: self.bind_tensors_to_proxies(grad_ins, proxies) return tuple(grad_ins) + # Guess what the outputs should be from the InputMetadata. + # This is not sound in general (we guess contiguous strides + # and no Tensor subclass-ness); we will stop guessing + # the output metadata in a follow-up. + def guess_output(self, input_metadata): + if input_metadata is None: + return None + tensoroptions, shape, _ = input_metadata + kwargs = {} + names = [ + "requires_grad", + "memory_format", + "device", + "dtype", + "layout", + "pinned_memory", + ] + for name, option in zip(names, tensoroptions): + if option is not None: + kwargs[name] = option + + with disable_proxy_modes_tracing(): + return torch.ops.aten.zeros(shape, **kwargs) + + def bind_function(self, fn_name, fn): + """Binds ops.fn_name = fn""" + ops.add(fn_name, fn) + + def apply_functional(self, fn_name, grads, args, output_metadata): + """Proxies a call to ops.fn_name(grads, *args) into the graph""" + op = ops.get(fn_name) + return self.proxy_call(op, (grads, *args), output_metadata) + + def proxy_call(self, fn, args, output_metadata): + """Proxies a call to fn(*args) into the graph""" + flat_args, _ = pytree.tree_flatten(args) + proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) + proxy_out = self.fx_tracer.create_proxy( + "call_function", fn, args=proxy_args, kwargs={} + ) + result = [self.guess_output(metadata) for metadata in output_metadata] + self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(len(result))]) + return result + + def validate_outputs(self, _, outputs, args, output_metadata): + """Proxies a call to ops.validate_outputs(outputs, *args) into the graph""" + op = ops.get("validate_outputs") + proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args)) + new_proxy_outputs = self.fx_tracer.create_proxy( + "call_function", op, args=proxy_args, kwargs={} + ) + assert len(output_metadata) == len(outputs) + outputs = [ + None if output is None or metadata is None else self.guess_output(metadata) + for output, metadata in zip(outputs, output_metadata) + ] + self.bind_tensors_to_proxies(outputs, new_proxy_outputs) + return outputs + def proxy_call_hook(self, hook, *args, **kwargs): return self.fx_tracer.create_proxy( "call_function", @@ -314,6 +411,7 @@ class AutogradCompilerInstance: assert nodes[first_getitem_idx] == inputs_users[0] last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 assert nodes[last_getitem_idx] == inputs_users[-1] + # getitem nodes on inputs for i, node in enumerate(inputs_users): if not has_cuda_inputs and node.meta["val"].device.type == "cuda": has_cuda_inputs = True @@ -323,9 +421,13 @@ class AutogradCompilerInstance: is_scalar = len(node.meta["val"].size()) == 0 if is_cpu and is_scalar: node_users = list(node.users.keys()) + # We can only move the cpu scalar if it is not exposed to user code. if all( - isinstance(user.target, torch._ops.OpOverload) - and user.target.namespace in ("prims", "aten") + ( + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") + ) + or isinstance(user.target, Op) for user in node_users ): # all users are prims/aten, can move safely @@ -335,6 +437,7 @@ class AutogradCompilerInstance: # this is to handle the case where cudagraphs is enabled on a cpu-only graph if has_cuda_inputs: for node in to_move.values(): + verbose_log.debug("Moving node %s from cpu to cuda", node) node.meta["val"] = node.meta["val"].cuda() # return runtime indices we need to move to cuda @@ -368,7 +471,10 @@ class AutogradCompilerInstance: or (node.op == "call_function" and node.target in _impure_targets) ) + before = len(self.fx_tracer.graph.nodes) self.fx_tracer.graph.eliminate_dead_code(is_impure) + after = len(self.fx_tracer.graph.nodes) + verbose_log.debug("DCE removed %d nodes", before - after) def end_capture(self, outputs): self.fx_tracer.create_proxy( @@ -384,6 +490,10 @@ class AutogradCompilerInstance: (self.fx_tracer.create_arg(self.to_proxy(outputs)),), {}, ) + runtime_inputs_to_move: list[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) + # TODO(rzou): the guessed metadata is incorrect, we will remove it at the end of the PR stack. self.rename_aot_dispatcher_nodes() self.reorder_tensor_pre_hook_nodes() self.reorder_pre_hook_nodes_to_schedule_asap() @@ -402,9 +512,6 @@ class AutogradCompilerInstance: # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and # should prevent these ops from going into the CA graph. self.dce() - runtime_inputs_to_move: list[int] = [] - if snapshot_cudagraph_enabled(): - runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) graph = GraphModule( self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}" @@ -778,8 +885,11 @@ class AutogradCompilerInstance: return [self.to_proxy(x) for x in t] if isinstance(t, tuple): return tuple(self.to_proxy(x) for x in t) - # can it be torch.SymInt as the code used to imply? - assert isinstance(t, torch.Tensor) + if isinstance(t, (torch.SymInt, torch.SymFloat)): + return self.symnode_proxy_lookup[t.node] + if not isinstance(t, torch.Tensor): + # constant types like device, dtype, str + return t proxy_tensor = fetch_object_proxy(self.fx_tracer, t) assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) return proxy_tensor.proxy diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 3b37b87391f8..53e8bcd06453 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -898,6 +898,19 @@ bool has_input_metadata(const Edge& thing) { return thing.is_valid(); } +std::vector> collect_input_metadata( + const edge_list& edges) { + std::vector> input_metadata; + for (const auto& edge : edges) { + if (!edge.is_valid()) { + input_metadata.emplace_back(std::nullopt); + continue; + } + input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr)); + } + return input_metadata; +} + // Given an vector or vector>, validate the // outputs. This involves using the InputMetadata to check the outputs and also // potentially calling .sum_to on the outputs. diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 4243f1b1d6ee..5bf00bac5378 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -47,6 +47,8 @@ TORCH_API void validate_outputs( const std::vector>& input_metadata, variable_list& grads, const std::function& format_error); +TORCH_API std::vector> collect_input_metadata( + const edge_list& edges); struct NodeTask { std::weak_ptr base_; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index ba2f6edbc6c0..abd11303eafe 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -34,8 +34,12 @@ using tensor_list = std::vector; using variable_list = std::vector; using edge_list = std::vector; using saved_variable_list = std::vector; +using ivalue_list = std::vector; +using functional_apply_t = std::function< + variable_list(const variable_list&, const std::vector&)>; using IndexRange = std::pair; using torch::dynamo::autograd::CompiledNodeArgs; +using torch::dynamo::autograd::PackedArgs; using torch::dynamo::autograd::SwapSavedVariables; // Custom deleter to prevent stack overflows. @@ -604,6 +608,12 @@ struct TORCH_API Node : std::enable_shared_from_this { std::string("apply_with_saved not implemented: ") + name()); } + // If this node is the AOTBackward node produced by torch.compile. + // Compiled Autograd special-cases on this information. + virtual bool is_aot_backward() const { + return false; + } + protected: /// Performs the `Node`'s actual operation. virtual variable_list apply(variable_list&& inputs) = 0; diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index 6342bf280a5c..4e8bba79a169 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -8,6 +8,7 @@ namespace torch::dynamo::autograd { class CompiledNodeArgs; class SwapSavedVariables; +struct PackedArgs; } // namespace torch::dynamo::autograd // A hook that's called on gradients diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 19151cbaafe6..abd8ff30e909 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -288,6 +288,11 @@ auto PyNode::name() const -> std::string { return name; } +bool PyNode::is_aot_backward() const { + py::handle handle(obj); + return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); +} + auto PyNode::compiled_autograd_should_lift() const -> bool { pybind11::gil_scoped_acquire gil; static PyObject* attr_name = diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 46faff8e4686..2f28c765ab06 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -43,6 +43,8 @@ struct PyNode : public Node { std::string name() const override; bool is_traceable() override; + bool is_aot_backward() const override; + void compiled_args(CompiledNodeArgs& args) override; variable_list apply_with_saved( const variable_list& inputs, diff --git a/torch/csrc/dynamo/compiled_autograd.cpp b/torch/csrc/dynamo/compiled_autograd.cpp new file mode 100644 index 000000000000..7e2aad576189 --- /dev/null +++ b/torch/csrc/dynamo/compiled_autograd.cpp @@ -0,0 +1,27 @@ +#include +#include + +namespace torch::dynamo::autograd { + +std::unique_ptr kPyCompilerInterface; + +const std::unique_ptr& getPyCompilerInterface() { + TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr); + return kPyCompilerInterface; +} + +void setPyCompilerInterface(std::unique_ptr&& impl) { + TORCH_INTERNAL_ASSERT(impl != nullptr); + kPyCompilerInterface = std::move(impl); +} + +void resetPyCompilerInterface() { + kPyCompilerInterface.reset(); +} + +std::vector> get_input_metadata( + const edge_list& edges) { + return torch::autograd::collect_input_metadata(edges); +} + +} // namespace torch::dynamo::autograd diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index 383cff14b8e0..b00ec6e00a40 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -900,6 +900,506 @@ class SwapSavedVariables { StashedVars stashed_ivalues; }; +// NOTE: [Compiled Autograd and backward functions] +// Built-in autograd nodes have functional apply variants +// (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph +// capture wants to take a variant of this function and proxy it into the graph. +// Every autograd node defines an apply_with_saved function, that when invoked, +// proxys a call to a function into the Compiled Autograd graph. +// +// Some requirements that we have are: +// - The proxy'ed function must have inputs that are FX-graphable types. +// - Windows has a DLL symbol limit of 65536. +// - Node::apply_with_saved is in libtorch_cpu which does not have direct access +// to Python +// +// There were multiple ways to skin the cat, but what we end up doing is: +// - for e.g. MulBackward0_apply_functional, we create a new C++ function +// MulBackward0_apply_functional_ivalue that accepts vector. +// - We define how to pack and unpack arbitrary C++ types into IValues. +// - apply_with_saved passes MulBackward0_apply_functional_ivalue and +// the IValue arguments to Python via an indirection. +// In Python, these get proxy'ed into a graph. + +// Helper struct for packing/unpacking an arbitrary C++ type into a single +// IValue. There are various full and partial specializations for IValuePacker +// to handle packing specific types (like TensorOptions) into an IValue. +template +struct IValuePacker { + // Defines how to pack T into an IValue. + static at::IValue pack(const T& t) { + return t; + } + // Defines how to unpack an IValue into T. + static T unpack(const at::IValue& t) { + return t.to(); + } + // Returns the TypePtr for the IValue (this is like the "type" of the IValue). + // We use this when passing the packed IValue from Python to C++. + // In Python, the IValue is just a PyObject* with the native type. + // For example, it may be a Python int, a Python List[int], etc. + // When passing this PyObject* into C++, we need to know how to parse it + // into a C++ type that then gets put into an IValue. + // That's what the TypePtr is for: it contains the information to do the + // parsing. See torch::jit::toIValue for more information. + static at::TypePtr packed_type() { + if constexpr (::std::is_same_v) { + return at::TensorType::get(); + } else if constexpr (::std::is_same_v) { + return at::IntType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymIntType::get(); + } else if constexpr (::std::is_same_v) { + return at::BoolType::get(); + } else if constexpr (::std::is_same_v) { + return at::FloatType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymFloatType::get(); + } else if constexpr (::std::is_same_v) { + return at::SymBoolType::get(); + } else if constexpr (::std::is_same_v) { + return at::LayoutType::get(); + } else if constexpr (::std::is_same_v) { + return at::StringType::get(); + } else if constexpr (::std::is_same_v) { + return at::DeviceObjType::get(); + } else if constexpr (::std::is_same_v) { + return at::NumberType::get(); + } else if constexpr (::std::is_same_v) { + return at::MemoryFormatType::get(); + } else if constexpr (::std::is_same_v) { + return at::ScalarTypeType::get(); + } else { + // If you got here, you have probably added a member of a new type + // to a built-in C++ autograd node. + // Unfortunately, we don't know how to handle this type yet. + // To get this new type to work with Compiled Autograd, please + // either change it to be an IValue-constructible type, or + // define how to pack and unpack an object of this time into an IValue + // by creating a specialization of IValuePacker for this type. + // See NOTE: [Compiled Autograd and backward functions] for context. + TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type"); + return at::NoneType::get(); + } + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const size_t& t) { + // We generally use size_t as the size of a list of Tensors or number of + // dimensions. The number of dimensions generally do not exceed 64 + // (TensorIterator has that limitation), and lists of Tensors generally do + // not exceed the int64_t max (you'd probably run out of RAM or run into + // significant Tensor overhead). If you run into this limitation the fix is + // to figure out how to pack size_t into int64_t. Note that size_t has some + // weird behavior on Mac OS. + uint64_t maximum_value = std::numeric_limits::max(); + TORCH_INTERNAL_ASSERT( + static_cast(t) <= maximum_value, + "size_t too large to pack into IValue"); + return static_cast(t); // pack as int64_t + } + static size_t unpack(const at::IValue& t) { + return static_cast(t.toInt()); + } + static at::TypePtr packed_type() { + return IValuePacker::packed_type(); + } +}; + +template <> +struct IValuePacker> { + static at::IValue pack(const std::vector& t) { + return t; + } + static std::vector unpack(const at::IValue& t) { + // We need this because there's no t.to>() override? + return t.toSymIntVector(); + } + static at::TypePtr packed_type() { + return at::ListType::create(at::SymIntType::get()); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const VariableInfo& t) { + auto tuple = std::make_tuple( + t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty); + return tuple; + } + static VariableInfo unpack(const at::IValue& t) { + auto tuple = t.to, + bool, + bool>>(); + VariableInfo v; + v.layout = std::get<0>(tuple); + v.device = std::get<1>(tuple); + v.scalar_type = std::get<2>(tuple); + v.size = std::get<3>(tuple); + v.requires_grad = std::get<4>(tuple); + v.is_empty = std::get<5>(tuple); + return v; + } + static at::TypePtr packed_type() { + return at::TupleType::create({ + at::LayoutType::get(), + at::DeviceObjType::get(), + at::ScalarTypeType::get(), + at::ListType::create(at::SymIntType::get()), + at::BoolType::get(), + at::BoolType::get(), + }); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const caffe2::TypeMeta& t) { + return at::typeMetaToScalarType(t); // pack as at::ScalarType + } + static caffe2::TypeMeta unpack(const at::IValue& t) { + return caffe2::TypeMeta::fromScalarType(t.to()); + } + static at::TypePtr packed_type() { + return IValuePacker::packed_type(); + } +}; + +inline std::optional optTypeMetaToScalarType( + const std::optional& t) { + if (t.has_value()) { + return at::typeMetaToScalarType(t.value()); + } else { + return std::nullopt; + } +} + +using packed_tensoroptions_t = std::tuple< + std::optional, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional>; + +inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) { + auto tuple = std::make_tuple( + t.requires_grad_opt(), + t.memory_format_opt(), + t.device_opt(), + optTypeMetaToScalarType(t.dtype_opt()), + t.layout_opt(), + t.pinned_memory_opt()); + return tuple; +} +inline at::TensorOptions unpack_TensorOptions( + const packed_tensoroptions_t& tuple) { + at::TensorOptions result; + auto maybe_requires_grad = std::get<0>(tuple); + if (maybe_requires_grad.has_value()) { + result = result.requires_grad(maybe_requires_grad.value()); + } + auto maybe_memory_format = std::get<1>(tuple); + if (maybe_memory_format.has_value()) { + result = result.memory_format(maybe_memory_format.value()); + } + auto maybe_device = std::get<2>(tuple); + if (maybe_device.has_value()) { + result = result.device(maybe_device.value()); + } + auto maybe_dtype = std::get<3>(tuple); + if (maybe_dtype.has_value()) { + result = + result.dtype(caffe2::TypeMeta::fromScalarType(maybe_dtype.value())); + } + auto maybe_layout = std::get<4>(tuple); + if (maybe_layout.has_value()) { + result = result.layout(maybe_layout.value()); + } + auto maybe_pinned_memory = std::get<5>(tuple); + if (maybe_pinned_memory.has_value()) { + result = result.pinned_memory(maybe_pinned_memory.value()); + } + return result; +} + +template <> +struct IValuePacker { + static at::IValue pack(const at::TensorOptions& t) { + return pack_TensorOptions(t); + } + static at::TensorOptions unpack(const at::IValue& t) { + auto tuple = t.to(); + return unpack_TensorOptions(tuple); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {at::OptionalType::create(at::BoolType::get()), + at::OptionalType::create(at::MemoryFormatType::get()), + at::OptionalType::create(at::DeviceObjType::get()), + at::OptionalType::create(at::ScalarTypeType::get()), + at::OptionalType::create(at::LayoutType::get()), + at::OptionalType::create(at::BoolType::get())}); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const TypeAndSize& t) { + auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options)); + return tuple; + } + static TypeAndSize unpack(const at::IValue& t) { + auto tuple = + t.to, packed_tensoroptions_t>>(); + TypeAndSize result; + result.sym_sizes = std::get<0>(tuple); + result.options = unpack_TensorOptions(std::get<1>(tuple)); + return result; + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker>::packed_type(), + IValuePacker::packed_type()}); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::optional& t) { + if (t.has_value()) { + return IValuePacker::pack(t.value()); + } else { + return std::nullopt; + } + } + static std::optional unpack(const at::IValue& t) { + if (t.isNone()) { + return std::nullopt; + } else { + return IValuePacker::unpack(t); + } + } + static at::TypePtr packed_type() { + return at::OptionalType::create(IValuePacker::packed_type()); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::vector& t) { + if constexpr (::std::is_constructible_v) { + return t; + } + if (t.empty()) { + auto lst = c10::impl::GenericList(at::AnyType::get()); + return lst; + } + auto type_ptr = IValuePacker::pack(t[0]).type(); + auto lst = c10::impl::GenericList(type_ptr); + for (const auto& elt : t) { + lst.emplace_back(IValuePacker::pack(elt)); + } + return lst; + } + static std::vector unpack(const at::IValue& t) { + if constexpr (::std::is_constructible_v) { + return t.to<::std::vector>(); + } + std::vector result; + auto lst = t.toList(); + for (const at::IValue& elt : lst) { + result.emplace_back(IValuePacker::unpack(elt)); + } + return result; + } + static at::TypePtr packed_type() { + return at::ListType::create(IValuePacker::packed_type()); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const c10::List& t) { + return IValuePacker>::pack(t.vec()); + } + static c10::List unpack(const at::IValue& t) { + return c10::List(IValuePacker>::unpack(t)); + } + static at::TypePtr packed_type() { + return IValuePacker>::packed_type(); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const std::array& t) { + std::vector result(t.begin(), t.end()); + return IValuePacker>::pack(result); + } + static std::array unpack(const at::IValue& t) { + std::array result; + auto packed = IValuePacker>::unpack(t); + for (size_t i = 0; i < packed.size(); i++) { + result[i] = packed[i]; + } + return result; + } + static at::TypePtr packed_type() { + return IValuePacker>::packed_type(); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const at::TensorGeometry& t) { + auto tuple = std::make_tuple( + t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset()); + return tuple; + } + static at::TensorGeometry unpack(const at::IValue& t) { + auto tuple = t.to, + std::vector, + at::SymInt>>(); + return at::TensorGeometry( + std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple)); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker>::packed_type(), + IValuePacker>::packed_type(), + at::SymIntType::get()}); + } +}; + +template <> +struct IValuePacker { + static at::IValue pack(const InputMetadata& t) { + TORCH_INTERNAL_ASSERT(!t.is_nested_tensor()); + auto tuple = std::make_tuple( + pack_TensorOptions(t.options()), + t.shape_as_dim_vector().vec(), + t.is_tensor_subclass()); + return tuple; + } + static InputMetadata unpack(const at::IValue& t) { + auto tuple = t.to< + std::tuple, bool>>(); + + return InputMetadata( + unpack_TensorOptions(std::get<0>(tuple)), + SymIntSmallVec(std::get<1>(tuple)), + std::get<2>(tuple), + false); + } + static at::TypePtr packed_type() { + return at::TupleType::create( + {IValuePacker::packed_type(), + IValuePacker>::packed_type(), + at::BoolType::get()}); + } +}; + +template +struct IValuePacker> { + static at::IValue pack(const at::OptionalArray& t) { + return IValuePacker>>::pack(t.list); + } + static at::OptionalArray unpack(const at::IValue& t) { + auto result = IValuePacker>>::unpack(t); + if (result.has_value()) { + return {result.value()}; + } else { + return {}; + } + } + static at::TypePtr packed_type() { + return IValuePacker>>::packed_type(); + } +}; + +// This is a helper struct for packing and unpacking multiple arguments into +// an ivalue_list. It leverages IValuePacker. +struct PackedArgs { + PackedArgs() = default; + + explicit PackedArgs(std::vector stack_) + : stack(std::move(stack_)) {} + + std::vector vec() && { + return std::move(stack); + } + + template + void pack(const T& t) { + stack.emplace_back(IValuePacker::pack(t)); + } + template + T unpack() { + return IValuePacker::unpack(std::move(stack[idx++])); + } + + private: + std::vector stack; + int64_t idx = 0; +}; + +// This is a layer of indirection for calling methods on the Python +// AutogradCompilerInstance (referred to as the "py_compiler") from +// libtorch_cpu (where Python is not available). +// A PyCompilerInterfaceImpl in libtorch_python subclasses it and +// overrides the methods to do the actual calls back to Python. +struct TORCH_API PyCompilerInterface { + PyCompilerInterface() = default; + PyCompilerInterface(const PyCompilerInterface&) = delete; + PyCompilerInterface& operator=(const PyCompilerInterface&) = delete; + PyCompilerInterface(PyCompilerInterface&&) = delete; + PyCompilerInterface& operator=(PyCompilerInterface&&) = delete; + virtual ~PyCompilerInterface() = default; + + // Invokes py_compiler.bind_function(fn_name, fn) + virtual void bind_function( + PyObject* py_compiler, + const std::string& fn_name, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + functional_apply_t fn, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector packed_args_schema) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } + + // Invokes py_compiler.method_name(fn_name, inputs, packed_args, + // output_metadata) + virtual variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); + } +}; + +TORCH_API const std::unique_ptr& getPyCompilerInterface(); +TORCH_API void setPyCompilerInterface( + std::unique_ptr&& impl); +TORCH_API void resetPyCompilerInterface(); + +// including torch/csrc/autograd/engine.h breaks BC by somehow introducing +// symbol resolution issues. Instead requiring downstream users to include +// engine.h to access collect_input_metadata, we provide it here (with a +// different name to avoid ambigous symbols...) +TORCH_API std::vector> get_input_metadata( + const edge_list& edges); + } // namespace torch::dynamo::autograd template <> diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 3753e5988cdb..12ef964f7e0b 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -52,6 +52,118 @@ Notes: namespace torch::dynamo::autograd { using c10::SymInt; +// List[Optional[Tensor]] in Python can't be directly parsed into a +// List[Tensor], so we need to do this conversion manually. +static std::vector toTensorList( + const std::vector>& inputs) { + std::vector result; + result.reserve(inputs.size()); + for (const auto& inp : inputs) { + if (inp.has_value()) { + result.emplace_back(*inp); + } else { + result.emplace_back(); + } + } + return result; +} + +// Binds a function (that represents some backward computation) to Python. +// All of these functions have a common signature, which is +// (in C++) (vector, vector) -> vector +// (in Python) (List[Optional[Tensor]], *packed_args: IValue) -> +// List[Optional[Tensor]] +// +// The vector are the list of gradient Tensors, each of which may be +// undefined (in C++) which corresponds to None (in Python). +static void bind_function( + PyObject* py_compiler, + const std::string& fn_name, + functional_apply_t fn, + std::vector packed_args_schema) { + // This is the function that can be called from Python. + auto py_func = py::cpp_function( + [packed_args_schema = std::move(packed_args_schema), fn = std::move(fn)]( + std::vector>& inputs, + const py::args& py_args) -> py::object { + // py_args is a tuple of PyObject*. + // We need to reconstruct a vector to invoke `fn`. + // To do so, we use the packed_args_schema to convert each PyObject* + // to its corresponding C++ type that can be stored into IValue. + TORCH_INTERNAL_ASSERT(py_args.size() == packed_args_schema.size()); + std::vector args; + args.reserve(py_args.size()); + auto tuple_args = jit::tuple_slice(py_args); + for (uint64_t idx = 0; idx < packed_args_schema.size(); idx++) { + args.emplace_back(jit::toIValue( + tuple_args[idx], packed_args_schema[idx], std::nullopt)); + } + // None in Python corresponds to undefined Tensor in C++ + auto inputs_ = toTensorList(inputs); + auto outputs = fn(inputs_, args); + return jit::toPyObject(at::IValue(outputs)); + }); + py::handle handle(py_compiler); + handle.attr("bind_function")(fn_name, py_func); +} + +// Invokes py_compiler.method_name(fn_name, inputs, packed_args, +// output_metadata) +static variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) { + // convert ivalue_list -> PyObject* + PyObject* py_packed_args = + PyTuple_New(static_cast(packed_args.size())); + for (const auto i : c10::irange(packed_args.size())) { + py::object obj = jit::toPyObject(packed_args[i]); + Py_INCREF(obj.ptr()); + PyTuple_SET_ITEM(py_packed_args, i, obj.ptr()); + } + + // call the corresponding method on the py_compiler + py::handle handle(py_compiler); + py::object stuff = handle.attr(method_name)( + fn_name, + inputs, + py::handle(py_packed_args), + jit::toPyObject(output_metadata)); + + // Convert the output from PyObject* to vector + auto tmp = py::cast>>(stuff); + return toTensorList(tmp); +} + +struct PyCompilerInterfaceImpl : PyCompilerInterface { + void bind_function( + PyObject* py_compiler, + const std::string& fn_name, + functional_apply_t fn, + std::vector packed_args_schema) override { + return torch::dynamo::autograd::bind_function( + py_compiler, fn_name, std::move(fn), std::move(packed_args_schema)); + } + variable_list call_function( + PyObject* py_compiler, + const char* method_name, + const std::string& fn_name, + const variable_list& inputs, + const ivalue_list& packed_args, + const c10::IValue& output_metadata) override { + return torch::dynamo::autograd::call_function( + py_compiler, + method_name, + fn_name, + inputs, + packed_args, + output_metadata); + } +}; + static PyObject* wrap_int_list(const std::vector& inputs) { PyObject* pyinput = PyTuple_New(static_cast(inputs.size())); for (const auto i : c10::irange(inputs.size())) { @@ -88,6 +200,22 @@ static void check(bool result) { check(nullptr); } +static variable_list validate_outputs( + const variable_list& outputs, + const ivalue_list& args) { + auto r = PackedArgs(args); + auto value = r.unpack>>(); + auto new_outputs = outputs; + + torch::autograd::validate_outputs( + value, new_outputs, [&](const std::string& msg) { + std::ostringstream ss; + ss << "[Compiled Autograd Tracing:]" << msg; + return ss.str(); + }); + return new_outputs; +} + // snapshot of python verbose logging toggle static PyObject* python_verbose_logger = nullptr; @@ -657,6 +785,8 @@ static CacheNode* _compiled_autograd_impl( ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); + setPyCompilerInterface(std::make_unique()); + TraceState state = call_begin_capture( py_compiler, *cache, compiler_call, output_edges.size()); InputBuffers input_buffers; @@ -723,16 +853,48 @@ static CacheNode* _compiled_autograd_impl( SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call); variable_list outputs = call.node->apply_with_saved(inputs, saved); - saved.debug_asserts(); saved.before(call.node->next_edges()); - validate_outputs( - call.node->next_edges(), outputs, [&](const std::string& msg) { - std::ostringstream ss; - ss << "[Compiled Autograd Tracing: " << call.node->name() << "] " - << msg; - return ss.str(); - }); + + auto input_metadata = get_input_metadata(call.node->next_edges()); + TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size()); + + // Lazily bind the `validate_outputs` function to Python. + static c10::once_flag flag; + c10::call_once(flag, [&]() { + auto schema = std::vector{IValuePacker< + std::vector>>::packed_type()}; + bind_function( + py_compiler.get(), "validate_outputs", validate_outputs, schema); + }); + + // Don't emit validate_outputs nodes that follow a CompiledBackward node. + // These nodes would otherwise prevent reordering of accumulate_grad + // nodes. + // + // Note that this will not cause correctness issues, because + // 1) AOTAutograd already coerces gradients to have the same metadata as + // the inputs. 2) the AOTAutograd graph already has the necessary + // aten::sum_to nodes in it (so it doesn't need to rely on + // validate_outputs to handle that). + // + // However, we may be dropping some (edge case) safety checks compared to + // eager: a backward that would have errored out in eager may not error + // out in compiled autograd (for example, if the user provided an + // incorrect number of gradients). + if (!call.node->is_aot_backward()) { + PackedArgs args; + args.pack(input_metadata); + ivalue_list input_metadata_state = std::move(args).vec(); + outputs = call_function( + py_compiler, + "validate_outputs", + "validate_outputs", + outputs, + input_metadata_state, + input_metadata_state[0]); + } + saved.after(call.node->next_edges()); saved.debug_asserts(); @@ -761,6 +923,7 @@ static CacheNode* _compiled_autograd_impl( } } + resetPyCompilerInterface(); PyObject* res = check(call_end_capture(py_compiler, state.outputs)); TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); TORCH_CHECK(