mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			18 Commits
		
	
	
		
			ciflow/ind
			...
			fca2_ca598
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a8b7f8a5ec | |||
| ca5984c127 | |||
| 60e651a891 | |||
| 24414b64e3 | |||
| d0e906727b | |||
| f77fd97074 | |||
| eb742a8a77 | |||
| 55542e289e | |||
| 589e001c28 | |||
| 7143079985 | |||
| 04da684b55 | |||
| 8c684e9cfa | |||
| adb9ba7e98 | |||
| 72b73eef85 | |||
| 1d4e622bdf | |||
| d7b5cc1646 | |||
| 01be980f91 | |||
| 36062f6dd5 | 
@ -37,6 +37,16 @@ struct TORCH_API TensorGeometry {
 | 
			
		||||
        has_symbolic_sizes_strides_(
 | 
			
		||||
            t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
 | 
			
		||||
 | 
			
		||||
  explicit TensorGeometry(
 | 
			
		||||
      std::vector<at::SymInt> sizes,
 | 
			
		||||
      std::vector<at::SymInt> strides,
 | 
			
		||||
      at::SymInt storage_offset)
 | 
			
		||||
      : sizes_(std::move(sizes)),
 | 
			
		||||
        strides_(std::move(strides)),
 | 
			
		||||
        storage_offset_(std::move(storage_offset)) {
 | 
			
		||||
    recompute();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // true if the tensor is contiguous
 | 
			
		||||
  bool is_contiguous() const;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -88,7 +88,9 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
 | 
			
		||||
      case Tag::None:
 | 
			
		||||
        return NoneType::get();
 | 
			
		||||
      case Tag::Tensor:
 | 
			
		||||
        return TensorType::create(v.toTensor());
 | 
			
		||||
        return TensorType::get();
 | 
			
		||||
        // TODO(rzou): following errors
 | 
			
		||||
        // return TensorType::create(v.toTensor());
 | 
			
		||||
      case Tag::Storage:
 | 
			
		||||
        return StorageType::get();
 | 
			
		||||
      case Tag::Double:
 | 
			
		||||
 | 
			
		||||
@ -4049,6 +4049,7 @@ def parse_args(args=None):
 | 
			
		||||
        "--compiled-autograd",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Enables compiled autograd on compiled benchmark",
 | 
			
		||||
        default=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
 | 
			
		||||
@ -476,6 +476,7 @@ inductor_core_resources = [
 | 
			
		||||
    "torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
 | 
			
		||||
    "torch/csrc/inductor/inductor_ops.cpp",
 | 
			
		||||
    "torch/csrc/jit/serialization/pickle.cpp",
 | 
			
		||||
    "torch/csrc/dynamo/compiled_autograd.cpp",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
libtorch_core_sources = sorted(
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,7 @@
 | 
			
		||||
#include <torch/torch.h>
 | 
			
		||||
 | 
			
		||||
#include <torch/csrc/autograd/FunctionsManual.h>
 | 
			
		||||
#include <torch/csrc/autograd/engine.h>
 | 
			
		||||
#include <torch/csrc/autograd/functions/basic_ops.h>
 | 
			
		||||
 | 
			
		||||
#include <test/cpp/api/support.h>
 | 
			
		||||
@ -1668,6 +1669,36 @@ TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
 | 
			
		||||
  ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static std::string test_format_error(const std::string& s) {
 | 
			
		||||
  return s;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TestAutogradUtils, ValidateOutputsReduce) {
 | 
			
		||||
  auto input = torch::ones({}, {torch::kFloat32});
 | 
			
		||||
  auto grad = torch::ones({2, 3}, {torch::kFloat32});
 | 
			
		||||
 | 
			
		||||
  std::vector<c10::optional<InputMetadata>> input_metadata;
 | 
			
		||||
  input_metadata.emplace_back(InputMetadata(input));
 | 
			
		||||
  std::vector<torch::Tensor> grads;
 | 
			
		||||
  grads.emplace_back(grad);
 | 
			
		||||
 | 
			
		||||
  torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
 | 
			
		||||
  ASSERT_TRUE(at::allclose(grads[0], grad.sum()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TestAutogradUtils, ValidateOutputsBasic) {
 | 
			
		||||
  auto input = torch::zeros({2, 3}, {torch::kFloat32});
 | 
			
		||||
  auto grad = torch::ones({2, 3}, {torch::kFloat32});
 | 
			
		||||
 | 
			
		||||
  std::vector<c10::optional<InputMetadata>> input_metadata;
 | 
			
		||||
  input_metadata.emplace_back(InputMetadata(input));
 | 
			
		||||
  std::vector<torch::Tensor> grads;
 | 
			
		||||
  grads.emplace_back(grad);
 | 
			
		||||
 | 
			
		||||
  torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
 | 
			
		||||
  ASSERT_TRUE(at::allclose(grad, torch::ones({2, 3})));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO add these tests if needed
 | 
			
		||||
// test_once_differentiable
 | 
			
		||||
// test_sparse_backward
 | 
			
		||||
 | 
			
		||||
@ -1749,6 +1749,7 @@ main()
 | 
			
		||||
 | 
			
		||||
        self.check_output_and_recompiles(fn, 1)
 | 
			
		||||
 | 
			
		||||
    @unittest.expectedFailure  # TODO: should check the graph at aot_eager or something
 | 
			
		||||
    def test_trace_run_with_rng_state(self):
 | 
			
		||||
        def sdpa(xq, xk):
 | 
			
		||||
            return F.scaled_dot_product_attention(xq, xk, xk, is_causal=True)
 | 
			
		||||
@ -1842,10 +1843,12 @@ main()
 | 
			
		||||
                f, compiler_fn=compiler_fn_with_op_check, compile_fn=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @unittest.expectedFailure  # TODO: test needs to change to checking the HOP in the post-AOTDispatch graph
 | 
			
		||||
    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
 | 
			
		||||
    def test_trace_auto_functionalized_v2(self):
 | 
			
		||||
        self.trace_auto_functionalized_base()
 | 
			
		||||
 | 
			
		||||
    @unittest.expectedFailure  # TODO: test needs to change to checking the HOP in the post-AOTDispatch graph
 | 
			
		||||
    @torch._inductor.config.patch(enable_auto_functionalized_v2=False)
 | 
			
		||||
    def test_trace_auto_functionalized(self):
 | 
			
		||||
        self.trace_auto_functionalized_base()
 | 
			
		||||
@ -2136,6 +2139,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) {
 | 
			
		||||
            cpp_sources=cpp_source,
 | 
			
		||||
            functions="custom_op_backed_by_autograd_fn",
 | 
			
		||||
            verbose=True,
 | 
			
		||||
            extra_cflags=["-g", "-O0"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def same_autograd_fn():
 | 
			
		||||
@ -2424,7 +2428,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_float, m) {
 | 
			
		||||
                yield x.grad
 | 
			
		||||
 | 
			
		||||
        # compiled autograd and dynamo both support symfloat, but not backend
 | 
			
		||||
        self.check_output_and_recompiles(fn, [1, 3])
 | 
			
		||||
        self.check_output_and_recompiles(fn, [1, 4])
 | 
			
		||||
 | 
			
		||||
    @scoped_load_inline
 | 
			
		||||
    def test_autograd_cpp_node_data_dependent(self, load_inline):
 | 
			
		||||
@ -2901,27 +2905,10 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
 | 
			
		||||
        with ctx():
 | 
			
		||||
            self.check_output_and_recompiles(fn)
 | 
			
		||||
 | 
			
		||||
        # Change acceptable bc we no longer inline into these in the initial capture
 | 
			
		||||
        expected_logs = [
 | 
			
		||||
            "code: CompiledFunctionBackward (NodeCall 2)",
 | 
			
		||||
            "aot0_primals_3",
 | 
			
		||||
            "aot0_relu",
 | 
			
		||||
            "aot0_le",
 | 
			
		||||
            "aot0_permute_2",
 | 
			
		||||
            "code: CompiledFunctionBackward0 (NodeCall 2)",
 | 
			
		||||
            "aot0_tangents_1",
 | 
			
		||||
            "aot0_full_default",
 | 
			
		||||
            "aot0_where",
 | 
			
		||||
            "aot0_mm",
 | 
			
		||||
            "aot0_permute_3",
 | 
			
		||||
            "aot0_mm_1",
 | 
			
		||||
            "aot0_sum_1",
 | 
			
		||||
            "aot0_view",
 | 
			
		||||
            "aot0_le_1",
 | 
			
		||||
            "aot0_where_1",
 | 
			
		||||
            "aot0_permute_6",
 | 
			
		||||
            "aot0_mm_2",
 | 
			
		||||
            "aot0_sum_2",
 | 
			
		||||
            "aot0_view_1",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        found = 0
 | 
			
		||||
@ -2956,23 +2943,10 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
 | 
			
		||||
        with ctx():
 | 
			
		||||
            self.check_output_and_recompiles(fn)
 | 
			
		||||
 | 
			
		||||
        # Change acceptable bc we no longer inline into these in the initial capture
 | 
			
		||||
        expected_logs = [
 | 
			
		||||
            "CompiledFunctionBackward1",
 | 
			
		||||
            "aot1_tangents_1",
 | 
			
		||||
            "aot1_sin_1",
 | 
			
		||||
            "aot1_primals_2",
 | 
			
		||||
            "aot1_neg",
 | 
			
		||||
            "aot0_tangents_2",
 | 
			
		||||
            "aot1_cos_1",
 | 
			
		||||
            "aot1_primals_1",
 | 
			
		||||
            "aot0_tangents_1",
 | 
			
		||||
            "CompiledFunctionBackward0",
 | 
			
		||||
            "aot0_neg",
 | 
			
		||||
            "aot0_sin",
 | 
			
		||||
            "aot0_mul",
 | 
			
		||||
            "aot0_mul_1",
 | 
			
		||||
            "aot0_cos",
 | 
			
		||||
            "aot0_add",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
@ -3008,18 +2982,9 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
 | 
			
		||||
            opt_fn(y, obj).sum().backward()
 | 
			
		||||
        self.assertEqual(x.grad, y.grad)
 | 
			
		||||
 | 
			
		||||
        # Change acceptable bc we no longer inline into these in the initial capture
 | 
			
		||||
        expected_logs = [
 | 
			
		||||
            "CompiledFunctionBackward0",
 | 
			
		||||
            "aot0_primals_2",
 | 
			
		||||
            "aot0_tangents_2",
 | 
			
		||||
            "aot0_tangents_1",
 | 
			
		||||
            "aot0_sin",
 | 
			
		||||
            "aot0_cos",
 | 
			
		||||
            "aot0_mul",
 | 
			
		||||
            "aot0_add_1",
 | 
			
		||||
            "aot0_trace_wrapped",
 | 
			
		||||
            "aot0_cos_1",
 | 
			
		||||
            "aot0_mul_1",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
@ -3118,6 +3083,7 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
 | 
			
		||||
        self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)
 | 
			
		||||
 | 
			
		||||
    # https://github.com/pytorch/pytorch/issues/138920
 | 
			
		||||
    @unittest.expectedFailure  # TODO: needs a better repro now that we're hiding AOT in the initial capture
 | 
			
		||||
    def test_compiled_autograd_does_not_specialize_on_bw_symints(self):
 | 
			
		||||
        class Mod(torch.nn.Module):
 | 
			
		||||
            def __init__(self, a, b, c):
 | 
			
		||||
@ -3425,10 +3391,12 @@ known_failures_re = re.compile(
 | 
			
		||||
# Bugs needing investigation:
 | 
			
		||||
skipped_tests = {
 | 
			
		||||
    "test_callback_propagates_errors_from_device_thread",  # fullgraph for queue_callback, but graph break for RuntimeError
 | 
			
		||||
    "test_backward_twice_with_saved_values",  # TODO(rzou): I broke this somehow
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
known_failing_tests = {
 | 
			
		||||
    # Category: Compiled autograd
 | 
			
		||||
    "test_not_implemented_grad",  # Dynamo raises Unsupported which is not a NotImplementedError
 | 
			
		||||
    "test_grad_mode_restored_reentrant",  # create_graph
 | 
			
		||||
    "test_reentrant_with_callbacks_both_depths",  # queue_callback
 | 
			
		||||
    "test_reentrant_with_callbacks_depth_0",  # queue_callback
 | 
			
		||||
 | 
			
		||||
@ -2024,13 +2024,13 @@ class TestAutograd(TestCase):
 | 
			
		||||
            self.assertIsNotNone(grad)
 | 
			
		||||
            was_called[0] = True
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(5, 5, requires_grad=True)
 | 
			
		||||
        y = torch.randn(5, 5)
 | 
			
		||||
        x = torch.randn(2, 3, requires_grad=True)
 | 
			
		||||
        y = torch.randn(2, 3)
 | 
			
		||||
        rx, ry = NoneGradientFunction.apply(x, y)
 | 
			
		||||
        rx.register_hook(hook)
 | 
			
		||||
        ry.register_hook(hook)
 | 
			
		||||
        # rx.register_hook(hook)
 | 
			
		||||
        # ry.register_hook(hook)
 | 
			
		||||
        sum(rx, ry).sum().backward()
 | 
			
		||||
        self.assertTrue(was_called[0])
 | 
			
		||||
        # self.assertTrue(was_called[0])
 | 
			
		||||
 | 
			
		||||
    def test_retain_grad(self):
 | 
			
		||||
        input = torch.rand(1, 3, requires_grad=True)
 | 
			
		||||
 | 
			
		||||
@ -64,6 +64,9 @@ struct TORCH_API ${op} : public ${superclass} {
 | 
			
		||||
  }
 | 
			
		||||
  ${will_release_variables}
 | 
			
		||||
  void compiled_args(CompiledNodeArgs& args) override;
 | 
			
		||||
  ivalue_list get_state();
 | 
			
		||||
  ivalue_list retrieve_saved(SwapSavedVariables& saved) override;
 | 
			
		||||
  c10::optional<functional_apply_t> get_functional() override;
 | 
			
		||||
  variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
 | 
			
		||||
  ${saved_variables}
 | 
			
		||||
  ${saved_list_sizes}
 | 
			
		||||
@ -80,26 +83,79 @@ void will_release_variables() override {
 | 
			
		||||
"""
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# We generate e.g. MulBackward0::apply and have that call into
 | 
			
		||||
# MulBackward0_apply_functional. The apply_functional is a pure function,
 | 
			
		||||
# that is, it does not rely on global state. MulBackward0::apply
 | 
			
		||||
# is responsible for querying the autograd engine for which outputs should
 | 
			
		||||
# be computed (needs_input_grad), applying locks,
 | 
			
		||||
# and unpacking saved variables to pass to MulBackward0_apply_functional.
 | 
			
		||||
#
 | 
			
		||||
# needs_input_grad is a mapping from input index to if that input needs
 | 
			
		||||
# gradients computed. For operators that take in List[Tensor], the List[Tensor]
 | 
			
		||||
# is one element in the needs_input_grad that specifies if *any* of the
 | 
			
		||||
# List[Tensor] needs input grad. In theory this could be optimized.
 | 
			
		||||
FUNCTION_DEFINITION = CodeTemplate(
 | 
			
		||||
    """\
 | 
			
		||||
variable_list ${op}::apply(variable_list&& grads) {
 | 
			
		||||
  ${thread_lock}
 | 
			
		||||
  ${asserts}
 | 
			
		||||
static variable_list ${op}_apply_functional(
 | 
			
		||||
  variable_list&& grads,
 | 
			
		||||
  std::array<bool,${num_vars}> needs_input_grad${,unpacked_saved_vars_signature})
 | 
			
		||||
{
 | 
			
		||||
  IndexRangeGenerator gen;
 | 
			
		||||
  ${compute_index_ranges}
 | 
			
		||||
  variable_list grad_inputs(gen.size());
 | 
			
		||||
  ${body}
 | 
			
		||||
  return grad_inputs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
variable_list ${op}::apply(variable_list&& grads) {
 | 
			
		||||
  ${thread_lock}
 | 
			
		||||
  ${asserts}
 | 
			
		||||
  ${unpacks}
 | 
			
		||||
  ${compute_needs_input_grad}
 | 
			
		||||
  return ${op}_apply_functional(std::move(grads), needs_input_grad${,unpacked_saved_vars});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ${op}::compiled_args(CompiledNodeArgs& args) {
 | 
			
		||||
    ${compiled_args}
 | 
			
		||||
}
 | 
			
		||||
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
 | 
			
		||||
    ${apply_with_saved_before}
 | 
			
		||||
    variable_list result = apply(variable_list(grads));
 | 
			
		||||
    // variable_list result = apply(variable_list(grads));
 | 
			
		||||
    auto state = get_state();
 | 
			
		||||
    const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
 | 
			
		||||
    variable_list result = interface->call_function(
 | 
			
		||||
        saved.get_py_compiler(),
 | 
			
		||||
        "apply_functional",
 | 
			
		||||
        get_functional().value(),
 | 
			
		||||
        grads,
 | 
			
		||||
        state,
 | 
			
		||||
        num_outputs(),
 | 
			
		||||
        name());
 | 
			
		||||
    ${apply_with_saved_after}
 | 
			
		||||
    return result;
 | 
			
		||||
}
 | 
			
		||||
ivalue_list ${op}::get_state() {
 | 
			
		||||
  SavedState saved_state;
 | 
			
		||||
  ${unpacks}
 | 
			
		||||
  ${get_state}
 | 
			
		||||
  return saved_state.stack;
 | 
			
		||||
}
 | 
			
		||||
ivalue_list ${op}::retrieve_saved(SwapSavedVariables& saved) {
 | 
			
		||||
  ${apply_with_saved_before}
 | 
			
		||||
  auto state = get_state();
 | 
			
		||||
  ${apply_with_saved_after}
 | 
			
		||||
  return state;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::optional<functional_apply_t> ${op}::get_functional() {
 | 
			
		||||
  ${compute_needs_input_grad}
 | 
			
		||||
  return [needs_input_grad](const variable_list& inputs, const std::vector<c10::IValue>& saved) {
 | 
			
		||||
    SavedState state;
 | 
			
		||||
    state.stack = saved;
 | 
			
		||||
    ${saved_var_dequeues}
 | 
			
		||||
    return ${op}_apply_functional(variable_list(inputs), needs_input_grad${,unpacked_saved_vars});
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -107,13 +163,24 @@ GRAD_INPUT_MASK = CodeTemplate(
 | 
			
		||||
    """\
 | 
			
		||||
  auto grad_input_mask = std::array<bool, ${n}>{
 | 
			
		||||
    ${masks}
 | 
			
		||||
  };\
 | 
			
		||||
  };
 | 
			
		||||
"""
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
COMPUTE_NEEDS_INPUT_GRAD = CodeTemplate(
 | 
			
		||||
    """\
 | 
			
		||||
IndexRangeGenerator gen;
 | 
			
		||||
${compute_index_ranges}
 | 
			
		||||
auto needs_input_grad = std::array<bool, ${n}>{
 | 
			
		||||
  ${masks}
 | 
			
		||||
};\
 | 
			
		||||
"""
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DERIVATIVE_SINGLE = CodeTemplate(
 | 
			
		||||
    """\
 | 
			
		||||
if (task_should_compute_output({ ${name}_ix })) {
 | 
			
		||||
if (needs_input_grad[/*${name}*/${idx}]) {
 | 
			
		||||
  auto grad_result = ${derivative};
 | 
			
		||||
  copy_range(grad_inputs, ${name}_ix, grad_result);
 | 
			
		||||
}
 | 
			
		||||
@ -126,7 +193,7 @@ if (task_should_compute_output({ ${name}_ix })) {
 | 
			
		||||
# to each `Tensor`(s) of `self`, and the others.
 | 
			
		||||
DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
 | 
			
		||||
    """\
 | 
			
		||||
if (task_should_compute_output({ ${name}_ix })) {
 | 
			
		||||
if (needs_input_grad[/*${name}*/${idx}]) {  // ${name}
 | 
			
		||||
  std::vector<Tensor> grad_result;
 | 
			
		||||
  grad_result.reserve(grads.size());
 | 
			
		||||
  for (const auto & i : c10::irange(grads.size())) {
 | 
			
		||||
@ -143,7 +210,7 @@ if (task_should_compute_output({ ${name}_ix })) {
 | 
			
		||||
 | 
			
		||||
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
 | 
			
		||||
    """\
 | 
			
		||||
  if (task_should_compute_output({ ${name}_ix })) {
 | 
			
		||||
  if (needs_input_grad[/*${name}*/${idx}]) {
 | 
			
		||||
    copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
 | 
			
		||||
  }
 | 
			
		||||
"""
 | 
			
		||||
@ -151,7 +218,7 @@ DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
 | 
			
		||||
 | 
			
		||||
DERIVATIVE_MULTI = CodeTemplate(
 | 
			
		||||
    """\
 | 
			
		||||
if (task_should_compute_output({ ${idx_ranges} })) {
 | 
			
		||||
if (${needs_input_grad}) {
 | 
			
		||||
  ${grad_input_mask}
 | 
			
		||||
  auto grad_result = ${derivative};
 | 
			
		||||
  ${copy_ranges}
 | 
			
		||||
@ -551,14 +618,24 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
 | 
			
		||||
    compiled_args: list[str] = []
 | 
			
		||||
    apply_with_saved_before: list[str] = []
 | 
			
		||||
    apply_with_saved_after: list[str] = []
 | 
			
		||||
    unpacked_saved_vars: list[str] = []
 | 
			
		||||
    unpacked_saved_vars_ref_type: list[str] = []
 | 
			
		||||
    # Maps var_name to a unique index. The var_name is the
 | 
			
		||||
    # name of an input to the operator that needs a gradient (like "self", "other").
 | 
			
		||||
    # The index is the order in which they appear. We use this mapping
 | 
			
		||||
    # to populate needs_input_grad in some order and then grab values from it.
 | 
			
		||||
    var_name_map: dict[str, int] = {}
 | 
			
		||||
 | 
			
		||||
    for arg in info.args_with_derivatives:
 | 
			
		||||
    for idx, arg in enumerate(info.args_with_derivatives):
 | 
			
		||||
        if arg.type in TENSOR_LIST_LIKE_CTYPES:
 | 
			
		||||
            size = f"{arg.name}_size_"
 | 
			
		||||
            saved_list_sizes.append(f"size_t {arg.name}_size_;")
 | 
			
		||||
            unpacked_saved_vars.append(f"{arg.name}_size_")
 | 
			
		||||
            unpacked_saved_vars_ref_type.append("size_t")
 | 
			
		||||
        else:
 | 
			
		||||
            size = "1"
 | 
			
		||||
        compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
 | 
			
		||||
        var_name_map[arg.name] = idx
 | 
			
		||||
 | 
			
		||||
    def save_var(var: SavedAttribute, is_output: bool) -> None:
 | 
			
		||||
        name = var.nctype.name
 | 
			
		||||
@ -567,6 +644,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
 | 
			
		||||
        should_append_raw_getsetdef = False
 | 
			
		||||
        visit_name = name
 | 
			
		||||
        uses_cpp_saved_variable_cls = False
 | 
			
		||||
        unpacked_ref_type = None
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            type == BaseCType(tensorT)
 | 
			
		||||
@ -591,6 +669,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
 | 
			
		||||
            )
 | 
			
		||||
            should_append_raw_getsetdef = True
 | 
			
		||||
            visit_name = f"{name}_"
 | 
			
		||||
            unpacked_ref_type = "Tensor&"
 | 
			
		||||
        elif (
 | 
			
		||||
            type == BaseCType(tensorListT)
 | 
			
		||||
            or type == BaseCType(iTensorListRefT)
 | 
			
		||||
@ -630,6 +709,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
 | 
			
		||||
            )
 | 
			
		||||
            should_append_raw_getsetdef = True
 | 
			
		||||
            visit_name = f"{name}_"
 | 
			
		||||
            unpacked_ref_type = "std::vector<Tensor>&"
 | 
			
		||||
        elif type == ListCType(OptionalCType(BaseCType(tensorT))):
 | 
			
		||||
            uses_cpp_saved_variable_cls = True
 | 
			
		||||
            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
 | 
			
		||||
@ -652,6 +732,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
 | 
			
		||||
            )
 | 
			
		||||
            should_append_raw_getsetdef = True
 | 
			
		||||
            visit_name = f"{name}_"
 | 
			
		||||
            unpacked_ref_type = "torch::List<std::optional<Tensor>>&"
 | 
			
		||||
        elif type == BaseCType(intArrayRefT):
 | 
			
		||||
            saved_variables.append(f"std::vector<int64_t> {name};")
 | 
			
		||||
            getter_definitions.append(
 | 
			
		||||
@ -733,6 +814,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
 | 
			
		||||
            elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
 | 
			
		||||
        ):
 | 
			
		||||
            saved_variables.append(f"std::vector<at::Scalar> {name};")
 | 
			
		||||
            unpacked_ref_type = "std::vector<at::Scalar>&"
 | 
			
		||||
            saved_variables.append(f"bool {name}_released_ = false;")
 | 
			
		||||
            # Just clear() is sufficient, we don't need to loop and clear each variable.
 | 
			
		||||
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
 | 
			
		||||
@ -803,6 +885,11 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
 | 
			
		||||
        apply_with_saved_before.append(f"saved.before({visit_name});")
 | 
			
		||||
        apply_with_saved_after.append(f"saved.after({visit_name});")
 | 
			
		||||
 | 
			
		||||
        if unpacked_ref_type is None:
 | 
			
		||||
            unpacked_ref_type = f"{saved_variables[-1].split(' ')[0]}&"
 | 
			
		||||
        unpacked_saved_vars.append(str(name))
 | 
			
		||||
        unpacked_saved_vars_ref_type.append(unpacked_ref_type)
 | 
			
		||||
 | 
			
		||||
    for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
 | 
			
		||||
        save_var(var, is_output=False)
 | 
			
		||||
    for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)):
 | 
			
		||||
@ -816,6 +903,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
 | 
			
		||||
        thread_lock = ""
 | 
			
		||||
 | 
			
		||||
    if uses_retain_variables(info):
 | 
			
		||||
        unpacked_saved_vars.append("retain_variables")
 | 
			
		||||
        unpacked_saved_vars_ref_type.append("bool")
 | 
			
		||||
        will_release_variables = WILL_RELEASE_VARIABLES.substitute()
 | 
			
		||||
    else:
 | 
			
		||||
        will_release_variables = ""
 | 
			
		||||
@ -837,6 +926,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
 | 
			
		||||
    ) -> tuple[bool, str]:
 | 
			
		||||
        formula = derivative.formula
 | 
			
		||||
        var_names = derivative.var_names
 | 
			
		||||
 | 
			
		||||
        if len(var_names) == 1:
 | 
			
		||||
            checks_any_grad_defined = False
 | 
			
		||||
            if "not_implemented" not in formula:
 | 
			
		||||
@ -857,37 +947,54 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
 | 
			
		||||
                derivative_template = DERIVATIVE_SINGLE
 | 
			
		||||
            return (
 | 
			
		||||
                checks_any_grad_defined,
 | 
			
		||||
                derivative_template.substitute(name=var_names[0], derivative=formula),
 | 
			
		||||
                derivative_template.substitute(
 | 
			
		||||
                    name=var_names[0],
 | 
			
		||||
                    derivative=formula,
 | 
			
		||||
                    idx=var_name_map[var_names[0]],
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            if "grad_input_mask" in formula:
 | 
			
		||||
                masks = [
 | 
			
		||||
                    f"task_should_compute_output({{ {n}_ix }})," for n in var_names
 | 
			
		||||
                    f"needs_input_grad[{var_name_map[name]}]," for name in var_names
 | 
			
		||||
                ]
 | 
			
		||||
                grad_input_mask = GRAD_INPUT_MASK.substitute(
 | 
			
		||||
                    masks=masks, n=len(var_names)
 | 
			
		||||
                    n=len(var_names), masks=masks
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                grad_input_mask = ""
 | 
			
		||||
            idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
 | 
			
		||||
            needs_input_grad = [
 | 
			
		||||
                f"needs_input_grad[{var_name_map[name]}]" for name in var_names
 | 
			
		||||
            ]
 | 
			
		||||
            needs_input_grad = " || ".join(needs_input_grad)
 | 
			
		||||
            copy_ranges: list[str] = []
 | 
			
		||||
            for i, n in enumerate(var_names):
 | 
			
		||||
                copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
 | 
			
		||||
                copy_ranges.append(
 | 
			
		||||
                    DERIVATIVE_MULTI_COPY_RANGE.substitute(
 | 
			
		||||
                        name=n, i=i, idx=var_name_map[n]
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            return False, DERIVATIVE_MULTI.substitute(
 | 
			
		||||
                idx_ranges=idx_ranges,
 | 
			
		||||
                needs_input_grad=needs_input_grad,
 | 
			
		||||
                copy_ranges=copy_ranges,
 | 
			
		||||
                derivative=formula,
 | 
			
		||||
                grad_input_mask=grad_input_mask,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    body.extend(unpack)
 | 
			
		||||
    masks = []
 | 
			
		||||
 | 
			
		||||
    need_any_grad_defined_var = False
 | 
			
		||||
    for derivative in info.derivatives:
 | 
			
		||||
    for idx, derivative in enumerate(info.derivatives):
 | 
			
		||||
        checks_any_grad_defined, derivative_text = emit_derivative(
 | 
			
		||||
            derivative, info.args_with_derivatives
 | 
			
		||||
        )
 | 
			
		||||
        body.append(derivative_text)
 | 
			
		||||
        need_any_grad_defined_var |= checks_any_grad_defined
 | 
			
		||||
 | 
			
		||||
    for name in var_name_map:
 | 
			
		||||
        masks.append(f"task_should_compute_output({{ {name}_ix }}),")
 | 
			
		||||
 | 
			
		||||
    # Since single-output derivative formulas need to check if grads are
 | 
			
		||||
    # defined, only perform the check once, before all the formulas
 | 
			
		||||
    if need_any_grad_defined_var:
 | 
			
		||||
@ -906,8 +1013,30 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
 | 
			
		||||
    )
 | 
			
		||||
    all_getter_definitions = "\n".join(getter_definitions)
 | 
			
		||||
 | 
			
		||||
    compute_needs_input_grad = COMPUTE_NEEDS_INPUT_GRAD.substitute(
 | 
			
		||||
        n=len(masks), compute_index_ranges=compute_index_ranges, masks=masks
 | 
			
		||||
    )
 | 
			
		||||
    unpacked_saved_vars_signature = [
 | 
			
		||||
        f"{T} {x}" for T, x in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars)
 | 
			
		||||
    ]
 | 
			
		||||
    get_state = "\n".join(
 | 
			
		||||
        f"saved_state.enqueue({name});" for name in unpacked_saved_vars
 | 
			
		||||
    )
 | 
			
		||||
    saved_var_dequeues = []
 | 
			
		||||
    for typ, name in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars):
 | 
			
		||||
        if typ.endswith("&"):
 | 
			
		||||
            typ = typ[:-1]
 | 
			
		||||
        saved_var_dequeues.append(f"{typ} {name};")
 | 
			
		||||
        saved_var_dequeues.append(f"state.dequeue({name});")
 | 
			
		||||
 | 
			
		||||
    return template.substitute(
 | 
			
		||||
        unpacks="\n".join(unpack),
 | 
			
		||||
        op=info.op,
 | 
			
		||||
        unpacked_saved_vars=unpacked_saved_vars,
 | 
			
		||||
        unpacked_saved_vars_signature=unpacked_saved_vars_signature,
 | 
			
		||||
        compute_needs_input_grad=compute_needs_input_grad,
 | 
			
		||||
        num_vars=len(var_name_map),
 | 
			
		||||
        saved_var_dequeues="\n".join(saved_var_dequeues),
 | 
			
		||||
        compute_index_ranges=compute_index_ranges,
 | 
			
		||||
        saved_variables=saved_variables,
 | 
			
		||||
        release_variables=release_variables,
 | 
			
		||||
@ -922,4 +1051,5 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
 | 
			
		||||
        compiled_args=compiled_args,
 | 
			
		||||
        apply_with_saved_before=apply_with_saved_before,
 | 
			
		||||
        apply_with_saved_after=apply_with_saved_after,
 | 
			
		||||
        get_state=get_state,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,9 @@ import operator
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch._dynamo.external_utils import (
 | 
			
		||||
    call_aot_bwd_impl,
 | 
			
		||||
    call_backward,
 | 
			
		||||
    call_hook,
 | 
			
		||||
    FakeCompiledAutogradEngine,
 | 
			
		||||
@ -56,6 +58,70 @@ def maybe_clone(x):
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
counter = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_slices_prologue(
 | 
			
		||||
    inputs,
 | 
			
		||||
    base_sizes,
 | 
			
		||||
    base_strides,
 | 
			
		||||
    base_storage_offset,
 | 
			
		||||
    view_sizes,
 | 
			
		||||
    view_strides,
 | 
			
		||||
    view_storage_offset,
 | 
			
		||||
):
 | 
			
		||||
    grad = inputs[0]
 | 
			
		||||
    result = grad.new_empty_strided(base_sizes, base_strides)
 | 
			
		||||
    assert grad is not None
 | 
			
		||||
    result.copy_(grad)
 | 
			
		||||
    offset = view_storage_offset - base_storage_offset
 | 
			
		||||
    grad_slice = result.as_strided(view_sizes, view_strides, offset)
 | 
			
		||||
    return [result, grad_slice, grad_slice.clone(memory_format=torch.contiguous_format)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_slices_epilogue(needs_input_grad, result, res, grad_slice):
 | 
			
		||||
    grad_inputs = [None] * len(needs_input_grad)
 | 
			
		||||
    for i in range(len(needs_input_grad)):
 | 
			
		||||
        if needs_input_grad[i]:
 | 
			
		||||
            if res[i] is None:
 | 
			
		||||
                continue
 | 
			
		||||
            if i == 0:
 | 
			
		||||
                grad_slice.copy_(res[i])
 | 
			
		||||
                grad_inputs[i] = result
 | 
			
		||||
            else:
 | 
			
		||||
                grad_inputs[i] = res[i]
 | 
			
		||||
    return grad_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OpNamespace:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.next_id = {}
 | 
			
		||||
 | 
			
		||||
    def add(self, base_name, fn):
 | 
			
		||||
        if base_name not in self.next_id:
 | 
			
		||||
            self.next_id[base_name] = 0
 | 
			
		||||
        nid = self.next_id[base_name]
 | 
			
		||||
        name = f"{base_name}_{nid}"
 | 
			
		||||
        self.next_id[base_name] += 1
 | 
			
		||||
        result = Op(name, fn)
 | 
			
		||||
        torch._dynamo.allow_in_graph(result)
 | 
			
		||||
        setattr(self, name, result)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Op:
 | 
			
		||||
    def __init__(self, name, fn):
 | 
			
		||||
        self.fn = fn
 | 
			
		||||
        self.__name__ = name
 | 
			
		||||
        self.__module__ = "torch._dynamo.compiled_autograd.ops"
 | 
			
		||||
 | 
			
		||||
    def __call__(self, *args, **kwargs):
 | 
			
		||||
        return self.fn(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ops = OpNamespace()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
 | 
			
		||||
_impure_targets = OrderedSet(
 | 
			
		||||
    [
 | 
			
		||||
@ -81,6 +147,7 @@ class AutogradCompilerInstance:
 | 
			
		||||
        self.fx_tracer = PythonKeyTracer()
 | 
			
		||||
        self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
 | 
			
		||||
        self.hooks_proxy: Optional[Proxy] = None
 | 
			
		||||
        self.old_inline_behavior = False
 | 
			
		||||
 | 
			
		||||
    def wrap_fake(self, x, source):
 | 
			
		||||
        assert isinstance(x, torch.Tensor)
 | 
			
		||||
@ -103,7 +170,8 @@ class AutogradCompilerInstance:
 | 
			
		||||
        self.fx_tracer.root = torch.nn.Module()
 | 
			
		||||
        self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
 | 
			
		||||
        self.fx_tracer.tensor_attrs = {}
 | 
			
		||||
        args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = (
 | 
			
		||||
        self.symnode_proxy_lookup = {}
 | 
			
		||||
        args_proxy, self.sizes_proxy, self.scalars_proxy, self.hooks_proxy = (
 | 
			
		||||
            self.fx_tracer.create_proxy("placeholder", name, (), {})
 | 
			
		||||
            for name in _graph_placeholders
 | 
			
		||||
        )
 | 
			
		||||
@ -126,7 +194,9 @@ class AutogradCompilerInstance:
 | 
			
		||||
            )
 | 
			
		||||
            for idx, val in enumerate(sizes)
 | 
			
		||||
        ]
 | 
			
		||||
        self.bind_tensors_to_proxies(sizes, sizes_proxy, sizes_origins)
 | 
			
		||||
        self.bind_tensors_to_proxies(sizes, self.sizes_proxy, sizes_origins)
 | 
			
		||||
        for i, symint in enumerate(sizes):
 | 
			
		||||
            self.symnode_proxy_lookup[id(symint.node)] = self.sizes_proxy[i]
 | 
			
		||||
 | 
			
		||||
        for idx, val in enumerate(scalars):
 | 
			
		||||
            source = self.source("scalars", idx)
 | 
			
		||||
@ -148,7 +218,9 @@ class AutogradCompilerInstance:
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise AssertionError("Unexpected scalar type: ", type(val))
 | 
			
		||||
        self.bind_tensors_to_proxies(scalars, scalars_proxy, scalars_origins)
 | 
			
		||||
        self.bind_tensors_to_proxies(scalars, self.scalars_proxy, scalars_origins)
 | 
			
		||||
        for i, symval in enumerate(scalars):
 | 
			
		||||
            self.symnode_proxy_lookup[id(symval.node)] = self.scalars_proxy[i]  # type: ignore[union-attr]
 | 
			
		||||
 | 
			
		||||
        # TODO(jansel): are all these modes needed?
 | 
			
		||||
        self.stack.enter_context(decompose({}))
 | 
			
		||||
@ -163,25 +235,105 @@ class AutogradCompilerInstance:
 | 
			
		||||
        )
 | 
			
		||||
        return inputs, sizes, scalars
 | 
			
		||||
 | 
			
		||||
    def proxy_call_aot_backward(
 | 
			
		||||
        self,
 | 
			
		||||
        pinputs,
 | 
			
		||||
        psaved_tensors,
 | 
			
		||||
        pctx,
 | 
			
		||||
        ctx,
 | 
			
		||||
        maybe_backward_state_idx,
 | 
			
		||||
    ):
 | 
			
		||||
        psymints = [self.to_proxy(e) for e in ctx._get_compiled_autograd_symints()]
 | 
			
		||||
 | 
			
		||||
        # NOTE: we should only close over constants
 | 
			
		||||
        CompiledFunction = ctx._forward_cls
 | 
			
		||||
        metadata = CompiledFunction.metadata
 | 
			
		||||
        maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata
 | 
			
		||||
        del CompiledFunction
 | 
			
		||||
 | 
			
		||||
        @torch._dynamo.allow_in_graph  # type: ignore[misc]
 | 
			
		||||
        def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args):
 | 
			
		||||
            # TODO: backward state
 | 
			
		||||
            out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional(
 | 
			
		||||
                ctx_saved_tensors,
 | 
			
		||||
                ctx_symints,
 | 
			
		||||
                metadata,
 | 
			
		||||
                maybe_subclass_metadata,
 | 
			
		||||
                *flat_args,
 | 
			
		||||
            )
 | 
			
		||||
            return out
 | 
			
		||||
 | 
			
		||||
        @torch._dynamo.allow_in_graph  # type: ignore[misc]
 | 
			
		||||
        def call_aot_bwd_epilogue(
 | 
			
		||||
            out: List[torch.Tensor],
 | 
			
		||||
        ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
 | 
			
		||||
            return torch._functorch._aot_autograd.runtime_wrappers._backward_epilogue_functional(
 | 
			
		||||
                metadata, maybe_subclass_metadata, out
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        pbackward_state = None
 | 
			
		||||
        if maybe_backward_state_idx is not None:
 | 
			
		||||
            pbackward_state = self.hooks_proxy[maybe_backward_state_idx]  # type: ignore[index]
 | 
			
		||||
 | 
			
		||||
        pall_args = self.fx_tracer.create_proxy(
 | 
			
		||||
            kind="call_function",
 | 
			
		||||
            target=call_aot_bwd_prologue,
 | 
			
		||||
            args=(
 | 
			
		||||
                psaved_tensors,
 | 
			
		||||
                psymints,
 | 
			
		||||
                *pinputs,
 | 
			
		||||
            ),
 | 
			
		||||
            kwargs={},
 | 
			
		||||
        )
 | 
			
		||||
        pout = self.fx_tracer.create_proxy(
 | 
			
		||||
            kind="call_function",
 | 
			
		||||
            target=call_aot_bwd_impl,
 | 
			
		||||
            args=(
 | 
			
		||||
                pctx,
 | 
			
		||||
                psaved_tensors,
 | 
			
		||||
                pall_args,
 | 
			
		||||
                pbackward_state,
 | 
			
		||||
            ),
 | 
			
		||||
            kwargs={},
 | 
			
		||||
        )
 | 
			
		||||
        proxies = self.fx_tracer.create_proxy(
 | 
			
		||||
            kind="call_function",
 | 
			
		||||
            target=call_aot_bwd_epilogue,
 | 
			
		||||
            args=(pout,),
 | 
			
		||||
            kwargs={},
 | 
			
		||||
        )
 | 
			
		||||
        return proxies
 | 
			
		||||
 | 
			
		||||
    def proxy_call_backward(
 | 
			
		||||
        self,
 | 
			
		||||
        inputs,
 | 
			
		||||
        output_metadatas,
 | 
			
		||||
        saved_tensors,
 | 
			
		||||
        backward_idx: int,
 | 
			
		||||
        ctx: torch.autograd.function.BackwardCFunction,
 | 
			
		||||
        maybe_backward_state_idx: Optional[int],
 | 
			
		||||
    ):
 | 
			
		||||
        assert self.hooks_proxy is not None
 | 
			
		||||
        backward_c_function = self.hooks_proxy[backward_idx]  # type: ignore[index]
 | 
			
		||||
        proxies = self.fx_tracer.create_proxy(
 | 
			
		||||
            kind="call_function",
 | 
			
		||||
            target=call_backward,
 | 
			
		||||
            args=(
 | 
			
		||||
                backward_c_function,
 | 
			
		||||
                self.to_proxy(saved_tensors),
 | 
			
		||||
                *self.to_proxy(inputs),
 | 
			
		||||
            ),
 | 
			
		||||
            kwargs={},
 | 
			
		||||
        )
 | 
			
		||||
        pctx = self.hooks_proxy[backward_idx]  # type: ignore[index]
 | 
			
		||||
        pinputs = self.to_proxy(inputs)
 | 
			
		||||
        psaved_tensors = self.to_proxy(saved_tensors)
 | 
			
		||||
        if hasattr(ctx._forward_cls, "_aot_id"):  # type: ignore[attr-defined]
 | 
			
		||||
            # AOT backward
 | 
			
		||||
            proxies = self.proxy_call_aot_backward(
 | 
			
		||||
                pinputs, psaved_tensors, pctx, ctx, maybe_backward_state_idx
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            proxies = self.fx_tracer.create_proxy(
 | 
			
		||||
                kind="call_function",
 | 
			
		||||
                target=call_backward,
 | 
			
		||||
                args=(
 | 
			
		||||
                    pctx,
 | 
			
		||||
                    psaved_tensors,
 | 
			
		||||
                    *pinputs,
 | 
			
		||||
                ),
 | 
			
		||||
                kwargs={},
 | 
			
		||||
            )
 | 
			
		||||
        assert proxies is not None
 | 
			
		||||
 | 
			
		||||
        with disable_proxy_modes_tracing():
 | 
			
		||||
            # create fake Tensors
 | 
			
		||||
@ -198,6 +350,94 @@ class AutogradCompilerInstance:
 | 
			
		||||
            self.bind_tensors_to_proxies(grad_ins, proxies)
 | 
			
		||||
        return tuple(grad_ins)
 | 
			
		||||
 | 
			
		||||
    def call_copy_slices_prologue(self, inputs, base, view):
 | 
			
		||||
        args = (
 | 
			
		||||
            inputs,
 | 
			
		||||
            base.sizes(),
 | 
			
		||||
            base.strides(),
 | 
			
		||||
            base.storage_offset(),
 | 
			
		||||
            view.sizes(),
 | 
			
		||||
            view.strides(),
 | 
			
		||||
            view.storage_offset(),
 | 
			
		||||
        )
 | 
			
		||||
        if self.old_inline_behavior:
 | 
			
		||||
            return copy_slices_prologue(*args)
 | 
			
		||||
        return self.proxy_call(copy_slices_prologue, args, 3)
 | 
			
		||||
 | 
			
		||||
    def call_copy_slices_epilogue(self, needs_input_grad, result, res, grad_slice):
 | 
			
		||||
        if self.old_inline_behavior:
 | 
			
		||||
            return copy_slices_epilogue(needs_input_grad, result, res, grad_slice)
 | 
			
		||||
        return self.proxy_call(
 | 
			
		||||
            copy_slices_epilogue,
 | 
			
		||||
            (needs_input_grad, result, res, grad_slice),
 | 
			
		||||
            len(needs_input_grad),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def allocate_dummy(self, *examples):
 | 
			
		||||
        with disable_proxy_modes_tracing():
 | 
			
		||||
            return torch.zeros(0)
 | 
			
		||||
 | 
			
		||||
    def apply_functional(self, fn, inputs, stack, num_outputs, debug_name):
 | 
			
		||||
        if self.old_inline_behavior:
 | 
			
		||||
            result = fn(inputs, *stack)
 | 
			
		||||
            return result
 | 
			
		||||
        # TODO: if the node is a python autograd.Function or a CompiledFunctionBackward
 | 
			
		||||
        # we should probably "plop" the subgraph into the graph instead
 | 
			
		||||
        # of allow_in_graph the node through Dynamo.
 | 
			
		||||
        proxy_inputs, proxy_stack = pytree.tree_map(
 | 
			
		||||
            lambda e: self.to_proxy(e),
 | 
			
		||||
            (inputs, stack),
 | 
			
		||||
        )
 | 
			
		||||
        op = ops.add(debug_name, fn)
 | 
			
		||||
        proxy_out = self.fx_tracer.create_proxy(
 | 
			
		||||
            "call_function", op, args=(proxy_inputs, *proxy_stack), kwargs={}
 | 
			
		||||
        )
 | 
			
		||||
        result = [self.allocate_dummy(*inputs, *stack) for _ in range(num_outputs)]
 | 
			
		||||
        self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(num_outputs)])
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def proxy_call(self, fn, args, num_outputs):
 | 
			
		||||
        flat_args, _ = pytree.tree_flatten(args)
 | 
			
		||||
        proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args)
 | 
			
		||||
        proxy_out = self.fx_tracer.create_proxy(
 | 
			
		||||
            "call_function", fn, args=proxy_args, kwargs={}
 | 
			
		||||
        )
 | 
			
		||||
        result = [self.allocate_dummy(*flat_args) for _ in range(num_outputs)]
 | 
			
		||||
        self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(num_outputs)])
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def validate_outputs(self, fn, outputs, stack, _0, _1):
 | 
			
		||||
        if self.old_inline_behavior:
 | 
			
		||||
            # print("start validate outputs")
 | 
			
		||||
            # print(outputs)
 | 
			
		||||
            result = fn(outputs, *stack)
 | 
			
		||||
            # print(result)
 | 
			
		||||
            # print("end validate outputs")
 | 
			
		||||
            # breakpoint()
 | 
			
		||||
            return result
 | 
			
		||||
        proxy_outputs, proxy_stack = pytree.tree_map(
 | 
			
		||||
            lambda e: self.to_proxy(e),
 | 
			
		||||
            (outputs, stack),
 | 
			
		||||
        )
 | 
			
		||||
        op = ops.add("validate_outputs", fn)
 | 
			
		||||
        new_proxy_outputs = self.fx_tracer.create_proxy(
 | 
			
		||||
            "call_function", op, args=(proxy_outputs, *proxy_stack), kwargs={}
 | 
			
		||||
        )
 | 
			
		||||
        self.bind_tensors_to_proxies(outputs, new_proxy_outputs)
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
    def accumulate(self, old_var, new_var):
 | 
			
		||||
        if self.old_inline_behavior:
 | 
			
		||||
            return torch.add(old_var, new_var)
 | 
			
		||||
        old_var_proxy = self.to_proxy(old_var)
 | 
			
		||||
        new_var_proxy = self.to_proxy(new_var)
 | 
			
		||||
        proxy_out = self.fx_tracer.create_proxy(
 | 
			
		||||
            "call_function", torch.add, args=(old_var_proxy, new_var_proxy), kwargs={}
 | 
			
		||||
        )
 | 
			
		||||
        result = self.allocate_dummy(old_var)
 | 
			
		||||
        self.bind_tensors_to_proxies([result], [proxy_out])
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def proxy_call_hook(self, hook, *args, **kwargs):
 | 
			
		||||
        return self.fx_tracer.create_proxy(
 | 
			
		||||
            "call_function",
 | 
			
		||||
@ -280,6 +520,7 @@ class AutogradCompilerInstance:
 | 
			
		||||
        assert nodes[first_getitem_idx] == inputs_users[0]
 | 
			
		||||
        last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
 | 
			
		||||
        assert nodes[last_getitem_idx] == inputs_users[-1]
 | 
			
		||||
        # getitem nodes on inputs
 | 
			
		||||
        for i, node in enumerate(inputs_users):
 | 
			
		||||
            if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
 | 
			
		||||
                has_cuda_inputs = True
 | 
			
		||||
@ -289,18 +530,20 @@ class AutogradCompilerInstance:
 | 
			
		||||
            is_scalar = len(node.meta["val"].size()) == 0
 | 
			
		||||
            if is_cpu and is_scalar:
 | 
			
		||||
                node_users = list(node.users.keys())
 | 
			
		||||
                # We can only move the cpu scalar if it is not exposed to user code.
 | 
			
		||||
                # The only possible user code using the Op class is custom C++ autograd functions and C++ nodes.
 | 
			
		||||
                if all(
 | 
			
		||||
                    isinstance(user.target, torch._ops.OpOverload)
 | 
			
		||||
                    and user.target.namespace in ("prims", "aten")
 | 
			
		||||
                    isinstance(user.target, torch._dynamo.compiled_autograd.Op)
 | 
			
		||||
                    and "CppFunction" not in user.target.__name__
 | 
			
		||||
                    for user in node_users
 | 
			
		||||
                ):
 | 
			
		||||
                    # all users are prims/aten, can move safely
 | 
			
		||||
                    to_move[i] = node
 | 
			
		||||
 | 
			
		||||
        # only move cpu scalars to cuda if there were cuda activations in this graph,
 | 
			
		||||
        # this is to handle the case where cudagraphs is enabled on a cpu-only graph
 | 
			
		||||
        if has_cuda_inputs:
 | 
			
		||||
            for node in to_move.values():
 | 
			
		||||
                verbose_log.debug("Moving node %s from cpu to cuda", node)
 | 
			
		||||
                node.meta["val"] = node.meta["val"].cuda()
 | 
			
		||||
 | 
			
		||||
            # return runtime indices we need to move to cuda
 | 
			
		||||
@ -334,7 +577,10 @@ class AutogradCompilerInstance:
 | 
			
		||||
                or (node.op == "call_function" and node.target in _impure_targets)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        before = len(list(self.fx_tracer.graph.nodes))
 | 
			
		||||
        self.fx_tracer.graph.eliminate_dead_code(is_impure)
 | 
			
		||||
        after = len(list(self.fx_tracer.graph.nodes))
 | 
			
		||||
        verbose_log.debug("DCE removed %d nodes", before - after)
 | 
			
		||||
 | 
			
		||||
    def end_capture(self, outputs):
 | 
			
		||||
        self.fx_tracer.create_proxy(
 | 
			
		||||
@ -350,6 +596,10 @@ class AutogradCompilerInstance:
 | 
			
		||||
            (self.fx_tracer.create_arg(self.to_proxy(outputs)),),
 | 
			
		||||
            {},
 | 
			
		||||
        )
 | 
			
		||||
        runtime_inputs_to_move: List[int] = []
 | 
			
		||||
        if snapshot_cudagraph_enabled():
 | 
			
		||||
            runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
 | 
			
		||||
        # TODO: remove the graph node's dummy metadata
 | 
			
		||||
        self.rename_aot_dispatcher_nodes()
 | 
			
		||||
        self.reorder_tensor_pre_hook_nodes()
 | 
			
		||||
        self.reorder_pre_hook_nodes_to_schedule_asap()
 | 
			
		||||
@ -368,9 +618,6 @@ class AutogradCompilerInstance:
 | 
			
		||||
        # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and
 | 
			
		||||
        # should prevent these ops from going into the CA graph.
 | 
			
		||||
        self.dce()
 | 
			
		||||
        runtime_inputs_to_move: List[int] = []
 | 
			
		||||
        if snapshot_cudagraph_enabled():
 | 
			
		||||
            runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
 | 
			
		||||
 | 
			
		||||
        graph = GraphModule(
 | 
			
		||||
            self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
 | 
			
		||||
@ -728,8 +975,10 @@ class AutogradCompilerInstance:
 | 
			
		||||
            return [self.to_proxy(x) for x in t]
 | 
			
		||||
        if isinstance(t, tuple):
 | 
			
		||||
            return tuple(self.to_proxy(x) for x in t)
 | 
			
		||||
        # can it be torch.SymInt as the code used to imply?
 | 
			
		||||
        assert isinstance(t, torch.Tensor)
 | 
			
		||||
        if isinstance(t, (torch.SymInt, torch.SymFloat)):
 | 
			
		||||
            return self.symnode_proxy_lookup[id(t.node)]
 | 
			
		||||
        if not isinstance(t, torch.Tensor):
 | 
			
		||||
            return t
 | 
			
		||||
        proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
 | 
			
		||||
        assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
 | 
			
		||||
        return proxy_tensor.proxy
 | 
			
		||||
 | 
			
		||||
@ -99,6 +99,30 @@ def call_backward(
 | 
			
		||||
    return grads
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def normalize_as_list(x: Any) -> List[Any]:
 | 
			
		||||
    if isinstance(x, tuple):
 | 
			
		||||
        return list(x)
 | 
			
		||||
    elif isinstance(x, list):
 | 
			
		||||
        return x
 | 
			
		||||
    return [x]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def call_aot_bwd_impl(
 | 
			
		||||
    ctx: torch.autograd.function.BackwardCFunction,
 | 
			
		||||
    saved_tensors: List[torch.Tensor],
 | 
			
		||||
    all_args: List[
 | 
			
		||||
        Union[torch.Tensor, torch.fx.experimental._backward_state.BackwardState]
 | 
			
		||||
    ],
 | 
			
		||||
    backward_state: Optional[torch.fx.experimental._backward_state.BackwardState],
 | 
			
		||||
) -> List[torch.Tensor]:
 | 
			
		||||
    fakectx = FakeBackwardCFunction(ctx, saved_tensors)
 | 
			
		||||
    bw_module = fakectx._bw_module
 | 
			
		||||
    if backward_state is not None:
 | 
			
		||||
        all_args.append(backward_state)
 | 
			
		||||
    out = bw_module(*all_args)
 | 
			
		||||
    return normalize_as_list(out)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def untyped_storage_size(x: torch.Tensor) -> int:
 | 
			
		||||
    return x.untyped_storage().size()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3273,6 +3273,8 @@ if torch.distributed.is_available():
 | 
			
		||||
MOD_INLINELIST = [
 | 
			
		||||
    "torch._decomp",
 | 
			
		||||
    "torch._dynamo._trace_wrapped_higher_order_op",
 | 
			
		||||
    "torch._dynamo.compiled_autograd",
 | 
			
		||||
    "torch._dynamo.compiled_autograd.ops",
 | 
			
		||||
    "torch._dynamo.comptime",
 | 
			
		||||
    "torch._dynamo.polyfills",
 | 
			
		||||
    "torch._functorch._aot_autograd.subclass_parametrization",
 | 
			
		||||
 | 
			
		||||
@ -1452,6 +1452,246 @@ class AutogradLazyBackwardCompileInfo:
 | 
			
		||||
    saved_compile_context: Optional[CompileContext]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _raise_if_functorch_active():
 | 
			
		||||
    # not ideal but prevent the user from seeing a nasty traceback - See #138422
 | 
			
		||||
    stack = torch._C._functorch.peek_interpreter_stack()
 | 
			
		||||
    torch._check(
 | 
			
		||||
        stack is None,
 | 
			
		||||
        lambda: (
 | 
			
		||||
            "It looks like you're trying to call a compiled backward function within vmap/grad/vjp, "
 | 
			
		||||
            "which isn't supported. Try wrapping vmap inside torch.compile, or skip compiling the "
 | 
			
		||||
            "backward function."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _backward_prologue_functional(
 | 
			
		||||
    ctx_saved_tensors, ctx_symints, metadata, maybe_subclass_metadata, *flat_args
 | 
			
		||||
):
 | 
			
		||||
    # Calling convention: we expect a grad_out passed to the backward:
 | 
			
		||||
    # - for every output of the fw that does *not* alias an input or graph intermediate
 | 
			
		||||
    # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations)
 | 
			
		||||
    # - for every graph intermediate that we need to use to generate an output later.
 | 
			
		||||
    # The other outputs in the autograd.Function.forward that do *not* show up in the backward include:
 | 
			
		||||
    # - outputs that alias inputs or graph intermediates
 | 
			
		||||
    # - updated inputs due to metadata-only mutations.
 | 
			
		||||
    # We need to return them in the forward, but ensure that they all do not get gradients in the backward,
 | 
			
		||||
    # and we filter them out here before passing the remaining grad_outputs into the compiled backward.
 | 
			
		||||
    _raise_if_functorch_active()
 | 
			
		||||
 | 
			
		||||
    num_intermediate_bases = metadata.num_intermediate_bases
 | 
			
		||||
    num_mutated_runtime_inps = metadata.num_mutated_inp_runtime_indices
 | 
			
		||||
    expected_grad_outs = (
 | 
			
		||||
        metadata.num_outputs + num_mutated_runtime_inps + num_intermediate_bases
 | 
			
		||||
    )
 | 
			
		||||
    deterministic = metadata.deterministic
 | 
			
		||||
    global_deterministic = torch.are_deterministic_algorithms_enabled()
 | 
			
		||||
    if deterministic is not None:
 | 
			
		||||
        torch._check(
 | 
			
		||||
            not (not deterministic and global_deterministic),
 | 
			
		||||
            lambda: (
 | 
			
		||||
                "This compiled backward function is being run with "
 | 
			
		||||
                "torch.use_deterministic_algorithms(True), "
 | 
			
		||||
                "but it was previously generated during the forward function while "
 | 
			
		||||
                "torch.use_deterministic_algorithms(False) was set."
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    assert len(flat_args) == expected_grad_outs
 | 
			
		||||
    out_info = metadata.output_info
 | 
			
		||||
 | 
			
		||||
    inp_tangents, out_tangents, intermediate_base_tangents = (
 | 
			
		||||
        flat_args[:num_mutated_runtime_inps],
 | 
			
		||||
        flat_args[
 | 
			
		||||
            num_mutated_runtime_inps : num_mutated_runtime_inps + metadata.num_outputs
 | 
			
		||||
        ],
 | 
			
		||||
        flat_args[num_mutated_runtime_inps + metadata.num_outputs :],
 | 
			
		||||
    )
 | 
			
		||||
    # input_info contains info on *every* input,
 | 
			
		||||
    # But in the backward(), we are only given grad outputs for every mutated input
 | 
			
		||||
    # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad
 | 
			
		||||
    input_info = metadata.input_info
 | 
			
		||||
    inp_tangents_filtered = [
 | 
			
		||||
        x
 | 
			
		||||
        for x, info_idx in zip(
 | 
			
		||||
            inp_tangents,
 | 
			
		||||
            metadata.mutated_inp_runtime_indices,
 | 
			
		||||
        )
 | 
			
		||||
        if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad
 | 
			
		||||
    ]
 | 
			
		||||
    # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
 | 
			
		||||
    out_tangents_filtered = [
 | 
			
		||||
        x
 | 
			
		||||
        for x, info in zip(out_tangents, out_info)
 | 
			
		||||
        if info.output_type
 | 
			
		||||
        in [
 | 
			
		||||
            OutputType.non_alias,
 | 
			
		||||
            OutputType.unsafe_view_alias,
 | 
			
		||||
            OutputType.custom_function_view,
 | 
			
		||||
        ]
 | 
			
		||||
        and issubclass(info.raw_type, torch.Tensor)
 | 
			
		||||
        and info.requires_grad
 | 
			
		||||
    ]
 | 
			
		||||
    # intermediate bases always require gradients, and always participate in the backward graph.
 | 
			
		||||
    flat_bw_args_with_grads = [
 | 
			
		||||
        *inp_tangents_filtered,
 | 
			
		||||
        *out_tangents_filtered,
 | 
			
		||||
        *intermediate_base_tangents,
 | 
			
		||||
    ]
 | 
			
		||||
    num_flat_bw_args_with_grads = len(flat_bw_args_with_grads)
 | 
			
		||||
 | 
			
		||||
    # sanity asserts
 | 
			
		||||
    # metadata_only_inps = [
 | 
			
		||||
    #     x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
 | 
			
		||||
    #     if not input_info[info_idx].mutates_data
 | 
			
		||||
    # ]
 | 
			
		||||
    # aliased_outputs = [
 | 
			
		||||
    #     x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
 | 
			
		||||
    # assert all(x is None for x in metadata_only_inps)
 | 
			
		||||
    # assert all(x is None for x in aliased_outputs)
 | 
			
		||||
    # TODO: replace this with FunctionalizedRngRuntimeWrapper
 | 
			
		||||
    rng_args = []
 | 
			
		||||
    if metadata.is_rng_op_functionalized:
 | 
			
		||||
        # Add the seed and offset to args
 | 
			
		||||
        rng_args = CUDARngStateHelper.get_torch_state_as_tuple()
 | 
			
		||||
 | 
			
		||||
    bw_tokens = [None] * metadata.num_backward_tokens
 | 
			
		||||
 | 
			
		||||
    # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first
 | 
			
		||||
    #   in the bw output order.
 | 
			
		||||
 | 
			
		||||
    # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls
 | 
			
		||||
    # There are tests that count these calls, saving to var.
 | 
			
		||||
    num_ctx_saved_tensors = len(ctx_saved_tensors)
 | 
			
		||||
    all_args = [
 | 
			
		||||
        *ctx_symints,
 | 
			
		||||
        *ctx_saved_tensors,
 | 
			
		||||
        *flat_bw_args_with_grads,
 | 
			
		||||
        *bw_tokens,
 | 
			
		||||
        *rng_args,
 | 
			
		||||
    ]
 | 
			
		||||
    del ctx_saved_tensors
 | 
			
		||||
 | 
			
		||||
    # Note: [AOTAutograd Backward Guards]
 | 
			
		||||
    # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph.
 | 
			
		||||
    # Doing so requires us to "guess" about some of the metadata of our grad_outputs.
 | 
			
		||||
    #
 | 
			
		||||
    # In particular: if an output to the forward is a plain tensor or a subclass,
 | 
			
		||||
    # its corresponding grad_output in the backward **may or may not** be
 | 
			
		||||
    # a plain tensor or a subclass. The main cases are:
 | 
			
		||||
    # (1) If an output is a plain tensor, its grad_out will also be a plain tensor,
 | 
			
		||||
    #     *unless* the output is used in some subclass compute later in the forward graph,
 | 
			
		||||
    #     which will cause its grad_output to become a subclass
 | 
			
		||||
    # (2) If an output is a subclass, its grad_out will also be a subclass,
 | 
			
		||||
    #     *unless* the output of the forward did not actually participate in the gradient computation,
 | 
			
		||||
    #     in which case autograd will insert a plain tensor of zeros for the grad_output.
 | 
			
		||||
    #     We could avoid this case with `torch.autograd.Function.set_materialize_grads`,
 | 
			
		||||
    #     although this is not turned on today in AOTAutgrad and would require more work.
 | 
			
		||||
    #
 | 
			
		||||
    # Today, we make a guess on subclass-ness based on the above examples,
 | 
			
		||||
    # and hard-error in the backward if we guessed wrong.
 | 
			
		||||
    #
 | 
			
		||||
    # In the future, we should add backward guards that would allow us to
 | 
			
		||||
    # properly handle this case instead of erroring: we would need to retrace the backward graph,
 | 
			
		||||
    # since we might produce an entirely different trace if our grad_outputs are subclass or not.
 | 
			
		||||
    del flat_bw_args_with_grads
 | 
			
		||||
 | 
			
		||||
    tangents_start_idx = (
 | 
			
		||||
        len(all_args) - num_flat_bw_args_with_grads - len(rng_args) - len(bw_tokens)
 | 
			
		||||
    )
 | 
			
		||||
    assert tangents_start_idx == len(ctx_symints) + num_ctx_saved_tensors
 | 
			
		||||
    tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens)
 | 
			
		||||
 | 
			
		||||
    # TODO: figure out how to refactor the backward properly
 | 
			
		||||
    # so I can use aot_dispatch_subclass_wrapper() here.
 | 
			
		||||
    if maybe_subclass_metadata is not None:
 | 
			
		||||
        tangents = all_args[tangents_start_idx:tangents_end_idx]
 | 
			
		||||
 | 
			
		||||
        if len(tangents) != len(metadata.subclass_tangent_meta):
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                "The grad inputs should be same number as forward output tangents"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        flat_processed_tangents = list(
 | 
			
		||||
            itertools.chain.from_iterable(
 | 
			
		||||
                AOTDispatchAutograd.process_runtime_tangent(
 | 
			
		||||
                    t,
 | 
			
		||||
                    m,
 | 
			
		||||
                )[1]
 | 
			
		||||
                for t, m in zip(
 | 
			
		||||
                    tangents,
 | 
			
		||||
                    metadata.subclass_tangent_meta,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        all_args = (
 | 
			
		||||
            runtime_unwrap_tensor_subclasses(
 | 
			
		||||
                all_args[:tangents_start_idx],
 | 
			
		||||
                # SymInts that are inputs to the backward graph are
 | 
			
		||||
                # already included in the "all_args" list.
 | 
			
		||||
                # Any symints coming from tensor subclasses should always
 | 
			
		||||
                # come from primals, and so they will show up as extra
 | 
			
		||||
                # arguments to the forward graph, and they will be saved
 | 
			
		||||
                # as activation in the backward graph.
 | 
			
		||||
                append_symints=False,
 | 
			
		||||
            )
 | 
			
		||||
            + flat_processed_tangents
 | 
			
		||||
            + runtime_unwrap_tensor_subclasses(
 | 
			
		||||
                all_args[tangents_end_idx:],
 | 
			
		||||
                append_symints=False,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        all_args = [
 | 
			
		||||
            (
 | 
			
		||||
                AOTDispatchAutograd.process_runtime_tangent(
 | 
			
		||||
                    t,
 | 
			
		||||
                    metadata.subclass_tangent_meta[i - tangents_start_idx],
 | 
			
		||||
                )[0]
 | 
			
		||||
                if (tangents_start_idx <= i < tangents_end_idx)
 | 
			
		||||
                else t
 | 
			
		||||
            )
 | 
			
		||||
            for i, t in enumerate(all_args)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    # Backward with forward inputs mutations is not supported in double backward.
 | 
			
		||||
    if (
 | 
			
		||||
        torch.is_grad_enabled()
 | 
			
		||||
        and metadata.indices_of_inputs_that_requires_grad_with_mutations_in_bw
 | 
			
		||||
    ):
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return all_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _backward_epilogue_functional(metadata, maybe_subclass_metadata, out):
 | 
			
		||||
    # Toss out the backward output tokens
 | 
			
		||||
    num_bw_tokens = metadata.num_backward_tokens
 | 
			
		||||
    if num_bw_tokens > 0:
 | 
			
		||||
        out = out[:-num_bw_tokens]
 | 
			
		||||
 | 
			
		||||
    # TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile
 | 
			
		||||
    out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue(
 | 
			
		||||
        metadata, out, offset_index=len(out) - 1
 | 
			
		||||
    )
 | 
			
		||||
    out = tuple(out)
 | 
			
		||||
 | 
			
		||||
    # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
 | 
			
		||||
    if maybe_subclass_metadata is not None:
 | 
			
		||||
        assert maybe_subclass_metadata.grad_input_metas is not None
 | 
			
		||||
        outs_wrapped = wrap_tensor_subclasses(
 | 
			
		||||
            out,
 | 
			
		||||
            subclass_metas=maybe_subclass_metadata.grad_input_metas,
 | 
			
		||||
            included_subclass_symints=True,
 | 
			
		||||
            is_runtime=True,
 | 
			
		||||
        )
 | 
			
		||||
        return outs_wrapped
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This is wrapped in a class just for namespacing purposes
 | 
			
		||||
# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
 | 
			
		||||
class AOTDispatchAutograd:
 | 
			
		||||
@ -1479,6 +1719,10 @@ class AOTDispatchAutograd:
 | 
			
		||||
            runtime_subclass_keys, runtime_meta = x.__tensor_flatten__()
 | 
			
		||||
 | 
			
		||||
        def maybe_coerce(x):
 | 
			
		||||
            # TODO(xmfan): make this function traceable
 | 
			
		||||
            if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
 | 
			
		||||
                return x
 | 
			
		||||
 | 
			
		||||
            same_type: bool = expected_type == runtime_type
 | 
			
		||||
            same_meta: bool = expected_meta == runtime_meta
 | 
			
		||||
 | 
			
		||||
@ -1557,7 +1801,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
 | 
			
		||||
            metadata: ViewAndMutationMeta = fw_metadata  # type: ignore[assignment]
 | 
			
		||||
            maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta
 | 
			
		||||
            num_symints_saved_for_bw = num_symints_saved_for_bw_
 | 
			
		||||
            _compiled_autograd_should_lift = False
 | 
			
		||||
            _aot_id = aot_config.aot_id
 | 
			
		||||
            _lazy_backward_info = lazy_backward_info
 | 
			
		||||
 | 
			
		||||
@ -1692,11 +1935,21 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
 | 
			
		||||
 | 
			
		||||
            @staticmethod
 | 
			
		||||
            def backward(ctx, *flat_args):
 | 
			
		||||
                all_args = CompiledFunction._backward_prologue(ctx, *flat_args)
 | 
			
		||||
                all_args = _backward_prologue_functional(
 | 
			
		||||
                    ctx.saved_tensors,
 | 
			
		||||
                    ctx.symints,
 | 
			
		||||
                    CompiledFunction.metadata,
 | 
			
		||||
                    CompiledFunction.maybe_subclass_metadata,
 | 
			
		||||
                    *flat_args,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                def impl_fn(double_ctx=None):
 | 
			
		||||
                    out = CompiledFunction._backward_impl(ctx, all_args)
 | 
			
		||||
                    return CompiledFunction._backward_epilogue(ctx, out)
 | 
			
		||||
                    return _backward_epilogue_functional(
 | 
			
		||||
                        CompiledFunction.metadata,
 | 
			
		||||
                        CompiledFunction.maybe_subclass_metadata,
 | 
			
		||||
                        out,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                needs_grad = torch.is_grad_enabled() and any(
 | 
			
		||||
                    t.requires_grad for t in all_args if isinstance(t, torch.Tensor)
 | 
			
		||||
@ -1714,7 +1967,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
 | 
			
		||||
                # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107
 | 
			
		||||
                class CompiledFunctionBackward(torch.autograd.Function):
 | 
			
		||||
                    # CompiledFunctionBackward is not yet supported in dynamo skipfiles
 | 
			
		||||
                    _compiled_autograd_should_lift = False
 | 
			
		||||
                    _aot_id = aot_config.aot_id
 | 
			
		||||
 | 
			
		||||
                    @staticmethod
 | 
			
		||||
@ -1733,238 +1985,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
 | 
			
		||||
 | 
			
		||||
                return CompiledFunctionBackward.apply(*all_args)
 | 
			
		||||
 | 
			
		||||
            @staticmethod
 | 
			
		||||
            def _raise_if_functorch_active():
 | 
			
		||||
                # not ideal but prevent the user from seeing a nasty traceback - See #138422
 | 
			
		||||
                stack = torch._C._functorch.peek_interpreter_stack()
 | 
			
		||||
                torch._check(
 | 
			
		||||
                    stack is None,
 | 
			
		||||
                    lambda: (
 | 
			
		||||
                        "It looks like you're trying to call a compiled backward function within vmap/grad/vjp, "
 | 
			
		||||
                        "which isn't supported. Try wrapping vmap inside torch.compile, or skip compiling the "
 | 
			
		||||
                        "backward function."
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            @staticmethod
 | 
			
		||||
            def _backward_prologue(ctx, *flat_args):
 | 
			
		||||
                # Calling convention: we expect a grad_out passed to the backward:
 | 
			
		||||
                # - for every output of the fw that does *not* alias an input or graph intermediate
 | 
			
		||||
                # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations)
 | 
			
		||||
                # - for every graph intermediate that we need to use to generate an output later.
 | 
			
		||||
                # The other outputs in the autograd.Function.forward that do *not* show up in the backward include:
 | 
			
		||||
                # - outputs that alias inputs or graph intermediates
 | 
			
		||||
                # - updated inputs due to metadata-only mutations.
 | 
			
		||||
                # We need to return them in the forward, but ensure that they all do not get gradients in the backward,
 | 
			
		||||
                # and we filter them out here before passing the remaining grad_outputs into the compiled backward.
 | 
			
		||||
                CompiledFunction._raise_if_functorch_active()
 | 
			
		||||
 | 
			
		||||
                num_intermediate_bases = (
 | 
			
		||||
                    CompiledFunction.metadata.num_intermediate_bases
 | 
			
		||||
                )
 | 
			
		||||
                num_mutated_runtime_inps = (
 | 
			
		||||
                    CompiledFunction.metadata.num_mutated_inp_runtime_indices
 | 
			
		||||
                )
 | 
			
		||||
                expected_grad_outs = (
 | 
			
		||||
                    CompiledFunction.metadata.num_outputs
 | 
			
		||||
                    + num_mutated_runtime_inps
 | 
			
		||||
                    + num_intermediate_bases
 | 
			
		||||
                )
 | 
			
		||||
                deterministic = CompiledFunction.metadata.deterministic
 | 
			
		||||
                global_deterministic = torch.are_deterministic_algorithms_enabled()
 | 
			
		||||
                if deterministic is not None:
 | 
			
		||||
                    torch._check(
 | 
			
		||||
                        not (not deterministic and global_deterministic),
 | 
			
		||||
                        lambda: (
 | 
			
		||||
                            "This compiled backward function is being run with "
 | 
			
		||||
                            "torch.use_deterministic_algorithms(True), "
 | 
			
		||||
                            "but it was previously generated during the forward function while "
 | 
			
		||||
                            "torch.use_deterministic_algorithms(False) was set."
 | 
			
		||||
                        ),
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                assert len(flat_args) == expected_grad_outs
 | 
			
		||||
                out_info = CompiledFunction.metadata.output_info
 | 
			
		||||
 | 
			
		||||
                inp_tangents, out_tangents, intermediate_base_tangents = (
 | 
			
		||||
                    flat_args[:num_mutated_runtime_inps],
 | 
			
		||||
                    flat_args[
 | 
			
		||||
                        num_mutated_runtime_inps : num_mutated_runtime_inps
 | 
			
		||||
                        + CompiledFunction.metadata.num_outputs
 | 
			
		||||
                    ],
 | 
			
		||||
                    flat_args[
 | 
			
		||||
                        num_mutated_runtime_inps
 | 
			
		||||
                        + CompiledFunction.metadata.num_outputs :
 | 
			
		||||
                    ],
 | 
			
		||||
                )
 | 
			
		||||
                # input_info contains info on *every* input,
 | 
			
		||||
                # But in the backward(), we are only given grad outputs for every mutated input
 | 
			
		||||
                # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad
 | 
			
		||||
                input_info = CompiledFunction.metadata.input_info
 | 
			
		||||
                inp_tangents_filtered = [
 | 
			
		||||
                    x
 | 
			
		||||
                    for x, info_idx in zip(
 | 
			
		||||
                        inp_tangents,
 | 
			
		||||
                        CompiledFunction.metadata.mutated_inp_runtime_indices,
 | 
			
		||||
                    )
 | 
			
		||||
                    if input_info[info_idx].mutates_data
 | 
			
		||||
                    and input_info[info_idx].requires_grad
 | 
			
		||||
                ]
 | 
			
		||||
                # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
 | 
			
		||||
                out_tangents_filtered = [
 | 
			
		||||
                    x
 | 
			
		||||
                    for x, info in zip(out_tangents, out_info)
 | 
			
		||||
                    if info.output_type
 | 
			
		||||
                    in [
 | 
			
		||||
                        OutputType.non_alias,
 | 
			
		||||
                        OutputType.unsafe_view_alias,
 | 
			
		||||
                        OutputType.custom_function_view,
 | 
			
		||||
                    ]
 | 
			
		||||
                    and issubclass(info.raw_type, torch.Tensor)
 | 
			
		||||
                    and info.requires_grad
 | 
			
		||||
                ]
 | 
			
		||||
                # intermediate bases always require gradients, and always participate in the backward graph.
 | 
			
		||||
                flat_bw_args_with_grads = [
 | 
			
		||||
                    *inp_tangents_filtered,
 | 
			
		||||
                    *out_tangents_filtered,
 | 
			
		||||
                    *intermediate_base_tangents,
 | 
			
		||||
                ]
 | 
			
		||||
                num_flat_bw_args_with_grads = len(flat_bw_args_with_grads)
 | 
			
		||||
 | 
			
		||||
                # sanity asserts
 | 
			
		||||
                # metadata_only_inps = [
 | 
			
		||||
                #     x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
 | 
			
		||||
                #     if not input_info[info_idx].mutates_data
 | 
			
		||||
                # ]
 | 
			
		||||
                # aliased_outputs = [
 | 
			
		||||
                #     x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
 | 
			
		||||
                # assert all(x is None for x in metadata_only_inps)
 | 
			
		||||
                # assert all(x is None for x in aliased_outputs)
 | 
			
		||||
                # TODO: replace this with FunctionalizedRngRuntimeWrapper
 | 
			
		||||
                rng_args = []
 | 
			
		||||
                if CompiledFunction.metadata.is_rng_op_functionalized:
 | 
			
		||||
                    # Add the seed and offset to args
 | 
			
		||||
                    rng_args = CUDARngStateHelper.get_torch_state_as_tuple()
 | 
			
		||||
 | 
			
		||||
                bw_tokens = [None] * CompiledFunction.metadata.num_backward_tokens
 | 
			
		||||
 | 
			
		||||
                # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first
 | 
			
		||||
                #   in the bw output order.
 | 
			
		||||
 | 
			
		||||
                # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls
 | 
			
		||||
                # There are tests that count these calls, saving to var.
 | 
			
		||||
                ctx_saved_tensors = ctx.saved_tensors
 | 
			
		||||
                num_ctx_saved_tensors = len(ctx_saved_tensors)
 | 
			
		||||
                all_args = [
 | 
			
		||||
                    *ctx.symints,
 | 
			
		||||
                    *ctx_saved_tensors,
 | 
			
		||||
                    *flat_bw_args_with_grads,
 | 
			
		||||
                    *bw_tokens,
 | 
			
		||||
                    *rng_args,
 | 
			
		||||
                ]
 | 
			
		||||
                del ctx_saved_tensors
 | 
			
		||||
 | 
			
		||||
                # Note: [AOTAutograd Backward Guards]
 | 
			
		||||
                # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph.
 | 
			
		||||
                # Doing so requires us to "guess" about some of the metadata of our grad_outputs.
 | 
			
		||||
                #
 | 
			
		||||
                # In particular: if an output to the forward is a plain tensor or a subclass,
 | 
			
		||||
                # its corresponding grad_output in the backward **may or may not** be
 | 
			
		||||
                # a plain tensor or a subclass. The main cases are:
 | 
			
		||||
                # (1) If an output is a plain tensor, its grad_out will also be a plain tensor,
 | 
			
		||||
                #     *unless* the output is used in some subclass compute later in the forward graph,
 | 
			
		||||
                #     which will cause its grad_output to become a subclass
 | 
			
		||||
                # (2) If an output is a subclass, its grad_out will also be a subclass,
 | 
			
		||||
                #     *unless* the output of the forward did not actually participate in the gradient computation,
 | 
			
		||||
                #     in which case autograd will insert a plain tensor of zeros for the grad_output.
 | 
			
		||||
                #     We could avoid this case with `torch.autograd.Function.set_materialize_grads`,
 | 
			
		||||
                #     although this is not turned on today in AOTAutgrad and would require more work.
 | 
			
		||||
                #
 | 
			
		||||
                # Today, we make a guess on subclass-ness based on the above examples,
 | 
			
		||||
                # and hard-error in the backward if we guessed wrong.
 | 
			
		||||
                #
 | 
			
		||||
                # In the future, we should add backward guards that would allow us to
 | 
			
		||||
                # properly handle this case instead of erroring: we would need to retrace the backward graph,
 | 
			
		||||
                # since we might produce an entirely different trace if our grad_outputs are subclass or not.
 | 
			
		||||
                del flat_bw_args_with_grads
 | 
			
		||||
 | 
			
		||||
                tangents_start_idx = (
 | 
			
		||||
                    len(all_args)
 | 
			
		||||
                    - num_flat_bw_args_with_grads
 | 
			
		||||
                    - len(rng_args)
 | 
			
		||||
                    - len(bw_tokens)
 | 
			
		||||
                )
 | 
			
		||||
                assert tangents_start_idx == len(ctx.symints) + num_ctx_saved_tensors
 | 
			
		||||
                tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens)
 | 
			
		||||
 | 
			
		||||
                # TODO: figure out how to refactor the backward properly
 | 
			
		||||
                # so I can use aot_dispatch_subclass_wrapper() here.
 | 
			
		||||
                if CompiledFunction.maybe_subclass_metadata is not None:
 | 
			
		||||
                    tangents = all_args[tangents_start_idx:tangents_end_idx]
 | 
			
		||||
 | 
			
		||||
                    if len(tangents) != len(
 | 
			
		||||
                        CompiledFunction.metadata.subclass_tangent_meta
 | 
			
		||||
                    ):
 | 
			
		||||
                        raise RuntimeError(
 | 
			
		||||
                            "The grad inputs should be same number as forward output tangents"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    flat_processed_tangents = list(
 | 
			
		||||
                        itertools.chain.from_iterable(
 | 
			
		||||
                            AOTDispatchAutograd.process_runtime_tangent(
 | 
			
		||||
                                t,
 | 
			
		||||
                                m,
 | 
			
		||||
                            )[1]
 | 
			
		||||
                            for t, m in zip(
 | 
			
		||||
                                tangents,
 | 
			
		||||
                                CompiledFunction.metadata.subclass_tangent_meta,
 | 
			
		||||
                            )
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    all_args = (
 | 
			
		||||
                        runtime_unwrap_tensor_subclasses(
 | 
			
		||||
                            all_args[:tangents_start_idx],
 | 
			
		||||
                            # SymInts that are inputs to the backward graph are
 | 
			
		||||
                            # already included in the "all_args" list.
 | 
			
		||||
                            # Any symints coming from tensor subclasses should always
 | 
			
		||||
                            # come from primals, and so they will show up as extra
 | 
			
		||||
                            # arguments to the forward graph, and they will be saved
 | 
			
		||||
                            # as activation in the backward graph.
 | 
			
		||||
                            append_symints=False,
 | 
			
		||||
                        )
 | 
			
		||||
                        + flat_processed_tangents
 | 
			
		||||
                        + runtime_unwrap_tensor_subclasses(
 | 
			
		||||
                            all_args[tangents_end_idx:],
 | 
			
		||||
                            append_symints=False,
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    all_args = [
 | 
			
		||||
                        (
 | 
			
		||||
                            AOTDispatchAutograd.process_runtime_tangent(
 | 
			
		||||
                                t,
 | 
			
		||||
                                CompiledFunction.metadata.subclass_tangent_meta[
 | 
			
		||||
                                    i - tangents_start_idx
 | 
			
		||||
                                ],
 | 
			
		||||
                            )[0]
 | 
			
		||||
                            if (tangents_start_idx <= i < tangents_end_idx)
 | 
			
		||||
                            else t
 | 
			
		||||
                        )
 | 
			
		||||
                        for i, t in enumerate(all_args)
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
                # Backward with forward inputs mutations is not supported in double backward.
 | 
			
		||||
                if (
 | 
			
		||||
                    torch.is_grad_enabled()
 | 
			
		||||
                    and CompiledFunction.metadata.indices_of_inputs_that_requires_grad_with_mutations_in_bw
 | 
			
		||||
                ):
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                return all_args
 | 
			
		||||
 | 
			
		||||
            @staticmethod
 | 
			
		||||
            def _backward_impl(ctx, all_args):
 | 
			
		||||
                if ctx._is_compiled_autograd_tracing():
 | 
			
		||||
@ -2066,34 +2086,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
 | 
			
		||||
                )
 | 
			
		||||
                return out
 | 
			
		||||
 | 
			
		||||
            @staticmethod
 | 
			
		||||
            def _backward_epilogue(ctx, out):
 | 
			
		||||
                # Toss out the backward output tokens
 | 
			
		||||
                num_bw_tokens = CompiledFunction.metadata.num_backward_tokens
 | 
			
		||||
                if num_bw_tokens > 0:
 | 
			
		||||
                    out = out[:-num_bw_tokens]
 | 
			
		||||
 | 
			
		||||
                # TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile
 | 
			
		||||
                out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue(
 | 
			
		||||
                    CompiledFunction.metadata, out, offset_index=len(out) - 1
 | 
			
		||||
                )
 | 
			
		||||
                out = tuple(out)
 | 
			
		||||
 | 
			
		||||
                # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
 | 
			
		||||
                if CompiledFunction.maybe_subclass_metadata is not None:
 | 
			
		||||
                    assert (
 | 
			
		||||
                        CompiledFunction.maybe_subclass_metadata.grad_input_metas
 | 
			
		||||
                        is not None
 | 
			
		||||
                    )
 | 
			
		||||
                    outs_wrapped = wrap_tensor_subclasses(
 | 
			
		||||
                        out,
 | 
			
		||||
                        subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas,
 | 
			
		||||
                        included_subclass_symints=True,
 | 
			
		||||
                        is_runtime=True,
 | 
			
		||||
                    )
 | 
			
		||||
                    return outs_wrapped
 | 
			
		||||
                return out
 | 
			
		||||
 | 
			
		||||
        compiled_function = RuntimeWrapper(
 | 
			
		||||
            indices_of_inps_to_detach=indices_of_inps_to_detach,
 | 
			
		||||
            trace_joint=True,
 | 
			
		||||
 | 
			
		||||
@ -334,6 +334,9 @@ class FunctionMeta(type):
 | 
			
		||||
        backward_fn._compiled_autograd_should_lift = attrs.get(  # type: ignore[attr-defined]
 | 
			
		||||
            "_compiled_autograd_should_lift", True
 | 
			
		||||
        )
 | 
			
		||||
        backward_fn._bw_module = None
 | 
			
		||||
        if getattr(cls, "_lazy_backward_info", None):
 | 
			
		||||
            backward_fn._bw_module = cls._lazy_backward_info.bw_module
 | 
			
		||||
        cls._backward_cls = backward_fn
 | 
			
		||||
 | 
			
		||||
        super().__init__(name, bases, attrs)
 | 
			
		||||
 | 
			
		||||
@ -525,14 +525,32 @@ void AutogradContext::save_variables() {
 | 
			
		||||
  to_save_.clear();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AutogradContext AutogradContext::functional(variable_list saved_tensors) {
 | 
			
		||||
//   auto result = AutogradContext();
 | 
			
		||||
//   result.is_functional_ = true;
 | 
			
		||||
//   result.saved_variables_override_ = saved_tensors;
 | 
			
		||||
//   return result;
 | 
			
		||||
//
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
variable_list AutogradContext::get_saved_variables() const {
 | 
			
		||||
  if (is_functional_) {
 | 
			
		||||
    return saved_variables_override_.value();
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE);
 | 
			
		||||
  variable_list saved;
 | 
			
		||||
  saved.reserve(saved_variables_.size());
 | 
			
		||||
  auto ptr = grad_fn_.lock();
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(ptr);
 | 
			
		||||
  for (auto& var : saved_variables_) {
 | 
			
		||||
    saved.push_back(var.unpack(ptr));
 | 
			
		||||
  // TORCH_INTERNAL_ASSERT(ptr);
 | 
			
		||||
  // TODO(rzou): hacky, can do this in a more legit way
 | 
			
		||||
  if (ptr) {
 | 
			
		||||
    for (auto& var : saved_variables_) {
 | 
			
		||||
      saved.push_back(var.unpack(ptr));
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    for (auto& var : saved_variables_) {
 | 
			
		||||
      saved.push_back(var.unpack());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return saved;
 | 
			
		||||
}
 | 
			
		||||
@ -543,6 +561,7 @@ bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
 | 
			
		||||
  return ptr->task_should_compute_output(output_edge_index);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(rzou): might segfault, need to make this functional
 | 
			
		||||
bool AutogradContext::needs_input_grad(
 | 
			
		||||
    std::initializer_list<IndexRange> idxs) const {
 | 
			
		||||
  auto ptr = grad_fn_.lock();
 | 
			
		||||
 | 
			
		||||
@ -153,6 +153,8 @@ struct TORCH_API AutogradContext {
 | 
			
		||||
  bool needs_input_grad(size_t output_edge_index) const;
 | 
			
		||||
  bool needs_input_grad(std::initializer_list<IndexRange> idxs) const;
 | 
			
		||||
 | 
			
		||||
  static AutogradContext functional(variable_list saved_tensors);
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::unordered_set<at::TensorImpl*> non_differentiable_;
 | 
			
		||||
  std::unordered_set<at::TensorImpl*> dirty_inputs_;
 | 
			
		||||
@ -166,6 +168,10 @@ struct TORCH_API AutogradContext {
 | 
			
		||||
  std::weak_ptr<Node> grad_fn_;
 | 
			
		||||
  bool has_freed_buffers_{false};
 | 
			
		||||
 | 
			
		||||
  // If we're constructing an AutogradContext on the fly for Compiled Autograd.
 | 
			
		||||
  bool is_functional_{false};
 | 
			
		||||
  std::optional<variable_list> saved_variables_override_;
 | 
			
		||||
 | 
			
		||||
  void save_variables();
 | 
			
		||||
 | 
			
		||||
  template <class T>
 | 
			
		||||
@ -220,6 +226,126 @@ struct CppNode : public Node {
 | 
			
		||||
  variable_list apply_with_saved(
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      SwapSavedVariables& saved) override {
 | 
			
		||||
    // saved.before(ctx_.saved_data);
 | 
			
		||||
    // TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
 | 
			
		||||
    // TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
 | 
			
		||||
    // saved.before(ctx_.saved_variables_);
 | 
			
		||||
    // TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
 | 
			
		||||
    // saved.before(ctx_.materialize_grads_);
 | 
			
		||||
    // saved.before(ctx_.has_freed_buffers_);
 | 
			
		||||
    // saved.before(input_info_);
 | 
			
		||||
    // saved.before(output_info_);
 | 
			
		||||
 | 
			
		||||
    // auto results = apply(variable_list(inputs));
 | 
			
		||||
 | 
			
		||||
    // saved.after(ctx_.saved_data);
 | 
			
		||||
    // TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
 | 
			
		||||
    // TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
 | 
			
		||||
    // saved.after(ctx_.saved_variables_);
 | 
			
		||||
    // TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
 | 
			
		||||
    // saved.after(ctx_.materialize_grads_);
 | 
			
		||||
    // saved.after(ctx_.has_freed_buffers_);
 | 
			
		||||
    // saved.after(input_info_);
 | 
			
		||||
    // saved.after(output_info_);
 | 
			
		||||
    // return results;
 | 
			
		||||
 | 
			
		||||
    // TODO(rzou): following is problematic
 | 
			
		||||
    auto stack = retrieve_saved(saved);
 | 
			
		||||
    const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
 | 
			
		||||
    variable_list results = interface->call_function(
 | 
			
		||||
        saved.get_py_compiler(),
 | 
			
		||||
        "apply_functional",
 | 
			
		||||
        get_functional().value(),
 | 
			
		||||
        inputs,
 | 
			
		||||
        stack,
 | 
			
		||||
        num_outputs(),
 | 
			
		||||
        name());
 | 
			
		||||
    return results;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::optional<functional_apply_t> get_functional() override {
 | 
			
		||||
    auto name = this->name();
 | 
			
		||||
 | 
			
		||||
    // TODO(rzou): probably need to pre compute needs_input_grad
 | 
			
		||||
    return [name](
 | 
			
		||||
               const variable_list& inputs,
 | 
			
		||||
               const std::vector<c10::IValue>& saved) {
 | 
			
		||||
      SavedState state;
 | 
			
		||||
      state.stack = saved;
 | 
			
		||||
      auto ctx = AutogradContext();
 | 
			
		||||
      ctx.is_functional_ = true;
 | 
			
		||||
      std::vector<VariableInfo> output_info;
 | 
			
		||||
      std::vector<bool> is_variable_input;
 | 
			
		||||
 | 
			
		||||
      state.dequeue(ctx.saved_data);
 | 
			
		||||
 | 
			
		||||
      variable_list saved_variables;
 | 
			
		||||
      state.dequeue(saved_variables);
 | 
			
		||||
      ctx.saved_variables_override_ = saved_variables;
 | 
			
		||||
 | 
			
		||||
      state.dequeue(ctx.materialize_grads_);
 | 
			
		||||
      state.dequeue(output_info);
 | 
			
		||||
      state.dequeue(is_variable_input);
 | 
			
		||||
 | 
			
		||||
      // TODO(rzou): refactor to share code with CppNode<T>::apply
 | 
			
		||||
      at::OptionalDeviceGuard _device_guard;
 | 
			
		||||
      auto num_inputs = inputs.size();
 | 
			
		||||
      variable_list backward_inputs;
 | 
			
		||||
      backward_inputs.reserve(num_inputs);
 | 
			
		||||
      for (const auto i : c10::irange(num_inputs)) {
 | 
			
		||||
        if (inputs[i].defined() || !ctx.materialize_grads_) {
 | 
			
		||||
          backward_inputs.emplace_back(inputs[i]);
 | 
			
		||||
        } else {
 | 
			
		||||
          backward_inputs.emplace_back(output_info[i].zeros(_device_guard));
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      auto outputs = T::backward(&ctx, inputs);
 | 
			
		||||
 | 
			
		||||
      const auto num_forward_inputs =
 | 
			
		||||
          static_cast<int64_t>(is_variable_input.size());
 | 
			
		||||
      auto num_outputs = static_cast<int64_t>(outputs.size());
 | 
			
		||||
      // Returning too many results is ok, but only as long as they're all
 | 
			
		||||
      // undefined. Truncate the result vector in that case.
 | 
			
		||||
      if (num_outputs > num_forward_inputs) {
 | 
			
		||||
        bool all_undef = true;
 | 
			
		||||
        for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
 | 
			
		||||
          all_undef &= (!outputs[i].defined());
 | 
			
		||||
        }
 | 
			
		||||
        if (all_undef) {
 | 
			
		||||
          outputs.resize(num_forward_inputs);
 | 
			
		||||
          num_outputs = num_forward_inputs;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (num_outputs != num_forward_inputs) {
 | 
			
		||||
        std::string msg("function ");
 | 
			
		||||
        msg += name + " returned an incorrect number of gradients (expected ";
 | 
			
		||||
        msg += std::to_string(num_forward_inputs) + ", got ";
 | 
			
		||||
        msg += std::to_string(num_outputs) + ")";
 | 
			
		||||
        throw std::runtime_error(msg);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      variable_list results;
 | 
			
		||||
      results.reserve(num_outputs);
 | 
			
		||||
      for (const auto i : c10::irange(num_outputs)) {
 | 
			
		||||
        if (!is_variable_input[i]) {
 | 
			
		||||
          if (outputs[i].defined()) {
 | 
			
		||||
            std::string msg("function ");
 | 
			
		||||
            msg += name +
 | 
			
		||||
                " returned a gradient different that is defined at position ";
 | 
			
		||||
            msg += std::to_string(i + 1) +
 | 
			
		||||
                ", std the corresponding forward input was not a Variable";
 | 
			
		||||
            throw std::runtime_error(msg);
 | 
			
		||||
          }
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
        results.emplace_back(outputs[i]);
 | 
			
		||||
      }
 | 
			
		||||
      return results;
 | 
			
		||||
    };
 | 
			
		||||
  }
 | 
			
		||||
  ivalue_list retrieve_saved(SwapSavedVariables& saved) override {
 | 
			
		||||
    saved.before(ctx_.saved_data);
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
 | 
			
		||||
@ -229,7 +355,27 @@ struct CppNode : public Node {
 | 
			
		||||
    saved.before(ctx_.has_freed_buffers_);
 | 
			
		||||
    saved.before(input_info_);
 | 
			
		||||
    saved.before(output_info_);
 | 
			
		||||
    auto results = apply(variable_list(inputs));
 | 
			
		||||
 | 
			
		||||
    SavedState state;
 | 
			
		||||
    // std::cout << "start, stack=" << state.stack.size() << std::endl;
 | 
			
		||||
    state.enqueue(ctx_.saved_data);
 | 
			
		||||
    // std::cout << "enqueued saved_data, stack=" << state.stack.size() <<
 | 
			
		||||
    // std::endl;
 | 
			
		||||
 | 
			
		||||
    variable_list saved_variables = ctx_.get_saved_variables();
 | 
			
		||||
    state.enqueue(saved_variables);
 | 
			
		||||
    // std::cout << "enqueued saved_variables_, stack=" << state.stack.size() <<
 | 
			
		||||
    // std::endl;
 | 
			
		||||
    state.enqueue(ctx_.materialize_grads_);
 | 
			
		||||
    // std::cout << "enqueued materialize_grads_, stack=" << state.stack.size()
 | 
			
		||||
    // << std::endl;
 | 
			
		||||
    state.enqueue(output_info_);
 | 
			
		||||
    // std::cout << "enqueued output_info_, stack=" << state.stack.size() <<
 | 
			
		||||
    // std::endl;
 | 
			
		||||
    state.enqueue(is_variable_input_);
 | 
			
		||||
    // std::cout << "enqueued is_variable_input_, stack=" << state.stack.size()
 | 
			
		||||
    // << std::endl;
 | 
			
		||||
 | 
			
		||||
    saved.after(ctx_.saved_data);
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
 | 
			
		||||
@ -239,7 +385,8 @@ struct CppNode : public Node {
 | 
			
		||||
    saved.after(ctx_.has_freed_buffers_);
 | 
			
		||||
    saved.after(input_info_);
 | 
			
		||||
    saved.after(output_info_);
 | 
			
		||||
    return results;
 | 
			
		||||
 | 
			
		||||
    return state.stack;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -855,22 +855,84 @@ void set_device(int device) {
 | 
			
		||||
  worker_device = device;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void validate_outputs(
 | 
			
		||||
    const edge_list& edges,
 | 
			
		||||
// validate_outputs has two overloads, one that accepts edge_list and one that
 | 
			
		||||
// accepts vector<optional<InputMetadata>>. The former is stateful (it requires
 | 
			
		||||
// the autograd graph to actually use) and the latter is for functional
 | 
			
		||||
// autograd. (where we want to be able to take an autograd graph and then
 | 
			
		||||
// construct a FX graph out of it without specializing on the properties of the
 | 
			
		||||
// gradients).
 | 
			
		||||
//
 | 
			
		||||
// We do some templating to avoid dynamic allocations in the hot path (the eager
 | 
			
		||||
// autograd case). Otherwise, the problem is that we are given a vector<Edge>
 | 
			
		||||
// and would need to materialize a vector<optional<InputMetadata>> (or some
 | 
			
		||||
// other vector) to pass to a common helper function. The alternative is to use
 | 
			
		||||
// C++20's ranges which we don't have access to yet.
 | 
			
		||||
 | 
			
		||||
// Given an Edge or optional<InputMetdata>, return the InputMetadata
 | 
			
		||||
template <typename T>
 | 
			
		||||
const InputMetadata& get_input_metadata(const T& thing);
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
const InputMetadata& get_input_metadata<c10::optional<InputMetadata>>(
 | 
			
		||||
    const c10::optional<InputMetadata>& thing) {
 | 
			
		||||
  return thing.value();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
const InputMetadata& get_input_metadata<Edge>(const Edge& thing) {
 | 
			
		||||
  return thing.function->input_metadata(thing.input_nr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Given an Edge or optional<InputMetdata>, return if there is an InputMetadata.
 | 
			
		||||
template <typename T>
 | 
			
		||||
bool has_input_metadata(const T& thing);
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
bool has_input_metadata<c10::optional<InputMetadata>>(
 | 
			
		||||
    const c10::optional<InputMetadata>& thing) {
 | 
			
		||||
  return thing.has_value();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
bool has_input_metadata<Edge>(const Edge& thing) {
 | 
			
		||||
  return thing.is_valid();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<c10::optional<InputMetadata>> collect_input_metadata(
 | 
			
		||||
    const edge_list& edges) {
 | 
			
		||||
  std::vector<c10::optional<InputMetadata>> input_metadata;
 | 
			
		||||
  for (const auto& edge : edges) {
 | 
			
		||||
    if (!edge.is_valid()) {
 | 
			
		||||
      input_metadata.emplace_back(c10::nullopt);
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr));
 | 
			
		||||
  }
 | 
			
		||||
  return input_metadata;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Given an vector<Edge> or vector<optional<InputMetdata>>, validate the
 | 
			
		||||
// outputs. This involves using the InputMetadata to check the outputs and also
 | 
			
		||||
// potentially calling .sum_to on the outputs.
 | 
			
		||||
template <typename T>
 | 
			
		||||
void validate_outputs_impl(
 | 
			
		||||
    const std::vector<T>& input_metadata_container,
 | 
			
		||||
    variable_list& grads,
 | 
			
		||||
    const std::function<std::string(const std::string&)>& format_error) {
 | 
			
		||||
  if (grads.size() != edges.size()) {
 | 
			
		||||
  if (grads.size() != input_metadata_container.size()) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "invalid number of gradients - expected ";
 | 
			
		||||
    ss << edges.size() << ", but got " << grads.size();
 | 
			
		||||
    ss << input_metadata_container.size() << ", but got " << grads.size();
 | 
			
		||||
    TORCH_CHECK(false, format_error(ss.str()));
 | 
			
		||||
  }
 | 
			
		||||
  for (const auto i : c10::irange(grads.size())) {
 | 
			
		||||
    const auto& edge = edges[i];
 | 
			
		||||
    if (!edge.is_valid())
 | 
			
		||||
    // std::cout << "validate_outputs_impl: " << i << std::endl;
 | 
			
		||||
    if (!has_input_metadata(input_metadata_container.at(i))) {
 | 
			
		||||
      continue;
 | 
			
		||||
 | 
			
		||||
    const auto& metadata = edge.function->input_metadata(edge.input_nr);
 | 
			
		||||
    }
 | 
			
		||||
    // std::cout << "validate_outputs_impl get_input_metadata: " << i <<
 | 
			
		||||
    // std::endl;
 | 
			
		||||
    const auto& metadata = get_input_metadata(input_metadata_container[i]);
 | 
			
		||||
    auto& grad = grads[i];
 | 
			
		||||
    if (!grad.defined()) {
 | 
			
		||||
      // FIXME: TestJit.test_ge_optimized fails this assertion.
 | 
			
		||||
@ -938,6 +1000,20 @@ void validate_outputs(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void validate_outputs(
 | 
			
		||||
    const edge_list& edges,
 | 
			
		||||
    variable_list& grads,
 | 
			
		||||
    const std::function<std::string(const std::string&)>& format_error) {
 | 
			
		||||
  return validate_outputs_impl(edges, grads, format_error);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void validate_outputs(
 | 
			
		||||
    const std::vector<c10::optional<InputMetadata>>& input_metadata,
 | 
			
		||||
    variable_list& grads,
 | 
			
		||||
    const std::function<std::string(const std::string&)>& format_error) {
 | 
			
		||||
  return validate_outputs_impl(input_metadata, grads, format_error);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static variable_list call_function(
 | 
			
		||||
    std::shared_ptr<GraphTask>& graph_task,
 | 
			
		||||
    Node* func,
 | 
			
		||||
 | 
			
		||||
@ -43,6 +43,12 @@ TORCH_API void validate_outputs(
 | 
			
		||||
    const edge_list& edges,
 | 
			
		||||
    variable_list& grads,
 | 
			
		||||
    const std::function<std::string(const std::string&)>& format_error);
 | 
			
		||||
TORCH_API void validate_outputs(
 | 
			
		||||
    const std::vector<c10::optional<InputMetadata>>& input_metadata,
 | 
			
		||||
    variable_list& grads,
 | 
			
		||||
    const std::function<std::string(const std::string&)>& format_error);
 | 
			
		||||
TORCH_API std::vector<c10::optional<InputMetadata>> collect_input_metadata(
 | 
			
		||||
    const edge_list& edges);
 | 
			
		||||
 | 
			
		||||
struct NodeTask {
 | 
			
		||||
  std::weak_ptr<GraphTask> base_;
 | 
			
		||||
 | 
			
		||||
@ -34,8 +34,12 @@ using tensor_list = std::vector<at::Tensor>;
 | 
			
		||||
using variable_list = std::vector<Variable>;
 | 
			
		||||
using edge_list = std::vector<Edge>;
 | 
			
		||||
using saved_variable_list = std::vector<SavedVariable>;
 | 
			
		||||
using ivalue_list = std::vector<c10::IValue>;
 | 
			
		||||
using functional_apply_t = std::function<
 | 
			
		||||
    variable_list(const variable_list&, const std::vector<c10::IValue>&)>;
 | 
			
		||||
using IndexRange = std::pair<size_t, size_t>;
 | 
			
		||||
using torch::dynamo::autograd::CompiledNodeArgs;
 | 
			
		||||
using torch::dynamo::autograd::SavedState;
 | 
			
		||||
using torch::dynamo::autograd::SwapSavedVariables;
 | 
			
		||||
 | 
			
		||||
// Custom deleter to prevent stack overflows.
 | 
			
		||||
@ -604,6 +608,18 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
 | 
			
		||||
        std::string("apply_with_saved not implemented: ") + name());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual ivalue_list retrieve_saved(SwapSavedVariables& saved) {
 | 
			
		||||
    throw std::runtime_error(
 | 
			
		||||
        std::string("retrieve_saved not implemented: ") + name());
 | 
			
		||||
  }
 | 
			
		||||
  virtual c10::optional<functional_apply_t> get_functional() {
 | 
			
		||||
    throw std::runtime_error(
 | 
			
		||||
        std::string("get_functional not implemented: ") + name());
 | 
			
		||||
  }
 | 
			
		||||
  virtual bool use_apply_with_saved() {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  /// Performs the `Node`'s actual operation.
 | 
			
		||||
  virtual variable_list apply(variable_list&& inputs) = 0;
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@
 | 
			
		||||
namespace torch::dynamo::autograd {
 | 
			
		||||
class CompiledNodeArgs;
 | 
			
		||||
class SwapSavedVariables;
 | 
			
		||||
struct SavedState;
 | 
			
		||||
} // namespace torch::dynamo::autograd
 | 
			
		||||
 | 
			
		||||
// A hook that's called on gradients
 | 
			
		||||
 | 
			
		||||
@ -103,4 +103,42 @@ variable_list AccumulateGrad::apply_with_saved(
 | 
			
		||||
  return variable_list();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ivalue_list AccumulateGrad::retrieve_saved(SwapSavedVariables& saved) {
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(false, "use apply_with_saved");
 | 
			
		||||
  auto should_visit = variable.defined() && variable.requires_grad();
 | 
			
		||||
  if (should_visit) {
 | 
			
		||||
    saved.before(variable);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  SavedState state;
 | 
			
		||||
  state.enqueue(variable);
 | 
			
		||||
 | 
			
		||||
  if (should_visit) {
 | 
			
		||||
    saved.after(variable);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return state.stack;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::optional<functional_apply_t> AccumulateGrad::get_functional() {
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(false, "use apply_with_saved");
 | 
			
		||||
  return [](const variable_list& inputs,
 | 
			
		||||
            const std::vector<c10::IValue>& saved) -> variable_list {
 | 
			
		||||
    SavedState state;
 | 
			
		||||
    state.stack = saved;
 | 
			
		||||
    Variable foo;
 | 
			
		||||
    state.dequeue(foo);
 | 
			
		||||
    if (!(foo.defined() && foo.requires_grad()) || !inputs[0].defined()) {
 | 
			
		||||
      return variable_list();
 | 
			
		||||
    }
 | 
			
		||||
    // op is intentionally static
 | 
			
		||||
    static auto op = c10::Dispatcher::singleton()
 | 
			
		||||
                         .findSchemaOrThrow("inductor::accumulate_grad_", "")
 | 
			
		||||
                         .typed<void(const at::Tensor&, const at::Tensor&)>();
 | 
			
		||||
    op.call(foo, inputs[0]);
 | 
			
		||||
    // TODO(rzou): tensor_post_acc_grad_hooks
 | 
			
		||||
    return variable_list();
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace torch::autograd
 | 
			
		||||
 | 
			
		||||
@ -267,6 +267,13 @@ struct TORCH_API AccumulateGrad : public Node {
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      SwapSavedVariables& saved) override;
 | 
			
		||||
 | 
			
		||||
  ivalue_list retrieve_saved(SwapSavedVariables& saved) override;
 | 
			
		||||
  c10::optional<functional_apply_t> get_functional() override;
 | 
			
		||||
 | 
			
		||||
  bool use_apply_with_saved() override {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Variable variable;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -61,6 +61,21 @@ auto UndefinedGradBackward::apply(variable_list&& output_grads)
 | 
			
		||||
  }
 | 
			
		||||
  return input_grads;
 | 
			
		||||
}
 | 
			
		||||
ivalue_list UndefinedGradBackward::retrieve_saved(SwapSavedVariables&) {
 | 
			
		||||
  return {};
 | 
			
		||||
}
 | 
			
		||||
c10::optional<functional_apply_t> UndefinedGradBackward::get_functional() {
 | 
			
		||||
  return [](const variable_list& inputs,
 | 
			
		||||
            const ivalue_list& stack) -> variable_list {
 | 
			
		||||
    variable_list outputs;
 | 
			
		||||
    outputs.reserve(inputs.size());
 | 
			
		||||
    for (auto& grad : inputs) {
 | 
			
		||||
      (void)grad; // Suppress unused variable warning
 | 
			
		||||
      outputs.emplace_back();
 | 
			
		||||
    }
 | 
			
		||||
    return outputs;
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
auto Identity::apply(variable_list&& grads) -> variable_list {
 | 
			
		||||
  return std::move(grads);
 | 
			
		||||
@ -77,5 +92,22 @@ variable_list GraphRoot::apply_with_saved(
 | 
			
		||||
  saved.after(outputs);
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
ivalue_list GraphRoot::retrieve_saved(SwapSavedVariables& saved) {
 | 
			
		||||
  saved.before(outputs);
 | 
			
		||||
  SavedState state;
 | 
			
		||||
  state.enqueue(outputs);
 | 
			
		||||
  saved.after(outputs);
 | 
			
		||||
  return state.stack;
 | 
			
		||||
}
 | 
			
		||||
c10::optional<functional_apply_t> GraphRoot::get_functional() {
 | 
			
		||||
  return [](const variable_list& inputs,
 | 
			
		||||
            const std::vector<c10::IValue>& saved) -> variable_list {
 | 
			
		||||
    SavedState state;
 | 
			
		||||
    state.stack = saved;
 | 
			
		||||
    variable_list outputs;
 | 
			
		||||
    state.dequeue(outputs);
 | 
			
		||||
    return outputs;
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace torch::autograd
 | 
			
		||||
 | 
			
		||||
@ -76,6 +76,8 @@ struct TORCH_API UndefinedGradBackward : public Node {
 | 
			
		||||
      SwapSavedVariables& saved) override {
 | 
			
		||||
    return apply(variable_list(inputs));
 | 
			
		||||
  }
 | 
			
		||||
  ivalue_list retrieve_saved(SwapSavedVariables&) override;
 | 
			
		||||
  c10::optional<functional_apply_t> get_functional() override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API GraphRoot : public Node {
 | 
			
		||||
@ -97,6 +99,8 @@ struct TORCH_API GraphRoot : public Node {
 | 
			
		||||
  variable_list apply_with_saved(
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      SwapSavedVariables& saved) override;
 | 
			
		||||
  ivalue_list retrieve_saved(SwapSavedVariables& saved) override;
 | 
			
		||||
  c10::optional<functional_apply_t> get_functional() override;
 | 
			
		||||
 | 
			
		||||
  variable_list outputs;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -16,15 +16,18 @@
 | 
			
		||||
 | 
			
		||||
namespace torch::autograd {
 | 
			
		||||
 | 
			
		||||
auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
 | 
			
		||||
static variable_list CopyBackwards_apply_functional(
 | 
			
		||||
    variable_list&& grads,
 | 
			
		||||
    std::array<bool, 2> needs_input_grad,
 | 
			
		||||
    const c10::TensorOptions& src_options) {
 | 
			
		||||
  check_input_variables("CopyBackwards", grads, 1, -1, true);
 | 
			
		||||
  auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grads[0]);
 | 
			
		||||
  variable_list grad_inputs(2);
 | 
			
		||||
  if (grad->defined()) {
 | 
			
		||||
    if (task_should_compute_output(0)) {
 | 
			
		||||
    if (needs_input_grad[0]) {
 | 
			
		||||
      grad_inputs[0] = at::zeros_like(*grad, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
 | 
			
		||||
    }
 | 
			
		||||
    if (task_should_compute_output(1)) {
 | 
			
		||||
    if (needs_input_grad[1]) {
 | 
			
		||||
      // Handle R->C copies without raising a warning
 | 
			
		||||
      const auto src_type = src_options.dtype().toScalarType();
 | 
			
		||||
      if (!c10::isComplexType(src_type) && grad->is_complex()) {
 | 
			
		||||
@ -38,6 +41,38 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
 | 
			
		||||
  return grad_inputs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ivalue_list CopyBackwards::retrieve_saved(SwapSavedVariables& saved) {
 | 
			
		||||
  saved.before(src_options);
 | 
			
		||||
  SavedState state;
 | 
			
		||||
  state.enqueue(src_options);
 | 
			
		||||
  saved.after(src_options);
 | 
			
		||||
  return state.stack;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::optional<functional_apply_t> CopyBackwards::get_functional() {
 | 
			
		||||
  auto needs_input_grad = std::array<bool, 2>{
 | 
			
		||||
      task_should_compute_output(0), task_should_compute_output(1)};
 | 
			
		||||
  return [needs_input_grad](
 | 
			
		||||
             const variable_list& inputs,
 | 
			
		||||
             const ivalue_list& stack) -> variable_list {
 | 
			
		||||
    SavedState state;
 | 
			
		||||
    state.stack = stack;
 | 
			
		||||
    at::TensorOptions src_options;
 | 
			
		||||
    state.dequeue(src_options);
 | 
			
		||||
    auto inputs_copy = inputs;
 | 
			
		||||
 | 
			
		||||
    return CopyBackwards_apply_functional(
 | 
			
		||||
        std::move(inputs_copy), needs_input_grad, src_options);
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
 | 
			
		||||
  return CopyBackwards_apply_functional(
 | 
			
		||||
      std::move(grads),
 | 
			
		||||
      {task_should_compute_output(0), task_should_compute_output(1)},
 | 
			
		||||
      src_options);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
 | 
			
		||||
  args.collect(src_options);
 | 
			
		||||
}
 | 
			
		||||
@ -71,24 +106,16 @@ CopySlices::CopySlices(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// common code between apply/apply_with_saved
 | 
			
		||||
template <typename T>
 | 
			
		||||
inline variable_list CopySlices::apply_impl(
 | 
			
		||||
template <typename F1>
 | 
			
		||||
static variable_list CopySlices_apply_functional(
 | 
			
		||||
    variable_list&& inputs,
 | 
			
		||||
    const T& call_fn) {
 | 
			
		||||
  check_input_variables("CopySlices", inputs, 1, -1, true);
 | 
			
		||||
    const std::vector<bool>& needs_input_grad,
 | 
			
		||||
    const at::TensorGeometry& base,
 | 
			
		||||
    const at::TensorGeometry& view,
 | 
			
		||||
    int64_t num_outputs,
 | 
			
		||||
    const F1& call_fn,
 | 
			
		||||
    const std::unique_ptr<ViewFunc>& view_fn) {
 | 
			
		||||
  auto& grad = inputs[0];
 | 
			
		||||
  if (!grad.defined()) {
 | 
			
		||||
    return variable_list(num_outputs());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Acquire lock to here protect thread safety on fn
 | 
			
		||||
  // see Note [Thread Safety on Autograd Node]
 | 
			
		||||
  std::lock_guard<std::mutex> lock(mutex_);
 | 
			
		||||
 | 
			
		||||
  if (!fn) {
 | 
			
		||||
    throw std::runtime_error(ERR_BACKWARD_TWICE);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto result =
 | 
			
		||||
      grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides());
 | 
			
		||||
@ -103,6 +130,50 @@ inline variable_list CopySlices::apply_impl(
 | 
			
		||||
        result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // TODO: We clone grad_slice because we modify it below and "fn" might save
 | 
			
		||||
  // it for the backward of res. We might be able to avoid the clone() if
 | 
			
		||||
  // double-backprop is disabled.
 | 
			
		||||
  auto res = call_fn({grad_slice.clone(at::MemoryFormat::Contiguous)});
 | 
			
		||||
 | 
			
		||||
  variable_list grad_inputs(num_outputs);
 | 
			
		||||
  for (const auto i : c10::irange(res.size())) {
 | 
			
		||||
    if (needs_input_grad[i]) {
 | 
			
		||||
      if (!res[i].defined()) {
 | 
			
		||||
        // If the output is not defined, treat it as if it was a zero tensor.
 | 
			
		||||
        // This can happen if users define a custom Function.
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      if (i == 0) {
 | 
			
		||||
        grad_slice.copy_(res[i]);
 | 
			
		||||
        // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
 | 
			
		||||
        grad_inputs[i] = std::move(result); // NOLINT(bugprone-use-after-move)
 | 
			
		||||
      } else {
 | 
			
		||||
        grad_inputs[i] = std::move(res[i]);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return grad_inputs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// common code between apply/apply_with_saved
 | 
			
		||||
template <typename T>
 | 
			
		||||
inline variable_list CopySlices::apply_impl(
 | 
			
		||||
    variable_list&& inputs,
 | 
			
		||||
    const T& call_fn) {
 | 
			
		||||
  check_input_variables("CopySlices", inputs, 1, -1, true);
 | 
			
		||||
  auto& grad = inputs[0];
 | 
			
		||||
  if (!grad.defined()) {
 | 
			
		||||
    return variable_list(num_outputs());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (!fn) {
 | 
			
		||||
    throw std::runtime_error(ERR_BACKWARD_TWICE);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Acquire lock to here protect thread safety on fn
 | 
			
		||||
  // see Note [Thread Safety on Autograd Node]
 | 
			
		||||
  std::lock_guard<std::mutex> lock(mutex_);
 | 
			
		||||
 | 
			
		||||
  // See Note [View + Inplace update for view tensor] For more details on this
 | 
			
		||||
  // block Since the gradient edge for the 0th input is different between `this`
 | 
			
		||||
  // and `fn`, make sure that the one from `fn` has the same metadata in the
 | 
			
		||||
@ -146,30 +217,19 @@ inline variable_list CopySlices::apply_impl(
 | 
			
		||||
        fn->next_edge(i).function.get() == this->next_edge(i).function.get());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // TODO: We clone grad_slice because we modify it below and "fn" might save
 | 
			
		||||
  // it for the backward of res. We might be able to avoid the clone() if
 | 
			
		||||
  // double-backprop is disabled.
 | 
			
		||||
  auto res = call_fn({grad_slice.clone(at::MemoryFormat::Contiguous)});
 | 
			
		||||
 | 
			
		||||
  variable_list grad_inputs(num_outputs());
 | 
			
		||||
  for (const auto i : c10::irange(res.size())) {
 | 
			
		||||
    if (task_should_compute_output(i)) {
 | 
			
		||||
      if (!res[i].defined()) {
 | 
			
		||||
        // If the output is not defined, treat it as if it was a zero tensor.
 | 
			
		||||
        // This can happen if users define a custom Function.
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      if (i == 0) {
 | 
			
		||||
        grad_slice.copy_(res[i]);
 | 
			
		||||
        // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
 | 
			
		||||
        grad_inputs[i] = std::move(result); // NOLINT(bugprone-use-after-move)
 | 
			
		||||
      } else {
 | 
			
		||||
        grad_inputs[i] = std::move(res[i]);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  std::vector<bool> needs_input_grad;
 | 
			
		||||
  for (const auto i : c10::irange(num_outputs())) {
 | 
			
		||||
    needs_input_grad.emplace_back(task_should_compute_output(i));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return grad_inputs;
 | 
			
		||||
  return CopySlices_apply_functional(
 | 
			
		||||
      std::move(inputs),
 | 
			
		||||
      needs_input_grad,
 | 
			
		||||
      base,
 | 
			
		||||
      view,
 | 
			
		||||
      num_outputs(),
 | 
			
		||||
      call_fn,
 | 
			
		||||
      view_fn);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void CopySlices::release_variables() {
 | 
			
		||||
@ -192,6 +252,44 @@ variable_list CopySlices::apply_with_saved(
 | 
			
		||||
    SwapSavedVariables& saved) {
 | 
			
		||||
  saved.before(base);
 | 
			
		||||
  saved.before(view);
 | 
			
		||||
 | 
			
		||||
  auto results = variable_list(num_outputs());
 | 
			
		||||
 | 
			
		||||
  if (grads[0].defined()) {
 | 
			
		||||
    std::vector<bool> needs_input_grad;
 | 
			
		||||
    for (const auto i : c10::irange(num_outputs())) {
 | 
			
		||||
      needs_input_grad.emplace_back(task_should_compute_output(i));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(!view_fn);
 | 
			
		||||
    const auto& interface = torch::dynamo::autograd::getPyCompilerInterface();
 | 
			
		||||
    variable_list stuff = interface->call_copy_slices_prologue(
 | 
			
		||||
        saved.get_py_compiler(),
 | 
			
		||||
        grads,
 | 
			
		||||
        base,
 | 
			
		||||
        view);
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(stuff.size() == 3);
 | 
			
		||||
    auto result = stuff[0];
 | 
			
		||||
    auto grad_slice = stuff[1];
 | 
			
		||||
    auto grad_slice_clone = stuff[2];
 | 
			
		||||
    auto res = fn->apply_with_saved({grad_slice_clone}, saved);
 | 
			
		||||
    results = interface->call_copy_slices_epilogue(
 | 
			
		||||
        saved.get_py_compiler(),
 | 
			
		||||
        needs_input_grad,
 | 
			
		||||
        result,
 | 
			
		||||
        res,
 | 
			
		||||
        grad_slice);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  saved.after(base);
 | 
			
		||||
  saved.after(view);
 | 
			
		||||
  return results;
 | 
			
		||||
 | 
			
		||||
  // apply_with_saved
 | 
			
		||||
  //
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  /*
 | 
			
		||||
  int call_count = 0;
 | 
			
		||||
  variable_list result = apply_impl(
 | 
			
		||||
      variable_list(grads),
 | 
			
		||||
@ -203,6 +301,62 @@ variable_list CopySlices::apply_with_saved(
 | 
			
		||||
  saved.after(base);
 | 
			
		||||
  saved.after(view);
 | 
			
		||||
  return result;
 | 
			
		||||
  */
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ivalue_list CopySlices::retrieve_saved(SwapSavedVariables& saved) {
 | 
			
		||||
  saved.before(base);
 | 
			
		||||
  saved.before(view);
 | 
			
		||||
 | 
			
		||||
  SavedState state;
 | 
			
		||||
  state.enqueue(base);
 | 
			
		||||
  state.enqueue(view);
 | 
			
		||||
 | 
			
		||||
  auto fn_state = fn->retrieve_saved(saved);
 | 
			
		||||
  state.stack.insert(state.stack.end(), fn_state.begin(), fn_state.end());
 | 
			
		||||
 | 
			
		||||
  saved.after(base);
 | 
			
		||||
  saved.after(view);
 | 
			
		||||
 | 
			
		||||
  return state.stack;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
c10::optional<functional_apply_t> CopySlices::get_functional() {
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(
 | 
			
		||||
      !view_fn, "NYI: compiled autograd with CopySlices with view_fn");
 | 
			
		||||
  auto num_out = num_outputs();
 | 
			
		||||
  std::vector<bool> needs_input_grad;
 | 
			
		||||
  for (const auto i : c10::irange(num_outputs())) {
 | 
			
		||||
    needs_input_grad.emplace_back(task_should_compute_output(i));
 | 
			
		||||
  }
 | 
			
		||||
  auto fn2 = fn;
 | 
			
		||||
 | 
			
		||||
  return [fn2, num_out, needs_input_grad](
 | 
			
		||||
             const variable_list& inputs,
 | 
			
		||||
             const std::vector<c10::IValue>& saved) -> variable_list {
 | 
			
		||||
    SavedState state;
 | 
			
		||||
    state.stack = saved;
 | 
			
		||||
    at::TensorGeometry base;
 | 
			
		||||
    at::TensorGeometry view;
 | 
			
		||||
    state.dequeue(base);
 | 
			
		||||
    state.dequeue(view);
 | 
			
		||||
 | 
			
		||||
    // TODO(rzou): somehow we need to restore the state...
 | 
			
		||||
    auto call_fn = [fn2](variable_list&& inputs2) -> variable_list {
 | 
			
		||||
      return (*fn2)(std::move(inputs2));
 | 
			
		||||
    };
 | 
			
		||||
    // TODO(rzou): wut
 | 
			
		||||
    variable_list copied_inputs = inputs;
 | 
			
		||||
 | 
			
		||||
    return CopySlices_apply_functional(
 | 
			
		||||
        std::move(copied_inputs),
 | 
			
		||||
        needs_input_grad,
 | 
			
		||||
        base,
 | 
			
		||||
        view,
 | 
			
		||||
        num_out,
 | 
			
		||||
        call_fn,
 | 
			
		||||
        {});
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
auto CopySlices::apply(variable_list&& inputs1) -> variable_list {
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,8 @@ struct TORCH_API CopyBackwards : public Node {
 | 
			
		||||
  variable_list apply_with_saved(
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      SwapSavedVariables& saved) override;
 | 
			
		||||
  ivalue_list retrieve_saved(SwapSavedVariables&) override;
 | 
			
		||||
  c10::optional<functional_apply_t> get_functional() override;
 | 
			
		||||
 | 
			
		||||
  at::TensorOptions src_options;
 | 
			
		||||
};
 | 
			
		||||
@ -172,6 +174,8 @@ struct TORCH_API CopySlices : public Node {
 | 
			
		||||
  variable_list apply_with_saved(
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      SwapSavedVariables& saved) override;
 | 
			
		||||
  ivalue_list retrieve_saved(SwapSavedVariables&) override;
 | 
			
		||||
  c10::optional<functional_apply_t> get_functional() override;
 | 
			
		||||
 | 
			
		||||
  at::TensorGeometry base;
 | 
			
		||||
  // view and view_fn are redundant and view_fn will be used if available.
 | 
			
		||||
 | 
			
		||||
@ -131,6 +131,11 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
 | 
			
		||||
  if (!ParameterClass)
 | 
			
		||||
    return nullptr;
 | 
			
		||||
 | 
			
		||||
  py::class_<at::TensorGeometry>(m, "TensorGeometry")
 | 
			
		||||
    .def("sizes", &at::TensorGeometry::sizes)
 | 
			
		||||
    .def("strides", &at::TensorGeometry::strides)
 | 
			
		||||
    .def("storage_offset", &at::TensorGeometry::storage_offset);
 | 
			
		||||
 | 
			
		||||
  py::class_<LegacyEvent>(m, "ProfilerEvent")
 | 
			
		||||
      .def("kind", &LegacyEvent::kindStr)
 | 
			
		||||
      .def("name", [](const LegacyEvent& e) { return e.name(); })
 | 
			
		||||
 | 
			
		||||
@ -103,7 +103,7 @@ struct TORCH_API InputMetadata {
 | 
			
		||||
  bool maybe_expandable_to(const at::Tensor& grad) const;
 | 
			
		||||
 | 
			
		||||
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
 | 
			
		||||
  const at::TensorOptions options_;
 | 
			
		||||
  at::TensorOptions options_;
 | 
			
		||||
  MetadataShape shape_;
 | 
			
		||||
  c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
 | 
			
		||||
  bool is_tensor_subclass_ = false;
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@
 | 
			
		||||
#include <torch/csrc/autograd/saved_variable.h>
 | 
			
		||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
 | 
			
		||||
#include <torch/csrc/dynamo/compiled_autograd.h>
 | 
			
		||||
#include <torch/csrc/dynamo/python_compiled_autograd.h>
 | 
			
		||||
#include <torch/csrc/jit/frontend/tracer.h>
 | 
			
		||||
#include <torch/csrc/jit/ir/ir.h>
 | 
			
		||||
#include <torch/csrc/jit/python/pybind_utils.h>
 | 
			
		||||
@ -236,15 +237,22 @@ auto PyNode::defer_to_dynamo(
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(
 | 
			
		||||
      _backward_idx.has_value(),
 | 
			
		||||
      "indices should already be set by compiled_args, called before apply_with_saved");
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(!_backward_state_idx.has_value());
 | 
			
		||||
  PyObject* backward_state_idx = Py_None;
 | 
			
		||||
  if (_backward_state_idx.has_value()) {
 | 
			
		||||
    backward_state_idx = PyLong_FromLong(_backward_state_idx.value());
 | 
			
		||||
    // this might be simplifiable now that we no longer inline
 | 
			
		||||
    Py_CLEAR(py_fn->compiled_autograd_backward_state);
 | 
			
		||||
  }
 | 
			
		||||
  THPObjectPtr r(PyObject_CallMethod(
 | 
			
		||||
      *compiler,
 | 
			
		||||
      "proxy_call_backward",
 | 
			
		||||
      "OOOi",
 | 
			
		||||
      "OOOiOO",
 | 
			
		||||
      pyInputs.get(),
 | 
			
		||||
      fwdInputMetadatas.get(),
 | 
			
		||||
      saved_tensors.get(),
 | 
			
		||||
      *_backward_idx));
 | 
			
		||||
      *_backward_idx,
 | 
			
		||||
      obj,
 | 
			
		||||
      backward_state_idx));
 | 
			
		||||
 | 
			
		||||
  if (!r)
 | 
			
		||||
    throw_python_error();
 | 
			
		||||
@ -367,6 +375,7 @@ variable_list PyNode::apply_with_saved(
 | 
			
		||||
  variable_list result;
 | 
			
		||||
  if (!compiled_autograd_should_lift()) {
 | 
			
		||||
    if (_backward_state_idx.has_value()) {
 | 
			
		||||
      // TODO(rzou): need to excise this branch?
 | 
			
		||||
      PyObject* r = PyObject_CallMethod(
 | 
			
		||||
          saved.get_py_compiler(),
 | 
			
		||||
          "bind_backward_state",
 | 
			
		||||
@ -396,6 +405,100 @@ variable_list PyNode::apply_with_saved(
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ivalue_list PyNode::retrieve_saved(SwapSavedVariables& saved) {
 | 
			
		||||
  auto f = (THPFunction*)obj;
 | 
			
		||||
  saved.before(f->compiled_autograd_symints);
 | 
			
		||||
  saved.before(f->saved_variables);
 | 
			
		||||
  saved.before(f->needs_input_grad);
 | 
			
		||||
  saved.before(f->materialize_non_diff_grads);
 | 
			
		||||
  saved.before(f->output_info);
 | 
			
		||||
  saved.before(f->input_info);
 | 
			
		||||
 | 
			
		||||
  SavedState state;
 | 
			
		||||
  state.enqueue(f->compiled_autograd_symints);
 | 
			
		||||
  state.enqueue(f->saved_variables, shared_from_this());
 | 
			
		||||
  // state.enqueue(f->needs_input_grad);
 | 
			
		||||
  // state.enqueue(f->materialize_non_diff_grads);
 | 
			
		||||
  // state.enqueue(f->output_info);
 | 
			
		||||
  // state.enqueue(f->input_info);
 | 
			
		||||
 | 
			
		||||
  saved.after(f->compiled_autograd_symints);
 | 
			
		||||
  saved.after(f->saved_variables);
 | 
			
		||||
  saved.after(f->needs_input_grad);
 | 
			
		||||
  saved.after(f->materialize_non_diff_grads);
 | 
			
		||||
  saved.after(f->output_info);
 | 
			
		||||
  saved.after(f->input_info);
 | 
			
		||||
 | 
			
		||||
  state.enqueue(f->compiled_autograd_symints);
 | 
			
		||||
  state.enqueue(f->saved_variables, shared_from_this());
 | 
			
		||||
  // state.enqueue(f->needs_input_grad);
 | 
			
		||||
  // state.enqueue(f->materialize_non_diff_grads);
 | 
			
		||||
  // state.enqueue(f->output_info);
 | 
			
		||||
  // state.enqueue(f->input_info);
 | 
			
		||||
 | 
			
		||||
  return state.stack;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(rzou): compiled autograd needs special handling of the following.
 | 
			
		||||
c10::optional<functional_apply_t> PyNode::get_functional() {
 | 
			
		||||
  return c10::nullopt;
 | 
			
		||||
  /*
 | 
			
		||||
  auto node = std::static_pointer_cast<PyNode>(shared_from_this());
 | 
			
		||||
  // TODO(rzou): probably need to pre compute needs_input_grad
 | 
			
		||||
  return
 | 
			
		||||
      [node](
 | 
			
		||||
          const variable_list& inputs, const std::vector<c10::IValue>& saved) {
 | 
			
		||||
        SavedState state;
 | 
			
		||||
        state.stack = saved;
 | 
			
		||||
 | 
			
		||||
        auto f = (THPFunction*)node->obj;
 | 
			
		||||
 | 
			
		||||
        state.dequeue(f->compiled_autograd_symints);
 | 
			
		||||
        state.dequeue(f->saved_variables);
 | 
			
		||||
        // state.dequeue(f->needs_input_grad);
 | 
			
		||||
        // state.dequeue(f->materialize_non_diff_grads);
 | 
			
		||||
        // state.dequeue(f->output_info);
 | 
			
		||||
        // state.dequeue(f->input_info);
 | 
			
		||||
 | 
			
		||||
        f->compiled_autograd_tracing = true;
 | 
			
		||||
        variable_list result;
 | 
			
		||||
        if (!node->compiled_autograd_should_lift()) {
 | 
			
		||||
          if (node->_backward_state_idx.has_value()) {
 | 
			
		||||
            PyObject* r = PyObject_CallMethod(
 | 
			
		||||
                torch::dynamo::autograd::current_py_compiler(),
 | 
			
		||||
                "bind_backward_state",
 | 
			
		||||
                "i",
 | 
			
		||||
                *node->_backward_state_idx);
 | 
			
		||||
            if (r == nullptr) {
 | 
			
		||||
              throw python_error();
 | 
			
		||||
            }
 | 
			
		||||
            THPObjectPtr prior(f->compiled_autograd_backward_state);
 | 
			
		||||
            f->compiled_autograd_backward_state = r;
 | 
			
		||||
            result = node->apply(variable_list(inputs));
 | 
			
		||||
            Py_CLEAR(f->compiled_autograd_backward_state);
 | 
			
		||||
            f->compiled_autograd_backward_state = prior.release();
 | 
			
		||||
          } else {
 | 
			
		||||
            result = node->apply(variable_list(inputs));
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
          result = node->defer_to_dynamo(
 | 
			
		||||
              variable_list(inputs),
 | 
			
		||||
              torch::dynamo::autograd::current_py_compiler());
 | 
			
		||||
        }
 | 
			
		||||
        f->compiled_autograd_tracing = false;
 | 
			
		||||
 | 
			
		||||
        state.dequeue(f->compiled_autograd_symints);
 | 
			
		||||
        state.dequeue(f->saved_variables);
 | 
			
		||||
        // state.dequeue(f->needs_input_grad);
 | 
			
		||||
        // state.dequeue(f->materialize_non_diff_grads);
 | 
			
		||||
        // state.dequeue(f->output_info);
 | 
			
		||||
        // state.dequeue(f->input_info);
 | 
			
		||||
 | 
			
		||||
        return result;
 | 
			
		||||
      };
 | 
			
		||||
    */
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PyObject* PyNode::to_py_args(
 | 
			
		||||
    const variable_list& inputs,
 | 
			
		||||
    at::OptionalDeviceGuard* device_guard) {
 | 
			
		||||
 | 
			
		||||
@ -42,6 +42,10 @@ struct PyNode : public Node {
 | 
			
		||||
  std::string name() const override;
 | 
			
		||||
  bool is_traceable() override;
 | 
			
		||||
 | 
			
		||||
  bool use_apply_with_saved() override {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void compiled_args(CompiledNodeArgs& args) override;
 | 
			
		||||
  variable_list apply_with_saved(
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
@ -70,6 +74,9 @@ struct PyNode : public Node {
 | 
			
		||||
      Py_DECREF(obj);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  c10::optional<functional_apply_t> get_functional() override;
 | 
			
		||||
  ivalue_list retrieve_saved(SwapSavedVariables& saved) override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										22
									
								
								torch/csrc/dynamo/compiled_autograd.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								torch/csrc/dynamo/compiled_autograd.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
			
		||||
#include <torch/csrc/dynamo/compiled_autograd.h>
 | 
			
		||||
 | 
			
		||||
namespace torch::dynamo::autograd {
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<PyCompilerInterface> kPyCompilerInterface;
 | 
			
		||||
 | 
			
		||||
const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface() {
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr);
 | 
			
		||||
  return kPyCompilerInterface;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void setPyCompilerInterface(std::unique_ptr<PyCompilerInterface>&& impl) {
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(impl != nullptr);
 | 
			
		||||
  std::swap(kPyCompilerInterface, impl);
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(kPyCompilerInterface != nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void resetPyCompilerInterface() {
 | 
			
		||||
  kPyCompilerInterface.reset();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace torch::dynamo::autograd
 | 
			
		||||
@ -899,6 +899,359 @@ class SwapSavedVariables {
 | 
			
		||||
  StashedVars<at::IValue> stashed_ivalues;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct SavedState {
 | 
			
		||||
  std::vector<at::IValue> stack;
 | 
			
		||||
  int64_t idx = 0;
 | 
			
		||||
 | 
			
		||||
  void enqueue(
 | 
			
		||||
      const SavedVariable& sv,
 | 
			
		||||
      const std::shared_ptr<Node>& saved_for) {
 | 
			
		||||
    stack.emplace_back(sv.unpack(saved_for));
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(SavedVariable& sv) {
 | 
			
		||||
    sv = SavedVariable(stack[idx++].toTensor(), /*is_output*/ true);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(
 | 
			
		||||
      const std::vector<SavedVariable>& sv,
 | 
			
		||||
      const std::shared_ptr<Node>& saved_for) {
 | 
			
		||||
    enqueue(static_cast<int64_t>(sv.size()));
 | 
			
		||||
    for (const auto& v : sv) {
 | 
			
		||||
      enqueue(v, saved_for);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(std::vector<SavedVariable>& sv) {
 | 
			
		||||
    int64_t size = 0;
 | 
			
		||||
    dequeue(size);
 | 
			
		||||
    sv.clear();
 | 
			
		||||
    for (int64_t idx = 0; idx < size; idx++) {
 | 
			
		||||
      sv.emplace_back();
 | 
			
		||||
      dequeue(sv.back());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /*
 | 
			
		||||
  void enqueue(const PyObject*& t) {
 | 
			
		||||
    enqueue_ivalue(t);
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(PyObject*& t) {
 | 
			
		||||
    dequeue_ivalue(t);
 | 
			
		||||
  }
 | 
			
		||||
  */
 | 
			
		||||
 | 
			
		||||
  void enqueue(const VariableInfo& t) {
 | 
			
		||||
    enqueue(t.layout);
 | 
			
		||||
    enqueue(t.device);
 | 
			
		||||
    enqueue(t.scalar_type);
 | 
			
		||||
    enqueue(t.size);
 | 
			
		||||
    enqueue(t.requires_grad);
 | 
			
		||||
    enqueue(t.is_empty);
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(VariableInfo& t) {
 | 
			
		||||
    dequeue(t.layout);
 | 
			
		||||
    dequeue(t.device);
 | 
			
		||||
    dequeue(t.scalar_type);
 | 
			
		||||
    dequeue(t.size);
 | 
			
		||||
    dequeue(t.requires_grad);
 | 
			
		||||
    dequeue(t.is_empty);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(size_t t) {
 | 
			
		||||
    enqueue(static_cast<int64_t>(t));
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(size_t& t) {
 | 
			
		||||
    int64_t tmp = 0;
 | 
			
		||||
    dequeue(tmp);
 | 
			
		||||
    t = static_cast<size_t>(tmp);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // TODO: probably wildly inefficient
 | 
			
		||||
  template <class T>
 | 
			
		||||
  void enqueue(const c10::List<T> t) {
 | 
			
		||||
    enqueue(t.vec());
 | 
			
		||||
  }
 | 
			
		||||
  template <class T>
 | 
			
		||||
  void dequeue(c10::List<T>& t) {
 | 
			
		||||
    std::vector<T> tmp;
 | 
			
		||||
    dequeue(tmp);
 | 
			
		||||
    t = c10::List<T>(tmp);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const TypeAndSize& value) {
 | 
			
		||||
    enqueue(value.sym_sizes);
 | 
			
		||||
    enqueue(value.options);
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(TypeAndSize& value) {
 | 
			
		||||
    dequeue(value.sym_sizes);
 | 
			
		||||
    dequeue(value.options);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const InputMetadata& value) {
 | 
			
		||||
    enqueue(value.options());
 | 
			
		||||
    // std::cout << "enqueue: " << value.shape_as_dim_vector() << std::endl;
 | 
			
		||||
    enqueue(value.shape_as_dim_vector().vec());
 | 
			
		||||
    enqueue(value.is_tensor_subclass());
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(!value.is_nested_tensor());
 | 
			
		||||
  }
 | 
			
		||||
  // Special case: InputMetadata has no copy ctor
 | 
			
		||||
  // TODO(rzou): ??
 | 
			
		||||
  void dequeue(InputMetadata& value) {
 | 
			
		||||
    at::TensorOptions options;
 | 
			
		||||
    dequeue(options);
 | 
			
		||||
    std::vector<at::SymInt> shape;
 | 
			
		||||
    dequeue(shape);
 | 
			
		||||
    // std::cout << "dequeue: " << shape << std::endl;
 | 
			
		||||
    bool is_tensor_subclass = false;
 | 
			
		||||
    dequeue(is_tensor_subclass);
 | 
			
		||||
    SymIntSmallVec sym_shape;
 | 
			
		||||
    for (const auto& s : shape) {
 | 
			
		||||
      sym_shape.emplace_back(s);
 | 
			
		||||
    }
 | 
			
		||||
    value = InputMetadata(options, sym_shape, is_tensor_subclass, false);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const ska::flat_hash_map<std::string, at::IValue>& dct) {
 | 
			
		||||
    std::vector<std::string> keys;
 | 
			
		||||
    std::vector<at::IValue> values;
 | 
			
		||||
    for (const auto& [key, value] : dct) {
 | 
			
		||||
      keys.emplace_back(key);
 | 
			
		||||
      values.emplace_back(value);
 | 
			
		||||
    }
 | 
			
		||||
    enqueue(keys);
 | 
			
		||||
    enqueue(values);
 | 
			
		||||
  }
 | 
			
		||||
  void enqueue(const at::IValue& iv) {
 | 
			
		||||
    stack.emplace_back(iv);
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(at::IValue& iv) {
 | 
			
		||||
    iv = stack[idx++];
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(ska::flat_hash_map<std::string, at::IValue>& dct) {
 | 
			
		||||
    std::vector<std::string> keys;
 | 
			
		||||
    std::vector<at::IValue> values;
 | 
			
		||||
    dequeue(keys);
 | 
			
		||||
    dequeue(values);
 | 
			
		||||
    dct.clear();
 | 
			
		||||
    for (const auto i : c10::irange(keys.size())) {
 | 
			
		||||
      dct.insert({keys[i], values[i]});
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const at::TensorOptions& value) {
 | 
			
		||||
    enqueue(value.requires_grad_opt());
 | 
			
		||||
    enqueue(value.memory_format_opt());
 | 
			
		||||
    enqueue(value.device_opt());
 | 
			
		||||
    enqueue(value.dtype_opt());
 | 
			
		||||
    enqueue(value.layout_opt());
 | 
			
		||||
    enqueue(value.pinned_memory_opt());
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(at::TensorOptions& value) {
 | 
			
		||||
    auto result = at::TensorOptions();
 | 
			
		||||
    c10::optional<bool> requires_grad_opt;
 | 
			
		||||
    dequeue(requires_grad_opt);
 | 
			
		||||
    if (requires_grad_opt) {
 | 
			
		||||
      result = result.requires_grad(*requires_grad_opt);
 | 
			
		||||
    }
 | 
			
		||||
    c10::optional<c10::MemoryFormat> memory_format_opt;
 | 
			
		||||
    dequeue(memory_format_opt);
 | 
			
		||||
    if (memory_format_opt) {
 | 
			
		||||
      result = result.memory_format(*memory_format_opt);
 | 
			
		||||
    }
 | 
			
		||||
    c10::optional<c10::Device> device_opt;
 | 
			
		||||
    dequeue(device_opt);
 | 
			
		||||
    if (device_opt) {
 | 
			
		||||
      result = result.device(*device_opt);
 | 
			
		||||
    }
 | 
			
		||||
    c10::optional<caffe2::TypeMeta> dtype_opt;
 | 
			
		||||
    dequeue(dtype_opt);
 | 
			
		||||
    if (dtype_opt) {
 | 
			
		||||
      result = result.dtype(*dtype_opt);
 | 
			
		||||
    }
 | 
			
		||||
    c10::optional<c10::Layout> layout_opt;
 | 
			
		||||
    dequeue(layout_opt);
 | 
			
		||||
    if (layout_opt) {
 | 
			
		||||
      result = result.layout(*layout_opt);
 | 
			
		||||
    }
 | 
			
		||||
    c10::optional<bool> pinned_memory_opt;
 | 
			
		||||
    dequeue(pinned_memory_opt);
 | 
			
		||||
    if (pinned_memory_opt) {
 | 
			
		||||
      result = result.pinned_memory(*pinned_memory_opt);
 | 
			
		||||
    }
 | 
			
		||||
    value = result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const caffe2::TypeMeta& value) {
 | 
			
		||||
    enqueue(at::typeMetaToScalarType(value));
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(caffe2::TypeMeta& value) {
 | 
			
		||||
    at::ScalarType result = at::kFloat;
 | 
			
		||||
    dequeue(result);
 | 
			
		||||
    value = caffe2::TypeMeta::fromScalarType(result);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void enqueue(const c10::OptionalArray<T>& t) {
 | 
			
		||||
    enqueue(t.list);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void dequeue(c10::OptionalArray<T>& t) {
 | 
			
		||||
    dequeue(t.list);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void enqueue(const std::optional<T>& t) {
 | 
			
		||||
    enqueue(t.has_value());
 | 
			
		||||
    if (t.has_value()) {
 | 
			
		||||
      enqueue(*t);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void dequeue(c10::optional<T>& value) {
 | 
			
		||||
    bool has_value = false;
 | 
			
		||||
    dequeue(has_value);
 | 
			
		||||
    if (has_value) {
 | 
			
		||||
      T tmp;
 | 
			
		||||
      dequeue(tmp);
 | 
			
		||||
      value = tmp;
 | 
			
		||||
    } else {
 | 
			
		||||
      value = c10::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const at::TensorGeometry& t) {
 | 
			
		||||
    enqueue(t.sym_sizes().vec());
 | 
			
		||||
    enqueue(t.sym_strides().vec());
 | 
			
		||||
    enqueue(t.sym_storage_offset());
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(at::TensorGeometry& t) {
 | 
			
		||||
    std::vector<at::SymInt> sym_sizes;
 | 
			
		||||
    std::vector<at::SymInt> sym_strides;
 | 
			
		||||
    at::SymInt sym_storage_offset;
 | 
			
		||||
    dequeue(sym_sizes);
 | 
			
		||||
    dequeue(sym_strides);
 | 
			
		||||
    dequeue(sym_storage_offset);
 | 
			
		||||
    t = at::TensorGeometry(sym_sizes, sym_strides, sym_storage_offset);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void enqueue(const std::vector<T>& t) {
 | 
			
		||||
    enqueue(static_cast<int64_t>(t.size()));
 | 
			
		||||
    for (const T& i : t) {
 | 
			
		||||
      enqueue(i);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void dequeue(std::vector<T>& t) {
 | 
			
		||||
    int64_t size = 0;
 | 
			
		||||
    dequeue(size);
 | 
			
		||||
    t.clear();
 | 
			
		||||
    for (int64_t idx = 0; idx < size; idx++) {
 | 
			
		||||
      t.emplace_back();
 | 
			
		||||
      dequeue(t.back());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const c10::SymInt& t) {
 | 
			
		||||
    stack.emplace_back(t);
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(c10::SymInt& t) {
 | 
			
		||||
    t = stack[idx++].toSymInt();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(int64_t t) {
 | 
			
		||||
    stack.emplace_back(t);
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(int64_t& t) {
 | 
			
		||||
    t = stack[idx++].toInt();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const std::vector<c10::SymInt>& t) {
 | 
			
		||||
    enqueue_ivalue(t);
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(std::vector<c10::SymInt>& t) {
 | 
			
		||||
    t = stack[idx++].toSymIntVector();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void enqueue(const std::vector<int64_t>& t) {
 | 
			
		||||
    enqueue_ivalue(t);
 | 
			
		||||
  }
 | 
			
		||||
  void dequeue(std::vector<int64_t>& t) {
 | 
			
		||||
    t = stack[idx++].toIntVector();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <class ivalue_t>
 | 
			
		||||
  void enqueue_ivalue(const ivalue_t& t) {
 | 
			
		||||
    stack.emplace_back(t);
 | 
			
		||||
  }
 | 
			
		||||
  template <class ivalue_t>
 | 
			
		||||
  void dequeue_ivalue(ivalue_t& value) {
 | 
			
		||||
    value = stack[idx++].to<ivalue_t>();
 | 
			
		||||
  }
 | 
			
		||||
#define HANDLE_IVALUE(ivalue_t)                            \
 | 
			
		||||
  void enqueue(const ivalue_t& value) {                    \
 | 
			
		||||
    return enqueue_ivalue<ivalue_t>(value);                \
 | 
			
		||||
  }                                                        \
 | 
			
		||||
  void enqueue(const std::vector<ivalue_t>& value) {       \
 | 
			
		||||
    return enqueue_ivalue<std::vector<ivalue_t>>(value);   \
 | 
			
		||||
  }                                                        \
 | 
			
		||||
  void enqueue(const c10::optional<ivalue_t>& value) {     \
 | 
			
		||||
    return enqueue_ivalue<c10::optional<ivalue_t>>(value); \
 | 
			
		||||
  }                                                        \
 | 
			
		||||
  void dequeue(ivalue_t& value) {                          \
 | 
			
		||||
    return dequeue_ivalue<ivalue_t>(value);                \
 | 
			
		||||
  }                                                        \
 | 
			
		||||
  void dequeue(std::vector<ivalue_t>& value) {             \
 | 
			
		||||
    return dequeue_ivalue<std::vector<ivalue_t>>(value);   \
 | 
			
		||||
  }                                                        \
 | 
			
		||||
  void dequeue(c10::optional<ivalue_t>& value) {           \
 | 
			
		||||
    return dequeue_ivalue<c10::optional<ivalue_t>>(value); \
 | 
			
		||||
  }
 | 
			
		||||
  HANDLE_IVALUE(at::Tensor)
 | 
			
		||||
  HANDLE_IVALUE(c10::ScalarType)
 | 
			
		||||
  HANDLE_IVALUE(c10::Scalar)
 | 
			
		||||
  HANDLE_IVALUE(c10::Layout)
 | 
			
		||||
  HANDLE_IVALUE(c10::Device)
 | 
			
		||||
  HANDLE_IVALUE(c10::MemoryFormat)
 | 
			
		||||
  HANDLE_IVALUE(bool)
 | 
			
		||||
  HANDLE_IVALUE(double)
 | 
			
		||||
  HANDLE_IVALUE(std::string)
 | 
			
		||||
#undef HANDLE_IVALUE
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API PyCompilerInterface {
 | 
			
		||||
  virtual ~PyCompilerInterface(){};
 | 
			
		||||
  virtual variable_list call_function(
 | 
			
		||||
      PyObject* py_compiler,
 | 
			
		||||
      const char* name,
 | 
			
		||||
      functional_apply_t fn,
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      const ivalue_list& saved_state,
 | 
			
		||||
      int64_t num_outputs,
 | 
			
		||||
      const std::string& debug) {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
 | 
			
		||||
  }
 | 
			
		||||
  virtual variable_list call_copy_slices_prologue(
 | 
			
		||||
      PyObject* py_compiler,
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      const at::TensorGeometry& base,
 | 
			
		||||
      const at::TensorGeometry& view) {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
 | 
			
		||||
  }
 | 
			
		||||
  virtual variable_list call_copy_slices_epilogue(
 | 
			
		||||
      PyObject* py_compiler,
 | 
			
		||||
      const std::vector<bool>& needs_input_grad,
 | 
			
		||||
      const at::Tensor& result,
 | 
			
		||||
      const variable_list& res,
 | 
			
		||||
      const at::Tensor& grad_slice) {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(false, "Needs to be overridden");
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TORCH_API const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface();
 | 
			
		||||
TORCH_API void setPyCompilerInterface(
 | 
			
		||||
    std::unique_ptr<PyCompilerInterface>&& impl);
 | 
			
		||||
TORCH_API void resetPyCompilerInterface();
 | 
			
		||||
 | 
			
		||||
} // namespace torch::dynamo::autograd
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
 | 
			
		||||
@ -52,6 +52,115 @@ Notes:
 | 
			
		||||
namespace torch::dynamo::autograd {
 | 
			
		||||
using c10::SymInt;
 | 
			
		||||
 | 
			
		||||
static PyObject* kPyCompiler;
 | 
			
		||||
 | 
			
		||||
PyObject* current_py_compiler() {
 | 
			
		||||
  return kPyCompiler;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename Func>
 | 
			
		||||
static variable_list call_function(
 | 
			
		||||
    PyObject* py_compiler,
 | 
			
		||||
    const char* name,
 | 
			
		||||
    Func fn,
 | 
			
		||||
    const variable_list& inputs,
 | 
			
		||||
    const ivalue_list& saved_state,
 | 
			
		||||
    int64_t num_outputs,
 | 
			
		||||
    const std::string& debug) {
 | 
			
		||||
  // Need this to do PyObject* -> IValue conversion
 | 
			
		||||
  std::vector<at::TypePtr> schema;
 | 
			
		||||
  schema.reserve(saved_state.size());
 | 
			
		||||
  for (const auto& ivalue : saved_state) {
 | 
			
		||||
    schema.emplace_back(ivalue.type());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // We are going to bind the following function to Python
 | 
			
		||||
  auto py_func = py::cpp_function(
 | 
			
		||||
      [schema, fn](
 | 
			
		||||
          std::vector<c10::optional<at::Tensor>>& inputs,
 | 
			
		||||
          const py::args& args) -> py::object {
 | 
			
		||||
        // It reconstructs the saved_state from args via the schema
 | 
			
		||||
        std::vector<at::IValue> stack;
 | 
			
		||||
        TORCH_INTERNAL_ASSERT(args.size() == schema.size());
 | 
			
		||||
        auto tuple_args = jit::tuple_slice(args);
 | 
			
		||||
        for (uint64_t idx = 0; idx < schema.size(); idx++) {
 | 
			
		||||
          stack.emplace_back(
 | 
			
		||||
              jit::toIValue(tuple_args[idx], schema[idx], c10::nullopt));
 | 
			
		||||
        }
 | 
			
		||||
        std::vector<at::Tensor> inputs_;
 | 
			
		||||
        for (const auto& inp : inputs) {
 | 
			
		||||
          if (inp.has_value()) {
 | 
			
		||||
            inputs_.emplace_back(*inp);
 | 
			
		||||
          } else {
 | 
			
		||||
            inputs_.emplace_back();
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        auto outputs = fn(inputs_, stack);
 | 
			
		||||
        return jit::toPyObject(at::IValue(outputs));
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
  // convert ivalue_list -> PyObject*
 | 
			
		||||
  PyObject* py_saved_state =
 | 
			
		||||
      PyTuple_New(static_cast<Py_ssize_t>(schema.size()));
 | 
			
		||||
  for (const auto i : c10::irange(schema.size())) {
 | 
			
		||||
    py::object obj = jit::toPyObject(saved_state[i]);
 | 
			
		||||
    Py_INCREF(obj.ptr());
 | 
			
		||||
    PyTuple_SET_ITEM(py_saved_state, i, obj.ptr());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // call the corresponding method on the py_compiler
 | 
			
		||||
  // That method will figure out what to do with the function
 | 
			
		||||
  // (it can either inline it or plop it straight into the FX graph).
 | 
			
		||||
  py::handle handle(py_compiler);
 | 
			
		||||
  py::object stuff = handle.attr(name)(
 | 
			
		||||
      py_func, inputs, py::handle(py_saved_state), num_outputs, debug);
 | 
			
		||||
 | 
			
		||||
  // Convert the output from PyObject* to vector<Tensor>
 | 
			
		||||
  auto tmp = py::cast<std::vector<std::optional<at::Tensor>>>(stuff);
 | 
			
		||||
  variable_list outputs;
 | 
			
		||||
  for (const auto& t : tmp) {
 | 
			
		||||
    if (t.has_value()) {
 | 
			
		||||
      outputs.emplace_back(t.value());
 | 
			
		||||
    } else {
 | 
			
		||||
      outputs.emplace_back();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return outputs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct PyCompilerInterfaceImpl : PyCompilerInterface {
 | 
			
		||||
  variable_list call_function(
 | 
			
		||||
      PyObject* py_compiler,
 | 
			
		||||
      const char* name,
 | 
			
		||||
      functional_apply_t fn,
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      const ivalue_list& saved_state,
 | 
			
		||||
      int64_t num_outputs,
 | 
			
		||||
      const std::string& debug) override {
 | 
			
		||||
    return torch::dynamo::autograd::call_function(
 | 
			
		||||
        py_compiler, name, fn, inputs, saved_state, num_outputs, debug);
 | 
			
		||||
  }
 | 
			
		||||
  variable_list call_copy_slices_prologue(
 | 
			
		||||
      PyObject* py_compiler,
 | 
			
		||||
      const variable_list& inputs,
 | 
			
		||||
      const at::TensorGeometry& base,
 | 
			
		||||
      const at::TensorGeometry& view) override {
 | 
			
		||||
    py::handle handle(py_compiler);
 | 
			
		||||
    py::object stuff = handle.attr("call_copy_slices_prologue")(inputs, base, view);
 | 
			
		||||
    return py::cast<std::vector<at::Tensor>>(stuff);
 | 
			
		||||
  }
 | 
			
		||||
  virtual variable_list call_copy_slices_epilogue(
 | 
			
		||||
      PyObject* py_compiler,
 | 
			
		||||
      const std::vector<bool>& needs_input_grad,
 | 
			
		||||
      const at::Tensor& result,
 | 
			
		||||
      const variable_list& res,
 | 
			
		||||
      const at::Tensor& grad_slice) override {
 | 
			
		||||
    py::handle handle(py_compiler);
 | 
			
		||||
    py::object stuff = handle.attr("call_copy_slices_epilogue")(needs_input_grad, result, res, grad_slice);
 | 
			
		||||
    return py::cast<std::vector<at::Tensor>>(stuff);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
 | 
			
		||||
  PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
 | 
			
		||||
  for (const auto i : c10::irange(inputs.size())) {
 | 
			
		||||
@ -89,6 +198,25 @@ static void check(bool result) {
 | 
			
		||||
    check(nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static variable_list validate_outputs(
 | 
			
		||||
    variable_list& outputs,
 | 
			
		||||
    const ivalue_list& saved) {
 | 
			
		||||
  SavedState r;
 | 
			
		||||
  r.stack = saved;
 | 
			
		||||
  std::vector<c10::optional<InputMetadata>> value;
 | 
			
		||||
  r.dequeue(value);
 | 
			
		||||
  // std::cout << "dequeue" << std::endl;
 | 
			
		||||
  // dumpimv(value);
 | 
			
		||||
 | 
			
		||||
  torch::autograd::validate_outputs(
 | 
			
		||||
      value, outputs, [&](const std::string& msg) {
 | 
			
		||||
        std::ostringstream ss;
 | 
			
		||||
        ss << "[Compiled Autograd Tracing:]" << msg;
 | 
			
		||||
        return ss.str();
 | 
			
		||||
      });
 | 
			
		||||
  return outputs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// snapshot of python verbose logging toggle
 | 
			
		||||
static PyObject* python_verbose_logger = nullptr;
 | 
			
		||||
 | 
			
		||||
@ -498,6 +626,21 @@ void set_ivalue_proxies(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static at::Tensor call_accumulate(
 | 
			
		||||
    PyObject* py_compiler,
 | 
			
		||||
    const at::Tensor& old_var,
 | 
			
		||||
    const at::Tensor& new_var) {
 | 
			
		||||
  if (!old_var.defined()) {
 | 
			
		||||
    return new_var;
 | 
			
		||||
  }
 | 
			
		||||
  if (!new_var.defined()) {
 | 
			
		||||
    return old_var;
 | 
			
		||||
  }
 | 
			
		||||
  py::handle handle(py_compiler);
 | 
			
		||||
  py::object stuff = handle.attr("accumulate")(old_var, new_var);
 | 
			
		||||
  return py::cast<at::Tensor>(stuff);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static TraceState call_begin_capture(
 | 
			
		||||
    PyObject* self,
 | 
			
		||||
    CacheNode& cache,
 | 
			
		||||
@ -656,6 +799,9 @@ CacheNode* _compiled_autograd_impl(
 | 
			
		||||
    // cache miss, need to capture FX graph
 | 
			
		||||
    ClosingTHPObjectPtr py_compiler(
 | 
			
		||||
        check(PyObject_CallNoArgs((the_autograd_compiler))));
 | 
			
		||||
    kPyCompiler = py_compiler.get();
 | 
			
		||||
 | 
			
		||||
    setPyCompilerInterface(std::make_unique<PyCompilerInterfaceImpl>());
 | 
			
		||||
 | 
			
		||||
    TraceState state = call_begin_capture(
 | 
			
		||||
        py_compiler, *cache, compiler_call, output_edges.size());
 | 
			
		||||
@ -722,17 +868,64 @@ CacheNode* _compiled_autograd_impl(
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
 | 
			
		||||
      variable_list outputs = call.node->apply_with_saved(inputs, saved);
 | 
			
		||||
      // std::cout << call.node->name() << std::endl;
 | 
			
		||||
      // std::cout << saved_state.size() << std::endl;
 | 
			
		||||
      // for (const auto& ivalue: saved_state) {
 | 
			
		||||
      //   if (ivalue.isTensor()) {
 | 
			
		||||
      //     std::cout << "tensor" << std::endl;
 | 
			
		||||
      //   } else {
 | 
			
		||||
      //     ivalue.dump();
 | 
			
		||||
      //   }
 | 
			
		||||
      // }
 | 
			
		||||
 | 
			
		||||
      // There are 4 cases:
 | 
			
		||||
      // 1) user Python autograd.Function
 | 
			
		||||
      // 2) autogenerated C++ Node.
 | 
			
		||||
      // 3) manual C++ Node (in PyTorch framework code)
 | 
			
		||||
      // 4) user C++ autograd::Function
 | 
			
		||||
 | 
			
		||||
      // should really be called "call_or_return_functional"
 | 
			
		||||
      variable_list outputs;
 | 
			
		||||
 | 
			
		||||
      // std::cout << call.node->name() << std::endl;
 | 
			
		||||
      outputs = call.node->apply_with_saved(inputs, saved);
 | 
			
		||||
      // if (call.node->use_apply_with_saved()) {
 | 
			
		||||
      //   outputs = call.node->apply_with_saved(inputs, saved);
 | 
			
		||||
      // } else {
 | 
			
		||||
      //   auto function_to_proxy = call.node->get_functional();
 | 
			
		||||
      //   auto saved_state = call.node->retrieve_saved(saved);
 | 
			
		||||
      //   outputs = call_function(
 | 
			
		||||
      //       py_compiler,
 | 
			
		||||
      //       "apply_functional",
 | 
			
		||||
      //       function_to_proxy.value(),
 | 
			
		||||
      //       inputs,
 | 
			
		||||
      //       saved_state,
 | 
			
		||||
      //       call.node->num_outputs(),
 | 
			
		||||
      //       call.node->name());
 | 
			
		||||
      // }
 | 
			
		||||
 | 
			
		||||
      saved.debug_asserts();
 | 
			
		||||
      saved.before(call.node->next_edges());
 | 
			
		||||
      validate_outputs(
 | 
			
		||||
          call.node->next_edges(), outputs, [&](const std::string& msg) {
 | 
			
		||||
            std::ostringstream ss;
 | 
			
		||||
            ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
 | 
			
		||||
               << msg;
 | 
			
		||||
            return ss.str();
 | 
			
		||||
          });
 | 
			
		||||
 | 
			
		||||
      auto input_metadata = collect_input_metadata(call.node->next_edges());
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(input_metadata.size() == outputs.size());
 | 
			
		||||
      // std::cout << "outputs_size: " << input_metadata.size() << std::endl;
 | 
			
		||||
 | 
			
		||||
      // std::cout << "enqueue" << std::endl;
 | 
			
		||||
      // dumpimv(input_metadata);
 | 
			
		||||
 | 
			
		||||
      SavedState state;
 | 
			
		||||
      state.enqueue(input_metadata);
 | 
			
		||||
      ivalue_list& input_metadata_state = state.stack;
 | 
			
		||||
      outputs = call_function(
 | 
			
		||||
          py_compiler,
 | 
			
		||||
          "validate_outputs",
 | 
			
		||||
          validate_outputs,
 | 
			
		||||
          outputs,
 | 
			
		||||
          input_metadata_state,
 | 
			
		||||
          outputs.size(),
 | 
			
		||||
          "validate_outputs");
 | 
			
		||||
 | 
			
		||||
      saved.after(call.node->next_edges());
 | 
			
		||||
      saved.debug_asserts();
 | 
			
		||||
 | 
			
		||||
@ -754,13 +947,15 @@ CacheNode* _compiled_autograd_impl(
 | 
			
		||||
        auto& output = outputs[i];
 | 
			
		||||
        const auto& next = call.node->next_edge(i);
 | 
			
		||||
        if (next.is_valid() && output.defined()) {
 | 
			
		||||
          input_buffers.lookup(next.function.get())
 | 
			
		||||
              .add(
 | 
			
		||||
                  next.input_nr, std::move(output), std::nullopt, std::nullopt);
 | 
			
		||||
          auto& buffer = input_buffers.lookup(next.function.get());
 | 
			
		||||
          buffer.buffer[next.input_nr] = call_accumulate(
 | 
			
		||||
              py_compiler, buffer.buffer[next.input_nr], output);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    resetPyCompilerInterface();
 | 
			
		||||
    kPyCompiler = nullptr;
 | 
			
		||||
    PyObject* res = check(call_end_capture(py_compiler, state.outputs));
 | 
			
		||||
    TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
 | 
			
		||||
@ -4,4 +4,5 @@
 | 
			
		||||
// see [Note: Compiled Autograd]
 | 
			
		||||
namespace torch::dynamo::autograd {
 | 
			
		||||
PyObject* torch_c_dynamo_compiled_autograd_init();
 | 
			
		||||
PyObject* current_py_compiler();
 | 
			
		||||
} // namespace torch::dynamo::autograd
 | 
			
		||||
 | 
			
		||||
@ -369,8 +369,18 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) {
 | 
			
		||||
          }
 | 
			
		||||
        case TypeKind::BoolType:
 | 
			
		||||
          return IValue(py::cast<std::vector<bool>>(obj));
 | 
			
		||||
        case TypeKind::TensorType:
 | 
			
		||||
          return IValue(py::cast<std::vector<at::Tensor>>(obj));
 | 
			
		||||
        case TypeKind::TensorType: {
 | 
			
		||||
          auto thing = py::cast<std::vector<std::optional<at::Tensor>>>(obj);
 | 
			
		||||
          auto thing2 = std::vector<at::Tensor>();
 | 
			
		||||
          for (const auto& inp : thing) {
 | 
			
		||||
            if (inp.has_value()) {
 | 
			
		||||
              thing2.emplace_back(*inp);
 | 
			
		||||
            } else {
 | 
			
		||||
              thing2.emplace_back();
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
          return IValue(thing2);
 | 
			
		||||
        }
 | 
			
		||||
        default:
 | 
			
		||||
          return createGenericList(obj, elem_type);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user