diff --git a/aten/src/ATen/TensorGeometry.h b/aten/src/ATen/TensorGeometry.h index 06a064063c4e..41f14a15ba99 100644 --- a/aten/src/ATen/TensorGeometry.h +++ b/aten/src/ATen/TensorGeometry.h @@ -37,16 +37,6 @@ struct TORCH_API TensorGeometry { has_symbolic_sizes_strides_( t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} - explicit TensorGeometry( - std::vector sizes, - std::vector strides, - at::SymInt storage_offset) - : sizes_(std::move(sizes)), - strides_(std::move(strides)), - storage_offset_(std::move(storage_offset)) { - recompute(); - } - // true if the tensor is contiguous bool is_contiguous() const; diff --git a/build_variables.bzl b/build_variables.bzl index a95c03cd0b34..8bd8ad3a8df0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -138,7 +138,6 @@ core_trainer_sources = [ "torch/csrc/autograd/variable.cpp", "torch/csrc/autograd/utils/warnings.cpp", "torch/csrc/autograd/jit_decomp_interface.cpp", - "torch/csrc/dynamo/compiled_autograd.cpp", "torch/csrc/jit/frontend/name_mangler.cpp", "torch/csrc/jit/ir/type_hashing.cpp", "torch/csrc/jit/serialization/pickler.cpp", diff --git a/test/dynamo/test_backward_higher_order_ops.py b/test/dynamo/test_backward_higher_order_ops.py index 6c5cd6a9f25f..14e3f2e044c1 100644 --- a/test/dynamo/test_backward_higher_order_ops.py +++ b/test/dynamo/test_backward_higher_order_ops.py @@ -121,27 +121,23 @@ class _multiply_invoke(torch.nn.Module): out.backward(grad_out) actual = normalize_gm(graph.print_readable(False)) self.assertEqual(x.grad, grad_out * grad_out) - if backend in ["aot_eager", "inductor"]: - self.assertExpectedInline( - actual, - """\ + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_ : list): l_inputs_ = L_inputs_ - getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None - validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None - getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None + new_grad: "f32[s0]" = torch.clone(getitem) - new_grad: "f32[2]" = torch.clone(getitem_3) + result: "f32[s0]" = getitem * getitem; getitem = None - result: "f32[2]" = getitem_3 * getitem_3; getitem_3 = None - - new_grad_1: "f32[2]" = torch.clone(result); result = None + new_grad_1: "f32[s0]" = torch.clone(result); result = None return (new_grad, new_grad_1) """, - ) + ) graph = None @@ -191,30 +187,26 @@ class GraphModule(torch.nn.Module): actual = normalize_gm(graph.print_readable(False)) self.assertEqual(obj.counter, 1) self.assertEqual(x.grad, grad_out + grad_out) - if backend in ["aot_eager", "inductor"]: - self.assertExpectedInline( - actual, - """\ + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_ : list, L_hooks_0_keywords_fn_keywords_obj_counter: "Sym(s1)"): l_inputs_ = L_inputs_ l_hooks_0_keywords_fn_keywords_obj_counter = L_hooks_0_keywords_fn_keywords_obj_counter - getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None + getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None - validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [2], False)]); getitem = None - getitem_3: "f32[2]" = validate_outputs[0]; validate_outputs = None - - new_grad: "f32[2]" = torch.clone(getitem_3) + new_grad: "f32[s0]" = torch.clone(getitem) add: "Sym(s1 + 1)" = l_hooks_0_keywords_fn_keywords_obj_counter + 1; l_hooks_0_keywords_fn_keywords_obj_counter = None - result: "f32[2]" = getitem_3 * getitem_3; getitem_3 = None + result: "f32[s0]" = getitem * getitem; getitem = None - new_grad_1: "f32[2]" = torch.clone(result); result = None + new_grad_1: "f32[s0]" = torch.clone(result); result = None return (new_grad, new_grad_1, add) """, - ) + ) out = fn(x, y) out.backward(grad_out) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 2a2b75690098..d916c4186a3c 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2924,6 +2924,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { "aot0_le", "aot0_permute_2", "code: CompiledFunctionBackward0 (NodeCall 2)", + "aot0_tangents_1", "aot0_full_default", "aot0_where", "aot0_mm", @@ -2973,17 +2974,20 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { expected_logs = [ "CompiledFunctionBackward1", + "aot1_tangents_1", "aot1_sin_1", + "aot1_primals_2", "aot1_neg", "aot0_tangents_2", "aot1_cos_1", + "aot1_primals_1", "aot0_tangents_1", "CompiledFunctionBackward0", - "aot0_sin_1", "aot0_neg", + "aot0_sin", "aot0_mul", - "aot0_cos_1", "aot0_mul_1", + "aot0_cos", "aot0_add", ] @@ -3614,7 +3618,6 @@ known_failing_tests = { "test_tp_compile_comm_reordering", "test_unwrap_async_collective_tensor_tangent", # Uncategorized - "test_not_implemented_grad", # Dynamo changes the types of exceptions } if not HAS_CUDA: diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index ed50a02d7b27..9d1a3202f147 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -337,9 +337,7 @@ class DistributedPatternTests(TestCase): self.assertEqual(fw_cnt.frame_count, 1) self.assertEqual(fw_cnt.op_count, 5) self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None - self.assertEqual( - bw_cnt.op_count, 72 - ) # Number of ops in the Dynamo-produced graphs + self.assertEqual(bw_cnt.op_count, 48) def test_module_backward_hooks_aot(self): m1, inp1 = init_module_bw_hooks(True) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index a2f9c60b2224..fa1d0ce4bc91 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -68,7 +68,6 @@ struct TORCH_API ${op} : public ${superclass} { } ${will_release_variables} void compiled_args(CompiledNodeArgs& args) override; - ivalue_list get_packed_args(); variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; ${saved_variables} ${saved_list_sizes} @@ -108,13 +107,6 @@ static variable_list ${op}_apply_functional( ${body} return grad_inputs; } -static variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args) -{ - auto packed_args = PackedArgs(args); - auto needs_input_grad = packed_args.unpack>(); - ${unpack_ivalues} - return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args}); -} variable_list ${op}::apply(variable_list&& grads) { ${thread_lock} @@ -128,42 +120,11 @@ void ${op}::compiled_args(CompiledNodeArgs& args) { ${compiled_args} } variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { - ${apply_with_saved_before} - - static std::once_flag flag; - std::call_once(flag, [&](){ - ${compute_schema} - const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); - interface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema); - }); - - variable_list result; - auto packed_args = get_packed_args(); - auto output_metadata = torch::dynamo::autograd::IValuePacker< - std::vector>>::pack( - torch::dynamo::autograd::get_input_metadata(next_edges())); - const auto& interface = torch::dynamo::autograd::getPyCompilerInterface(); - result = interface->call_function( - saved.get_py_compiler(), - "apply_functional", - name(), - grads, - packed_args, - output_metadata); - - ${apply_with_saved_after} - return result; + ${apply_with_saved_before} + variable_list result = apply(variable_list(grads)); + ${apply_with_saved_after} + return result; } -ivalue_list ${op}::get_packed_args() { - PackedArgs packed_args; - ${asserts} - ${unpacks} - ${compute_needs_input_grad} - packed_args.pack(needs_input_grad); - ${get_packed_args} - return std::move(packed_args).vec(); -} - """ ) @@ -1032,38 +993,14 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { f"{T} {x}" for T, x in zip(apply_functional_args_ref_types, apply_functional_args) ] - get_packed_args = "\n".join( - f"packed_args.pack({name});" for name in apply_functional_args - ) - unpack_ivalues = [] - for typ, name in zip(apply_functional_args_ref_types, apply_functional_args): - if typ.endswith("&"): - typ = typ[:-1] - unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();") - - schema_args = [f"std::array"] - for typ in apply_functional_args_ref_types: - if typ.endswith("&"): - typ = typ[:-1] - if typ.startswith("const"): - typ = typ[5:] - schema_args.append(typ.strip()) - compute_schema = ["std::vector schema = {"] - for schema_arg in schema_args: - compute_schema.append( - f" torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type()," - ) - compute_schema.append("};") return template.substitute( unpacks="\n".join(unpack), op=info.op, - compute_schema="\n".join(compute_schema), apply_functional_args=apply_functional_args, apply_functional_args_signature=apply_functional_args_signature, compute_needs_input_grad=compute_needs_input_grad, num_inputs=len(input_name_to_idx), - unpack_ivalues="\n".join(unpack_ivalues), compute_index_ranges=compute_index_ranges, saved_variables=saved_variables, release_variables=release_variables, @@ -1078,5 +1015,4 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { compiled_args=compiled_args, apply_with_saved_before=apply_with_saved_before, apply_with_saved_after=apply_with_saved_after, - get_packed_args=get_packed_args, ) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index cd4db23e332f..fb7017bc6dc9 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -8,7 +8,6 @@ from collections import defaultdict from typing import Any, Optional, TYPE_CHECKING, Union import torch -import torch.utils._pytree as pytree from torch._dynamo.external_utils import ( call_backward, call_hook, @@ -66,39 +65,6 @@ def maybe_clone(x): return x -# We lazily bind "functional backward" variants for PyTorch built-in autograd -# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0 -# Each "functional backward" is bound the first time the node's apply_with_saved -# function is called. It's possible to avoid lazy binding and instead bind -# all of this upfront (perhaps at import time) via codegen changes. -class OpNamespace: - def add(self, name, fn): - assert not hasattr(self, name) - result = Op(name, fn) - torch._dynamo.allow_in_graph(result) - setattr(self, name, result) - return result - - def get(self, name): - return getattr(self, name) - - -class Op: - def __init__(self, name, fn): - self.fn = fn - self.__name__ = name - self.__module__ = "torch._dynamo.compiled_autograd.ops" - - def __call__(self, *args, **kwargs): - return self.fn(*args, **kwargs) - - def __repr__(self): - return self.__module__ + "." + self.__name__ - - -ops = OpNamespace() - - _graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] _impure_targets = OrderedSet( [ @@ -171,8 +137,7 @@ class AutogradCompilerInstance: self.fx_tracer.root = torch.nn.Module() self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) self.fx_tracer.tensor_attrs = {} - self.symnode_proxy_lookup = {} - args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = ( + args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = ( self.fx_tracer.create_proxy("placeholder", name, (), {}) for name in _graph_placeholders ) @@ -195,9 +160,7 @@ class AutogradCompilerInstance: ) for idx, val in enumerate(sizes) ] - self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins) - for i, symint in enumerate(sizes): - self.symnode_proxy_lookup[symint.node] = self.sizes_proxy[i] + self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins) for idx, val in enumerate(scalars): source = self.source("scalars", idx) @@ -219,9 +182,7 @@ class AutogradCompilerInstance: ) else: raise AssertionError("Unexpected scalar type: ", type(val)) - self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins) - for i, symval in enumerate(scalars): - self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr] + self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins) # TODO(jansel): are all these modes needed? self.stack.enter_context(decompose({})) @@ -255,6 +216,7 @@ class AutogradCompilerInstance: ), kwargs={}, ) + with disable_proxy_modes_tracing(): # create fake Tensors grad_ins: list[Optional[torch.Tensor]] = [] @@ -270,65 +232,6 @@ class AutogradCompilerInstance: self.bind_tensors_to_proxies(grad_ins, proxies) return tuple(grad_ins) - # Guess what the outputs should be from the InputMetadata. - # This is not sound in general (we guess contiguous strides - # and no Tensor subclass-ness); we will stop guessing - # the output metadata in a follow-up. - def guess_output(self, input_metadata): - if input_metadata is None: - return None - tensoroptions, shape, _ = input_metadata - kwargs = {} - names = [ - "requires_grad", - "memory_format", - "device", - "dtype", - "layout", - "pinned_memory", - ] - for name, option in zip(names, tensoroptions): - if option is not None: - kwargs[name] = option - - with disable_proxy_modes_tracing(): - return torch.ops.aten.zeros(shape, **kwargs) - - def bind_function(self, fn_name, fn): - """Binds ops.fn_name = fn""" - ops.add(fn_name, fn) - - def apply_functional(self, fn_name, grads, args, output_metadata): - """Proxies a call to ops.fn_name(grads, *args) into the graph""" - op = ops.get(fn_name) - return self.proxy_call(op, (grads, *args), output_metadata) - - def proxy_call(self, fn, args, output_metadata): - """Proxies a call to fn(*args) into the graph""" - flat_args, _ = pytree.tree_flatten(args) - proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) - proxy_out = self.fx_tracer.create_proxy( - "call_function", fn, args=proxy_args, kwargs={} - ) - result = [self.guess_output(metadata) for metadata in output_metadata] - self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(len(result))]) - return result - - def validate_outputs(self, _, outputs, args, output_metadata): - """Proxies a call to ops.validate_outputs(outputs, *args) into the graph""" - op = ops.get("validate_outputs") - proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args)) - new_proxy_outputs = self.fx_tracer.create_proxy( - "call_function", op, args=proxy_args, kwargs={} - ) - assert len(output_metadata) == len(outputs) - outputs = [ - None if output is None or metadata is None else self.guess_output(metadata) - for output, metadata in zip(outputs, output_metadata) - ] - self.bind_tensors_to_proxies(outputs, new_proxy_outputs) - return outputs - def proxy_call_hook(self, hook, *args, **kwargs): return self.fx_tracer.create_proxy( "call_function", @@ -411,7 +314,6 @@ class AutogradCompilerInstance: assert nodes[first_getitem_idx] == inputs_users[0] last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 assert nodes[last_getitem_idx] == inputs_users[-1] - # getitem nodes on inputs for i, node in enumerate(inputs_users): if not has_cuda_inputs and node.meta["val"].device.type == "cuda": has_cuda_inputs = True @@ -421,13 +323,9 @@ class AutogradCompilerInstance: is_scalar = len(node.meta["val"].size()) == 0 if is_cpu and is_scalar: node_users = list(node.users.keys()) - # We can only move the cpu scalar if it is not exposed to user code. if all( - ( - isinstance(user.target, torch._ops.OpOverload) - and user.target.namespace in ("prims", "aten") - ) - or isinstance(user.target, Op) + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") for user in node_users ): # all users are prims/aten, can move safely @@ -437,7 +335,6 @@ class AutogradCompilerInstance: # this is to handle the case where cudagraphs is enabled on a cpu-only graph if has_cuda_inputs: for node in to_move.values(): - verbose_log.debug("Moving node %s from cpu to cuda", node) node.meta["val"] = node.meta["val"].cuda() # return runtime indices we need to move to cuda @@ -471,10 +368,7 @@ class AutogradCompilerInstance: or (node.op == "call_function" and node.target in _impure_targets) ) - before = len(self.fx_tracer.graph.nodes) self.fx_tracer.graph.eliminate_dead_code(is_impure) - after = len(self.fx_tracer.graph.nodes) - verbose_log.debug("DCE removed %d nodes", before - after) def end_capture(self, outputs): self.fx_tracer.create_proxy( @@ -490,10 +384,6 @@ class AutogradCompilerInstance: (self.fx_tracer.create_arg(self.to_proxy(outputs)),), {}, ) - runtime_inputs_to_move: list[int] = [] - if snapshot_cudagraph_enabled(): - runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) - # TODO(rzou): the guessed metadata is incorrect, we will remove it at the end of the PR stack. self.rename_aot_dispatcher_nodes() self.reorder_tensor_pre_hook_nodes() self.reorder_pre_hook_nodes_to_schedule_asap() @@ -512,6 +402,9 @@ class AutogradCompilerInstance: # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and # should prevent these ops from going into the CA graph. self.dce() + runtime_inputs_to_move: list[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) graph = GraphModule( self.fx_tracer.root, self.fx_tracer.graph, f"CompiledAutograd{self.id}" @@ -885,11 +778,8 @@ class AutogradCompilerInstance: return [self.to_proxy(x) for x in t] if isinstance(t, tuple): return tuple(self.to_proxy(x) for x in t) - if isinstance(t, (torch.SymInt, torch.SymFloat)): - return self.symnode_proxy_lookup[t.node] - if not isinstance(t, torch.Tensor): - # constant types like device, dtype, str - return t + # can it be torch.SymInt as the code used to imply? + assert isinstance(t, torch.Tensor) proxy_tensor = fetch_object_proxy(self.fx_tracer, t) assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) return proxy_tensor.proxy diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index c8d465211a6f..2aad75e0e74b 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -897,19 +897,6 @@ bool has_input_metadata(const Edge& thing) { return thing.is_valid(); } -std::vector> collect_input_metadata( - const edge_list& edges) { - std::vector> input_metadata; - for (const auto& edge : edges) { - if (!edge.is_valid()) { - input_metadata.emplace_back(std::nullopt); - continue; - } - input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr)); - } - return input_metadata; -} - // Given an vector or vector>, validate the // outputs. This involves using the InputMetadata to check the outputs and also // potentially calling .sum_to on the outputs. diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 5bf00bac5378..4243f1b1d6ee 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -47,8 +47,6 @@ TORCH_API void validate_outputs( const std::vector>& input_metadata, variable_list& grads, const std::function& format_error); -TORCH_API std::vector> collect_input_metadata( - const edge_list& edges); struct NodeTask { std::weak_ptr base_; diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index abd11303eafe..ba2f6edbc6c0 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -34,12 +34,8 @@ using tensor_list = std::vector; using variable_list = std::vector; using edge_list = std::vector; using saved_variable_list = std::vector; -using ivalue_list = std::vector; -using functional_apply_t = std::function< - variable_list(const variable_list&, const std::vector&)>; using IndexRange = std::pair; using torch::dynamo::autograd::CompiledNodeArgs; -using torch::dynamo::autograd::PackedArgs; using torch::dynamo::autograd::SwapSavedVariables; // Custom deleter to prevent stack overflows. @@ -608,12 +604,6 @@ struct TORCH_API Node : std::enable_shared_from_this { std::string("apply_with_saved not implemented: ") + name()); } - // If this node is the AOTBackward node produced by torch.compile. - // Compiled Autograd special-cases on this information. - virtual bool is_aot_backward() const { - return false; - } - protected: /// Performs the `Node`'s actual operation. virtual variable_list apply(variable_list&& inputs) = 0; diff --git a/torch/csrc/autograd/function_hook.h b/torch/csrc/autograd/function_hook.h index 4e8bba79a169..6342bf280a5c 100644 --- a/torch/csrc/autograd/function_hook.h +++ b/torch/csrc/autograd/function_hook.h @@ -8,7 +8,6 @@ namespace torch::dynamo::autograd { class CompiledNodeArgs; class SwapSavedVariables; -struct PackedArgs; } // namespace torch::dynamo::autograd // A hook that's called on gradients diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index abd8ff30e909..19151cbaafe6 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -288,11 +288,6 @@ auto PyNode::name() const -> std::string { return name; } -bool PyNode::is_aot_backward() const { - py::handle handle(obj); - return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); -} - auto PyNode::compiled_autograd_should_lift() const -> bool { pybind11::gil_scoped_acquire gil; static PyObject* attr_name = diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 2f28c765ab06..46faff8e4686 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -43,8 +43,6 @@ struct PyNode : public Node { std::string name() const override; bool is_traceable() override; - bool is_aot_backward() const override; - void compiled_args(CompiledNodeArgs& args) override; variable_list apply_with_saved( const variable_list& inputs, diff --git a/torch/csrc/dynamo/compiled_autograd.cpp b/torch/csrc/dynamo/compiled_autograd.cpp deleted file mode 100644 index 7e2aad576189..000000000000 --- a/torch/csrc/dynamo/compiled_autograd.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include -#include - -namespace torch::dynamo::autograd { - -std::unique_ptr kPyCompilerInterface; - -const std::unique_ptr& getPyCompilerInterface() { - TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr); - return kPyCompilerInterface; -} - -void setPyCompilerInterface(std::unique_ptr&& impl) { - TORCH_INTERNAL_ASSERT(impl != nullptr); - kPyCompilerInterface = std::move(impl); -} - -void resetPyCompilerInterface() { - kPyCompilerInterface.reset(); -} - -std::vector> get_input_metadata( - const edge_list& edges) { - return torch::autograd::collect_input_metadata(edges); -} - -} // namespace torch::dynamo::autograd diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index b00ec6e00a40..383cff14b8e0 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -900,506 +900,6 @@ class SwapSavedVariables { StashedVars stashed_ivalues; }; -// NOTE: [Compiled Autograd and backward functions] -// Built-in autograd nodes have functional apply variants -// (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph -// capture wants to take a variant of this function and proxy it into the graph. -// Every autograd node defines an apply_with_saved function, that when invoked, -// proxys a call to a function into the Compiled Autograd graph. -// -// Some requirements that we have are: -// - The proxy'ed function must have inputs that are FX-graphable types. -// - Windows has a DLL symbol limit of 65536. -// - Node::apply_with_saved is in libtorch_cpu which does not have direct access -// to Python -// -// There were multiple ways to skin the cat, but what we end up doing is: -// - for e.g. MulBackward0_apply_functional, we create a new C++ function -// MulBackward0_apply_functional_ivalue that accepts vector. -// - We define how to pack and unpack arbitrary C++ types into IValues. -// - apply_with_saved passes MulBackward0_apply_functional_ivalue and -// the IValue arguments to Python via an indirection. -// In Python, these get proxy'ed into a graph. - -// Helper struct for packing/unpacking an arbitrary C++ type into a single -// IValue. There are various full and partial specializations for IValuePacker -// to handle packing specific types (like TensorOptions) into an IValue. -template -struct IValuePacker { - // Defines how to pack T into an IValue. - static at::IValue pack(const T& t) { - return t; - } - // Defines how to unpack an IValue into T. - static T unpack(const at::IValue& t) { - return t.to(); - } - // Returns the TypePtr for the IValue (this is like the "type" of the IValue). - // We use this when passing the packed IValue from Python to C++. - // In Python, the IValue is just a PyObject* with the native type. - // For example, it may be a Python int, a Python List[int], etc. - // When passing this PyObject* into C++, we need to know how to parse it - // into a C++ type that then gets put into an IValue. - // That's what the TypePtr is for: it contains the information to do the - // parsing. See torch::jit::toIValue for more information. - static at::TypePtr packed_type() { - if constexpr (::std::is_same_v) { - return at::TensorType::get(); - } else if constexpr (::std::is_same_v) { - return at::IntType::get(); - } else if constexpr (::std::is_same_v) { - return at::SymIntType::get(); - } else if constexpr (::std::is_same_v) { - return at::BoolType::get(); - } else if constexpr (::std::is_same_v) { - return at::FloatType::get(); - } else if constexpr (::std::is_same_v) { - return at::SymFloatType::get(); - } else if constexpr (::std::is_same_v) { - return at::SymBoolType::get(); - } else if constexpr (::std::is_same_v) { - return at::LayoutType::get(); - } else if constexpr (::std::is_same_v) { - return at::StringType::get(); - } else if constexpr (::std::is_same_v) { - return at::DeviceObjType::get(); - } else if constexpr (::std::is_same_v) { - return at::NumberType::get(); - } else if constexpr (::std::is_same_v) { - return at::MemoryFormatType::get(); - } else if constexpr (::std::is_same_v) { - return at::ScalarTypeType::get(); - } else { - // If you got here, you have probably added a member of a new type - // to a built-in C++ autograd node. - // Unfortunately, we don't know how to handle this type yet. - // To get this new type to work with Compiled Autograd, please - // either change it to be an IValue-constructible type, or - // define how to pack and unpack an object of this time into an IValue - // by creating a specialization of IValuePacker for this type. - // See NOTE: [Compiled Autograd and backward functions] for context. - TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type"); - return at::NoneType::get(); - } - } -}; - -template <> -struct IValuePacker { - static at::IValue pack(const size_t& t) { - // We generally use size_t as the size of a list of Tensors or number of - // dimensions. The number of dimensions generally do not exceed 64 - // (TensorIterator has that limitation), and lists of Tensors generally do - // not exceed the int64_t max (you'd probably run out of RAM or run into - // significant Tensor overhead). If you run into this limitation the fix is - // to figure out how to pack size_t into int64_t. Note that size_t has some - // weird behavior on Mac OS. - uint64_t maximum_value = std::numeric_limits::max(); - TORCH_INTERNAL_ASSERT( - static_cast(t) <= maximum_value, - "size_t too large to pack into IValue"); - return static_cast(t); // pack as int64_t - } - static size_t unpack(const at::IValue& t) { - return static_cast(t.toInt()); - } - static at::TypePtr packed_type() { - return IValuePacker::packed_type(); - } -}; - -template <> -struct IValuePacker> { - static at::IValue pack(const std::vector& t) { - return t; - } - static std::vector unpack(const at::IValue& t) { - // We need this because there's no t.to>() override? - return t.toSymIntVector(); - } - static at::TypePtr packed_type() { - return at::ListType::create(at::SymIntType::get()); - } -}; - -template <> -struct IValuePacker { - static at::IValue pack(const VariableInfo& t) { - auto tuple = std::make_tuple( - t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty); - return tuple; - } - static VariableInfo unpack(const at::IValue& t) { - auto tuple = t.to, - bool, - bool>>(); - VariableInfo v; - v.layout = std::get<0>(tuple); - v.device = std::get<1>(tuple); - v.scalar_type = std::get<2>(tuple); - v.size = std::get<3>(tuple); - v.requires_grad = std::get<4>(tuple); - v.is_empty = std::get<5>(tuple); - return v; - } - static at::TypePtr packed_type() { - return at::TupleType::create({ - at::LayoutType::get(), - at::DeviceObjType::get(), - at::ScalarTypeType::get(), - at::ListType::create(at::SymIntType::get()), - at::BoolType::get(), - at::BoolType::get(), - }); - } -}; - -template <> -struct IValuePacker { - static at::IValue pack(const caffe2::TypeMeta& t) { - return at::typeMetaToScalarType(t); // pack as at::ScalarType - } - static caffe2::TypeMeta unpack(const at::IValue& t) { - return caffe2::TypeMeta::fromScalarType(t.to()); - } - static at::TypePtr packed_type() { - return IValuePacker::packed_type(); - } -}; - -inline std::optional optTypeMetaToScalarType( - const std::optional& t) { - if (t.has_value()) { - return at::typeMetaToScalarType(t.value()); - } else { - return std::nullopt; - } -} - -using packed_tensoroptions_t = std::tuple< - std::optional, - std::optional, - std::optional, - std::optional, - std::optional, - std::optional>; - -inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) { - auto tuple = std::make_tuple( - t.requires_grad_opt(), - t.memory_format_opt(), - t.device_opt(), - optTypeMetaToScalarType(t.dtype_opt()), - t.layout_opt(), - t.pinned_memory_opt()); - return tuple; -} -inline at::TensorOptions unpack_TensorOptions( - const packed_tensoroptions_t& tuple) { - at::TensorOptions result; - auto maybe_requires_grad = std::get<0>(tuple); - if (maybe_requires_grad.has_value()) { - result = result.requires_grad(maybe_requires_grad.value()); - } - auto maybe_memory_format = std::get<1>(tuple); - if (maybe_memory_format.has_value()) { - result = result.memory_format(maybe_memory_format.value()); - } - auto maybe_device = std::get<2>(tuple); - if (maybe_device.has_value()) { - result = result.device(maybe_device.value()); - } - auto maybe_dtype = std::get<3>(tuple); - if (maybe_dtype.has_value()) { - result = - result.dtype(caffe2::TypeMeta::fromScalarType(maybe_dtype.value())); - } - auto maybe_layout = std::get<4>(tuple); - if (maybe_layout.has_value()) { - result = result.layout(maybe_layout.value()); - } - auto maybe_pinned_memory = std::get<5>(tuple); - if (maybe_pinned_memory.has_value()) { - result = result.pinned_memory(maybe_pinned_memory.value()); - } - return result; -} - -template <> -struct IValuePacker { - static at::IValue pack(const at::TensorOptions& t) { - return pack_TensorOptions(t); - } - static at::TensorOptions unpack(const at::IValue& t) { - auto tuple = t.to(); - return unpack_TensorOptions(tuple); - } - static at::TypePtr packed_type() { - return at::TupleType::create( - {at::OptionalType::create(at::BoolType::get()), - at::OptionalType::create(at::MemoryFormatType::get()), - at::OptionalType::create(at::DeviceObjType::get()), - at::OptionalType::create(at::ScalarTypeType::get()), - at::OptionalType::create(at::LayoutType::get()), - at::OptionalType::create(at::BoolType::get())}); - } -}; - -template <> -struct IValuePacker { - static at::IValue pack(const TypeAndSize& t) { - auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options)); - return tuple; - } - static TypeAndSize unpack(const at::IValue& t) { - auto tuple = - t.to, packed_tensoroptions_t>>(); - TypeAndSize result; - result.sym_sizes = std::get<0>(tuple); - result.options = unpack_TensorOptions(std::get<1>(tuple)); - return result; - } - static at::TypePtr packed_type() { - return at::TupleType::create( - {IValuePacker>::packed_type(), - IValuePacker::packed_type()}); - } -}; - -template -struct IValuePacker> { - static at::IValue pack(const std::optional& t) { - if (t.has_value()) { - return IValuePacker::pack(t.value()); - } else { - return std::nullopt; - } - } - static std::optional unpack(const at::IValue& t) { - if (t.isNone()) { - return std::nullopt; - } else { - return IValuePacker::unpack(t); - } - } - static at::TypePtr packed_type() { - return at::OptionalType::create(IValuePacker::packed_type()); - } -}; - -template -struct IValuePacker> { - static at::IValue pack(const std::vector& t) { - if constexpr (::std::is_constructible_v) { - return t; - } - if (t.empty()) { - auto lst = c10::impl::GenericList(at::AnyType::get()); - return lst; - } - auto type_ptr = IValuePacker::pack(t[0]).type(); - auto lst = c10::impl::GenericList(type_ptr); - for (const auto& elt : t) { - lst.emplace_back(IValuePacker::pack(elt)); - } - return lst; - } - static std::vector unpack(const at::IValue& t) { - if constexpr (::std::is_constructible_v) { - return t.to<::std::vector>(); - } - std::vector result; - auto lst = t.toList(); - for (const at::IValue& elt : lst) { - result.emplace_back(IValuePacker::unpack(elt)); - } - return result; - } - static at::TypePtr packed_type() { - return at::ListType::create(IValuePacker::packed_type()); - } -}; - -template -struct IValuePacker> { - static at::IValue pack(const c10::List& t) { - return IValuePacker>::pack(t.vec()); - } - static c10::List unpack(const at::IValue& t) { - return c10::List(IValuePacker>::unpack(t)); - } - static at::TypePtr packed_type() { - return IValuePacker>::packed_type(); - } -}; - -template -struct IValuePacker> { - static at::IValue pack(const std::array& t) { - std::vector result(t.begin(), t.end()); - return IValuePacker>::pack(result); - } - static std::array unpack(const at::IValue& t) { - std::array result; - auto packed = IValuePacker>::unpack(t); - for (size_t i = 0; i < packed.size(); i++) { - result[i] = packed[i]; - } - return result; - } - static at::TypePtr packed_type() { - return IValuePacker>::packed_type(); - } -}; - -template <> -struct IValuePacker { - static at::IValue pack(const at::TensorGeometry& t) { - auto tuple = std::make_tuple( - t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset()); - return tuple; - } - static at::TensorGeometry unpack(const at::IValue& t) { - auto tuple = t.to, - std::vector, - at::SymInt>>(); - return at::TensorGeometry( - std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple)); - } - static at::TypePtr packed_type() { - return at::TupleType::create( - {IValuePacker>::packed_type(), - IValuePacker>::packed_type(), - at::SymIntType::get()}); - } -}; - -template <> -struct IValuePacker { - static at::IValue pack(const InputMetadata& t) { - TORCH_INTERNAL_ASSERT(!t.is_nested_tensor()); - auto tuple = std::make_tuple( - pack_TensorOptions(t.options()), - t.shape_as_dim_vector().vec(), - t.is_tensor_subclass()); - return tuple; - } - static InputMetadata unpack(const at::IValue& t) { - auto tuple = t.to< - std::tuple, bool>>(); - - return InputMetadata( - unpack_TensorOptions(std::get<0>(tuple)), - SymIntSmallVec(std::get<1>(tuple)), - std::get<2>(tuple), - false); - } - static at::TypePtr packed_type() { - return at::TupleType::create( - {IValuePacker::packed_type(), - IValuePacker>::packed_type(), - at::BoolType::get()}); - } -}; - -template -struct IValuePacker> { - static at::IValue pack(const at::OptionalArray& t) { - return IValuePacker>>::pack(t.list); - } - static at::OptionalArray unpack(const at::IValue& t) { - auto result = IValuePacker>>::unpack(t); - if (result.has_value()) { - return {result.value()}; - } else { - return {}; - } - } - static at::TypePtr packed_type() { - return IValuePacker>>::packed_type(); - } -}; - -// This is a helper struct for packing and unpacking multiple arguments into -// an ivalue_list. It leverages IValuePacker. -struct PackedArgs { - PackedArgs() = default; - - explicit PackedArgs(std::vector stack_) - : stack(std::move(stack_)) {} - - std::vector vec() && { - return std::move(stack); - } - - template - void pack(const T& t) { - stack.emplace_back(IValuePacker::pack(t)); - } - template - T unpack() { - return IValuePacker::unpack(std::move(stack[idx++])); - } - - private: - std::vector stack; - int64_t idx = 0; -}; - -// This is a layer of indirection for calling methods on the Python -// AutogradCompilerInstance (referred to as the "py_compiler") from -// libtorch_cpu (where Python is not available). -// A PyCompilerInterfaceImpl in libtorch_python subclasses it and -// overrides the methods to do the actual calls back to Python. -struct TORCH_API PyCompilerInterface { - PyCompilerInterface() = default; - PyCompilerInterface(const PyCompilerInterface&) = delete; - PyCompilerInterface& operator=(const PyCompilerInterface&) = delete; - PyCompilerInterface(PyCompilerInterface&&) = delete; - PyCompilerInterface& operator=(PyCompilerInterface&&) = delete; - virtual ~PyCompilerInterface() = default; - - // Invokes py_compiler.bind_function(fn_name, fn) - virtual void bind_function( - PyObject* py_compiler, - const std::string& fn_name, - // NOLINTNEXTLINE(performance-unnecessary-value-param) - functional_apply_t fn, - // NOLINTNEXTLINE(performance-unnecessary-value-param) - std::vector packed_args_schema) { - TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); - } - - // Invokes py_compiler.method_name(fn_name, inputs, packed_args, - // output_metadata) - virtual variable_list call_function( - PyObject* py_compiler, - const char* method_name, - const std::string& fn_name, - const variable_list& inputs, - const ivalue_list& packed_args, - const c10::IValue& output_metadata) { - TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); - } -}; - -TORCH_API const std::unique_ptr& getPyCompilerInterface(); -TORCH_API void setPyCompilerInterface( - std::unique_ptr&& impl); -TORCH_API void resetPyCompilerInterface(); - -// including torch/csrc/autograd/engine.h breaks BC by somehow introducing -// symbol resolution issues. Instead requiring downstream users to include -// engine.h to access collect_input_metadata, we provide it here (with a -// different name to avoid ambigous symbols...) -TORCH_API std::vector> get_input_metadata( - const edge_list& edges); - } // namespace torch::dynamo::autograd template <> diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 473ed29c7eb6..b58f62c3d2fc 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -52,118 +52,6 @@ Notes: namespace torch::dynamo::autograd { using c10::SymInt; -// List[Optional[Tensor]] in Python can't be directly parsed into a -// List[Tensor], so we need to do this conversion manually. -static std::vector toTensorList( - const std::vector>& inputs) { - std::vector result; - result.reserve(inputs.size()); - for (const auto& inp : inputs) { - if (inp.has_value()) { - result.emplace_back(*inp); - } else { - result.emplace_back(); - } - } - return result; -} - -// Binds a function (that represents some backward computation) to Python. -// All of these functions have a common signature, which is -// (in C++) (vector, vector) -> vector -// (in Python) (List[Optional[Tensor]], *packed_args: IValue) -> -// List[Optional[Tensor]] -// -// The vector are the list of gradient Tensors, each of which may be -// undefined (in C++) which corresponds to None (in Python). -static void bind_function( - PyObject* py_compiler, - const std::string& fn_name, - functional_apply_t fn, - std::vector packed_args_schema) { - // This is the function that can be called from Python. - auto py_func = py::cpp_function( - [packed_args_schema = std::move(packed_args_schema), fn = std::move(fn)]( - std::vector>& inputs, - const py::args& py_args) -> py::object { - // py_args is a tuple of PyObject*. - // We need to reconstruct a vector to invoke `fn`. - // To do so, we use the packed_args_schema to convert each PyObject* - // to its corresponding C++ type that can be stored into IValue. - TORCH_INTERNAL_ASSERT(py_args.size() == packed_args_schema.size()); - std::vector args; - args.reserve(py_args.size()); - auto tuple_args = jit::tuple_slice(py_args); - for (uint64_t idx = 0; idx < packed_args_schema.size(); idx++) { - args.emplace_back(jit::toIValue( - tuple_args[idx], packed_args_schema[idx], std::nullopt)); - } - // None in Python corresponds to undefined Tensor in C++ - auto inputs_ = toTensorList(inputs); - auto outputs = fn(inputs_, args); - return jit::toPyObject(at::IValue(outputs)); - }); - py::handle handle(py_compiler); - handle.attr("bind_function")(fn_name, py_func); -} - -// Invokes py_compiler.method_name(fn_name, inputs, packed_args, -// output_metadata) -static variable_list call_function( - PyObject* py_compiler, - const char* method_name, - const std::string& fn_name, - const variable_list& inputs, - const ivalue_list& packed_args, - const c10::IValue& output_metadata) { - // convert ivalue_list -> PyObject* - PyObject* py_packed_args = - PyTuple_New(static_cast(packed_args.size())); - for (const auto i : c10::irange(packed_args.size())) { - py::object obj = jit::toPyObject(packed_args[i]); - Py_INCREF(obj.ptr()); - PyTuple_SET_ITEM(py_packed_args, i, obj.ptr()); - } - - // call the corresponding method on the py_compiler - py::handle handle(py_compiler); - py::object stuff = handle.attr(method_name)( - fn_name, - inputs, - py::handle(py_packed_args), - jit::toPyObject(output_metadata)); - - // Convert the output from PyObject* to vector - auto tmp = py::cast>>(stuff); - return toTensorList(tmp); -} - -struct PyCompilerInterfaceImpl : PyCompilerInterface { - void bind_function( - PyObject* py_compiler, - const std::string& fn_name, - functional_apply_t fn, - std::vector packed_args_schema) override { - return torch::dynamo::autograd::bind_function( - py_compiler, fn_name, std::move(fn), std::move(packed_args_schema)); - } - variable_list call_function( - PyObject* py_compiler, - const char* method_name, - const std::string& fn_name, - const variable_list& inputs, - const ivalue_list& packed_args, - const c10::IValue& output_metadata) override { - return torch::dynamo::autograd::call_function( - py_compiler, - method_name, - fn_name, - inputs, - packed_args, - output_metadata); - } -}; - static PyObject* wrap_int_list(const std::vector& inputs) { PyObject* pyinput = PyTuple_New(static_cast(inputs.size())); for (const auto i : c10::irange(inputs.size())) { @@ -200,22 +88,6 @@ static void check(bool result) { check(nullptr); } -static variable_list validate_outputs( - const variable_list& outputs, - const ivalue_list& args) { - auto r = PackedArgs(args); - auto value = r.unpack>>(); - auto new_outputs = outputs; - - torch::autograd::validate_outputs( - value, new_outputs, [&](const std::string& msg) { - std::ostringstream ss; - ss << "[Compiled Autograd Tracing:]" << msg; - return ss.str(); - }); - return new_outputs; -} - // snapshot of python verbose logging toggle static PyObject* python_verbose_logger = nullptr; @@ -785,8 +657,6 @@ static CacheNode* _compiled_autograd_impl( ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); - setPyCompilerInterface(std::make_unique()); - TraceState state = call_begin_capture( py_compiler, *cache, compiler_call, output_edges.size()); InputBuffers input_buffers; @@ -853,48 +723,16 @@ static CacheNode* _compiled_autograd_impl( SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call); variable_list outputs = call.node->apply_with_saved(inputs, saved); + saved.debug_asserts(); saved.before(call.node->next_edges()); - - auto input_metadata = get_input_metadata(call.node->next_edges()); - TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size()); - - // Lazily bind the `validate_outputs` function to Python. - static c10::once_flag flag; - c10::call_once(flag, [&]() { - auto schema = std::vector{IValuePacker< - std::vector>>::packed_type()}; - bind_function( - py_compiler.get(), "validate_outputs", validate_outputs, schema); - }); - - // Don't emit validate_outputs nodes that follow a CompiledBackward node. - // These nodes would otherwise prevent reordering of accumulate_grad - // nodes. - // - // Note that this will not cause correctness issues, because - // 1) AOTAutograd already coerces gradients to have the same metadata as - // the inputs. 2) the AOTAutograd graph already has the necessary - // aten::sum_to nodes in it (so it doesn't need to rely on - // validate_outputs to handle that). - // - // However, we may be dropping some (edge case) safety checks compared to - // eager: a backward that would have errored out in eager may not error - // out in compiled autograd (for example, if the user provided an - // incorrect number of gradients). - if (!call.node->is_aot_backward()) { - PackedArgs args; - args.pack(input_metadata); - ivalue_list input_metadata_state = std::move(args).vec(); - outputs = call_function( - py_compiler, - "validate_outputs", - "validate_outputs", - outputs, - input_metadata_state, - input_metadata_state[0]); - } - + validate_outputs( + call.node->next_edges(), outputs, [&](const std::string& msg) { + std::ostringstream ss; + ss << "[Compiled Autograd Tracing: " << call.node->name() << "] " + << msg; + return ss.str(); + }); saved.after(call.node->next_edges()); saved.debug_asserts(); @@ -923,7 +761,6 @@ static CacheNode* _compiled_autograd_impl( } } - resetPyCompilerInterface(); PyObject* res = check(call_end_capture(py_compiler, state.outputs)); TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); TORCH_CHECK(