[ca] side-effect free inital trace: compiled_args (#147804)

const methods to prevent accidental mutation. changes mainly in Error nodes and PyNode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147804
Approved by: https://github.com/jansel
ghstack dependencies: #147242, #147796
This commit is contained in:
Simon Fan
2025-02-25 19:57:55 -08:00
committed by PyTorch MergeBot
parent 5e3069dde8
commit fd1220e386
17 changed files with 105 additions and 79 deletions

View File

@ -67,7 +67,7 @@ struct TORCH_API ${op} : public ${superclass} {
${release_variables} ${release_variables}
} }
${will_release_variables} ${will_release_variables}
void compiled_args(CompiledNodeArgs& args) override; void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
${saved_variables} ${saved_variables}
${saved_list_sizes} ${saved_list_sizes}
@ -127,7 +127,7 @@ variable_list ${op}::apply(variable_list&& grads) {
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args}); return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
} }
void ${op}::compiled_args(CompiledNodeArgs& args) { void ${op}::compiled_args(CompiledNodeArgs& args) const {
${compiled_args} ${compiled_args}
} }
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {

View File

@ -284,7 +284,7 @@ struct CppNode : public Node {
void set_ctx_grad_fn(const std::shared_ptr<Node>& node); void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
void save_variables_to_ctx(); void save_variables_to_ctx();
void compiled_args(CompiledNodeArgs& args) override { void compiled_args(CompiledNodeArgs& args) const override {
// although neither of the 2 methods below have uniqueness guarantees // although neither of the 2 methods below have uniqueness guarantees
// it is unlikely for them to collide at the same time // it is unlikely for them to collide at the same time
args.collect(static_cast<uint64_t>(typeid(T).hash_code())); args.collect(static_cast<uint64_t>(typeid(T).hash_code()));

View File

@ -545,8 +545,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return tensor_pre_hooks_; return tensor_pre_hooks_;
} }
virtual std::unique_ptr<PostAccumulateGradHook>& virtual std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
tensor_post_acc_grad_hooks() noexcept { const noexcept {
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr; static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
return empty; return empty;
} }
@ -593,7 +593,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// 2) Collect node information for specialization and caching // 2) Collect node information for specialization and caching
// Implementations in subclasses should call args.collect() with all node // Implementations in subclasses should call args.collect() with all node
// attrs. These functions are only called durring backward. // attrs. These functions are only called durring backward.
virtual void compiled_args(CompiledNodeArgs& args) { virtual void compiled_args(CompiledNodeArgs& args) const {
throw std::runtime_error( throw std::runtime_error(
std::string("compiled_args not implemented: ") + name()); std::string("compiled_args not implemented: ") + name());
} }

View File

@ -22,7 +22,8 @@ struct TORCH_API FunctionPreHook {
virtual ~FunctionPreHook() = default; virtual ~FunctionPreHook() = default;
virtual variable_list operator()(const variable_list& grads) = 0; virtual variable_list operator()(const variable_list& grads) = 0;
// only implemented for python hooks, registers hook with compiled autograd // only implemented for python hooks, registers hook with compiled autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error( throw std::runtime_error(
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name()); typeid(*this).name());
@ -35,7 +36,8 @@ struct TORCH_API FunctionPostHook {
const variable_list& outputs /* grad_inputs */, const variable_list& outputs /* grad_inputs */,
const variable_list& inputs /* grad_outputs */) = 0; const variable_list& inputs /* grad_outputs */) = 0;
// only implemented for python hooks, registers hook with compiled autograd // only implemented for python hooks, registers hook with compiled autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error( throw std::runtime_error(
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") + std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name()); typeid(*this).name());
@ -47,7 +49,8 @@ struct TORCH_API PostAccumulateGradHook {
virtual void operator()(const Variable& tensor) = 0; virtual void operator()(const Variable& tensor) = 0;
// only implemented for python hooks on nodes, registers hook with compiled // only implemented for python hooks on nodes, registers hook with compiled
// autograd // autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) { virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error( throw std::runtime_error(
std::string("not yet implemented for compiled autograd: ") + std::string("not yet implemented for compiled autograd: ") +
typeid(*this).name()); typeid(*this).name());

View File

@ -66,12 +66,12 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
return variable_list(); return variable_list();
} }
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) { void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
if (args.cond(variable.defined() && variable.requires_grad())) { if (args.cond(variable.defined() && variable.requires_grad())) {
args.collect(variable); args.collect(variable);
args.collect(variable.grad()); args.collect(variable.grad());
} }
auto& hook = tensor_post_acc_grad_hooks(); const auto& hook = tensor_post_acc_grad_hooks();
if (hook != nullptr) { if (hook != nullptr) {
hook->compiled_args(args); hook->compiled_args(args);
} }

View File

@ -50,8 +50,8 @@ struct TORCH_API AccumulateGrad : public Node {
return impl::hooks(variable); return impl::hooks(variable);
} }
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
override { const noexcept override {
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor, // NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
// it can be destroyed even though the Tensor is still alive (contrary // it can be destroyed even though the Tensor is still alive (contrary
// to all other Nodes). So we must lazily read the Tensor hooks here. // to all other Nodes). So we must lazily read the Tensor hooks here.
@ -262,7 +262,7 @@ struct TORCH_API AccumulateGrad : public Node {
} }
} }
void compiled_args(CompiledNodeArgs& args) override; void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved( variable_list apply_with_saved(
const variable_list& inputs, const variable_list& inputs,
SwapSavedVariables& saved) override; SwapSavedVariables& saved) override;

View File

@ -12,11 +12,15 @@
namespace torch::autograd { namespace torch::autograd {
auto Error::apply(variable_list&& inputs) -> variable_list { variable_list Error::apply(variable_list&& inputs) {
return static_cast<const Error*>(this)->apply(std::move(inputs));
}
variable_list Error::apply(variable_list&& inputs) const {
throw std::runtime_error(msg); throw std::runtime_error(msg);
} }
void Error::compiled_args(CompiledNodeArgs& args) { void Error::compiled_args(CompiledNodeArgs& args) const {
// throw the error durring collect, the graph won't get compiled // throw the error durring collect, the graph won't get compiled
apply(variable_list()); apply(variable_list());
} }
@ -66,7 +70,7 @@ auto Identity::apply(variable_list&& grads) -> variable_list {
return std::move(grads); return std::move(grads);
} }
void GraphRoot::compiled_args(CompiledNodeArgs& args) { void GraphRoot::compiled_args(CompiledNodeArgs& args) const {
args.collect(outputs); args.collect(outputs);
} }
variable_list GraphRoot::apply_with_saved( variable_list GraphRoot::apply_with_saved(

View File

@ -18,8 +18,9 @@ struct TORCH_API Error : public Node {
Error(std::string msg) : msg(std::move(msg)) {} Error(std::string msg) : msg(std::move(msg)) {}
variable_list apply(variable_list&& inputs) override; variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
void compiled_args(CompiledNodeArgs& args) override; void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved( variable_list apply_with_saved(
const variable_list& inputs, const variable_list& inputs,
SwapSavedVariables& saved) override; SwapSavedVariables& saved) override;
@ -51,6 +52,7 @@ struct TORCH_API DelayedError : public Node {
} }
variable_list apply(variable_list&& inputs) override; variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
std::string msg; std::string msg;
}; };
@ -61,6 +63,7 @@ struct TORCH_API UndefinedGrad : public Node {
} }
variable_list apply(variable_list&& inputs) override; variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
}; };
struct TORCH_API UndefinedGradBackward : public Node { struct TORCH_API UndefinedGradBackward : public Node {
@ -69,8 +72,9 @@ struct TORCH_API UndefinedGradBackward : public Node {
UndefinedGradBackward() = default; UndefinedGradBackward() = default;
variable_list apply(variable_list&& inputs) override; variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
void compiled_args(CompiledNodeArgs& args) override {} void compiled_args(CompiledNodeArgs& args) const override {}
variable_list apply_with_saved( variable_list apply_with_saved(
const variable_list& inputs, const variable_list& inputs,
SwapSavedVariables& saved) override { SwapSavedVariables& saved) override {
@ -93,7 +97,7 @@ struct TORCH_API GraphRoot : public Node {
return outputs; return outputs;
} }
void compiled_args(CompiledNodeArgs& args) override; void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved( variable_list apply_with_saved(
const variable_list& inputs, const variable_list& inputs,
SwapSavedVariables& saved) override; SwapSavedVariables& saved) override;

View File

@ -60,7 +60,7 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
src_options); src_options);
} }
void CopyBackwards::compiled_args(CompiledNodeArgs& args) { void CopyBackwards::compiled_args(CompiledNodeArgs& args) const {
args.collect(src_options); args.collect(src_options);
} }
@ -235,7 +235,7 @@ void CopySlices::release_variables() {
fn = nullptr; fn = nullptr;
} }
void CopySlices::compiled_args(CompiledNodeArgs& args) { void CopySlices::compiled_args(CompiledNodeArgs& args) const {
TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd") TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
TORCH_INTERNAL_ASSERT((bool)fn); TORCH_INTERNAL_ASSERT((bool)fn);
args.collect(base); args.collect(base);

View File

@ -15,7 +15,7 @@ namespace torch::autograd {
struct TORCH_API CopyBackwards : public Node { struct TORCH_API CopyBackwards : public Node {
variable_list apply(variable_list&& grads) override; variable_list apply(variable_list&& grads) override;
void compiled_args(CompiledNodeArgs& args) override; void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved( variable_list apply_with_saved(
const variable_list& inputs, const variable_list& inputs,
SwapSavedVariables& saved) override; SwapSavedVariables& saved) override;
@ -168,7 +168,7 @@ struct TORCH_API CopySlices : public Node {
variable_list apply(variable_list&& inputs) override; variable_list apply(variable_list&& inputs) override;
void release_variables() override; void release_variables() override;
void compiled_args(CompiledNodeArgs& args) override; void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved( variable_list apply_with_saved(
const variable_list& inputs, const variable_list& inputs,
SwapSavedVariables& saved) override; SwapSavedVariables& saved) override;

View File

@ -34,6 +34,7 @@
#include <torch/csrc/utils/python_strings.h> #include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_dtypes.h> #include <torch/csrc/utils/tensor_dtypes.h>
#include <autograd/function.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
@ -185,9 +186,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
return to_variable_list(r.get(), is_variable_input); return to_variable_list(r.get(), is_variable_input);
} }
auto PyNode::defer_to_dynamo( auto PyNode::apply_with_saved_impl(
const variable_list& inputs, const variable_list& inputs,
const std::optional<PyObject*>& compiler) -> variable_list { const SwapSavedVariables& saved) -> variable_list {
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
at::OptionalDeviceGuard _device_guard; at::OptionalDeviceGuard _device_guard;
THPFunction* py_fn = (THPFunction*)obj; THPFunction* py_fn = (THPFunction*)obj;
@ -235,24 +236,24 @@ auto PyNode::defer_to_dynamo(
} }
THPObjectPtr saved_tensors(unpack_saved_variables( THPObjectPtr saved_tensors(unpack_saved_variables(
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); })); py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
TORCH_INTERNAL_ASSERT(
_backward_idx.has_value(), auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this);
"indices should already be set by compiled_args, called before apply_with_saved");
PyObject* backward_state_idx = Py_None; PyObject* backward_state_idx = Py_None;
if (_backward_state_idx.has_value()) { if (maybe_bwd_state_idx.has_value()) {
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value()); backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value());
// this might be simplifiable now that we no longer inline // this might be simplifiable now that we no longer inline
Py_CLEAR(py_fn->compiled_autograd_backward_state); Py_CLEAR(py_fn->compiled_autograd_backward_state);
} }
THPObjectPtr r(PyObject_CallMethod( THPObjectPtr r(PyObject_CallMethod(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access) // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
compiler.value(), saved.get_py_compiler(),
"proxy_call_backward", "proxy_call_backward",
"OOOiOO", "OOOiOO",
pyInputs.get(), pyInputs.get(),
fwdInputMetadatas.get(), fwdInputMetadatas.get(),
saved_tensors.get(), saved_tensors.get(),
*_backward_idx, bwd_idx,
obj, obj,
backward_state_idx)); backward_state_idx));
@ -301,7 +302,7 @@ bool PyNode::is_aot_backward() const {
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
} }
void PyNode::compiled_args(CompiledNodeArgs& args) { void PyNode::compiled_args(CompiledNodeArgs& args) const {
static PyObject* method_name = static PyObject* method_name =
PyUnicode_InternFromString("_compiled_autograd_key"); PyUnicode_InternFromString("_compiled_autograd_key");
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr)); THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
@ -346,14 +347,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
args.collect(f->input_info); args.collect(f->input_info);
Py_INCREF(obj); Py_INCREF(obj);
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter())); c10::SafePyObject backward_obj(obj, getPyInterpreter());
std::optional<c10::SafePyObject> backward_state_obj;
PyObject* bw_state = f->compiled_autograd_backward_state; PyObject* bw_state = f->compiled_autograd_backward_state;
if (args.cond(bw_state != nullptr)) { if (args.cond(bw_state != nullptr)) {
Py_INCREF(bw_state); Py_INCREF(bw_state);
_backward_state_idx = args.add_backward_state( backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter());
c10::SafePyObject(bw_state, getPyInterpreter()));
} }
args.collect_pynode_objs(
this, std::move(backward_obj), std::move(backward_state_obj));
} }
variable_list PyNode::apply_with_saved( variable_list PyNode::apply_with_saved(
@ -368,8 +370,7 @@ variable_list PyNode::apply_with_saved(
saved.before(f->output_info); saved.before(f->output_info);
saved.before(f->input_info); saved.before(f->input_info);
f->compiled_autograd_tracing = true; f->compiled_autograd_tracing = true;
variable_list result = variable_list result = apply_with_saved_impl(variable_list(inputs), saved);
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
f->compiled_autograd_tracing = false; f->compiled_autograd_tracing = false;
saved.after(f->compiled_autograd_symints); saved.after(f->compiled_autograd_symints);
saved.after(f->saved_variables); saved.after(f->saved_variables);

View File

@ -35,9 +35,9 @@ struct PyNode : public Node {
const std::vector<bool>& is_variable_input); const std::vector<bool>& is_variable_input);
variable_list apply(variable_list&& inputs) override; variable_list apply(variable_list&& inputs) override;
variable_list defer_to_dynamo( variable_list apply_with_saved_impl(
const variable_list& inputs, const variable_list& inputs,
const std::optional<PyObject*>& compiler); const SwapSavedVariables& saved);
void release_variables() override; void release_variables() override;
std::string name() const override; std::string name() const override;
@ -45,7 +45,7 @@ struct PyNode : public Node {
bool is_aot_backward() const override; bool is_aot_backward() const override;
void compiled_args(CompiledNodeArgs& args) override; void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved( variable_list apply_with_saved(
const variable_list& inputs, const variable_list& inputs,
SwapSavedVariables& saved) override; SwapSavedVariables& saved) override;
@ -53,13 +53,6 @@ struct PyNode : public Node {
// THPFunction this Function is wrapping. Owning! // THPFunction this Function is wrapping. Owning!
PyObject* obj; PyObject* obj;
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
std::optional<int> _backward_idx;
// The AutogradCompilerCall::hooks idx corresponding to this node's
// backward_state
std::optional<int> _backward_state_idx;
// NOLINTNEXTLINE(bugprone-exception-escape) // NOLINTNEXTLINE(bugprone-exception-escape)
~PyNode() override { ~PyNode() override {
// Can't use THPObjectPtr as a field in this class; destructor won't take // Can't use THPObjectPtr as a field in this class; destructor won't take

View File

@ -176,7 +176,7 @@ auto PyFunctionPostHook::operator()(
return unwrap_variables(PyTuple_GetItem(tup.get(), 0)); return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
} }
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) { void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr; PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0; Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict); Py_BEGIN_CRITICAL_SECTION(dict);
@ -189,7 +189,7 @@ void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
Py_END_CRITICAL_SECTION(); Py_END_CRITICAL_SECTION();
} }
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) { void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr; PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0; Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict); Py_BEGIN_CRITICAL_SECTION(dict);
@ -200,7 +200,7 @@ void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
Py_END_CRITICAL_SECTION(); Py_END_CRITICAL_SECTION();
} }
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) { void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr; PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0; Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict); Py_BEGIN_CRITICAL_SECTION(dict);
@ -237,7 +237,7 @@ auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
} }
void PyFunctionTensorPostAccGradHooks::compiled_args( void PyFunctionTensorPostAccGradHooks::compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) { torch::dynamo::autograd::CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr; PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0; Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict); Py_BEGIN_CRITICAL_SECTION(dict);

View File

@ -14,7 +14,8 @@ struct PyFunctionTensorPreHook : public FunctionPreHook {
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx); PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
~PyFunctionTensorPreHook() override; ~PyFunctionTensorPreHook() override;
variable_list operator()(const variable_list& values) override; variable_list operator()(const variable_list& values) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict; PyObject* dict;
size_t value_idx; size_t value_idx;
}; };
@ -23,7 +24,8 @@ struct PyFunctionPreHook : public FunctionPreHook {
PyFunctionPreHook(PyObject* dict); PyFunctionPreHook(PyObject* dict);
~PyFunctionPreHook() override; ~PyFunctionPreHook() override;
variable_list operator()(const variable_list& values) override; variable_list operator()(const variable_list& values) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict; PyObject* dict;
}; };
@ -33,7 +35,8 @@ struct PyFunctionPostHook : public FunctionPostHook {
variable_list operator()( variable_list operator()(
const variable_list& outputs, const variable_list& outputs,
const variable_list& inputs) override; const variable_list& inputs) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict; PyObject* dict;
}; };
@ -45,7 +48,8 @@ struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
PyFunctionTensorPostAccGradHooks(PyObject* dict); PyFunctionTensorPostAccGradHooks(PyObject* dict);
~PyFunctionTensorPostAccGradHooks() override; ~PyFunctionTensorPostAccGradHooks() override;
void operator()(const Variable& tensor) override; void operator()(const Variable& tensor) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
void apply_with_saved( void apply_with_saved(
Variable& tensor, Variable& tensor,
torch::dynamo::autograd::SwapSavedVariables& saved) override; torch::dynamo::autograd::SwapSavedVariables& saved) override;

View File

@ -27,7 +27,7 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
return fn_(outputs, inputs); return fn_(outputs, inputs);
} }
void compiled_args(CompiledNodeArgs& args) override {} void compiled_args(CompiledNodeArgs& args) const override {}
protected: protected:
std::function<variable_list(const variable_list&, const variable_list&)> fn_; std::function<variable_list(const variable_list&, const variable_list&)> fn_;

View File

@ -183,21 +183,23 @@ Reducer::Reducer(
#endif #endif
// Hook to execute after the gradient accumulator has executed. // Hook to execute after the gradient accumulator has executed.
hooks_.emplace_back( hooks_.emplace_back(
grad_accumulator->add_post_hook( grad_accumulator->add_post_hook(std::make_unique<
std::make_unique<torch::autograd::utils::LambdaPostHook>( torch::autograd::utils::
[this, variable_index]( LambdaPostHook>(
const torch::autograd::variable_list& outputs, [this, variable_index](
const torch::autograd::variable_list& /* unused */) { const torch::autograd::variable_list& outputs,
const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32 #ifndef _WIN32
this->rpc_context_.set( this->rpc_context_.set(
ThreadLocalDistAutogradContext::getContextPtr()); ThreadLocalDistAutogradContext::getContextPtr());
#endif #endif
this->autograd_hook(variable_index); this->autograd_hook(variable_index);
return outputs; return outputs;
}, },
[=](torch::autograd::CompiledNodeArgs& args) { [=](torch::autograd::CompiledNodeArgs& args) {
// Make post_hook an noop if compiled_autograds is enabled. TORCH_INTERNAL_ASSERT(
})), "Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
})),
grad_accumulator); grad_accumulator);
// Map raw function pointer to parameter index. // Map raw function pointer to parameter index.

View File

@ -343,6 +343,9 @@ struct AutogradCompilerCall {
std::vector<uint32_t> size_input_origins; std::vector<uint32_t> size_input_origins;
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>> std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
sv_to_hooks; sv_to_hooks;
// pynode -> backward and backward state idx
std::unordered_map<const Node*, std::pair<size_t, std::optional<size_t>>>
pynode_objs;
}; };
class CompiledNodeArgs { class CompiledNodeArgs {
@ -613,12 +616,17 @@ class CompiledNodeArgs {
typeid(*node), _specialization_key, _specialization_key_size); typeid(*node), _specialization_key, _specialization_key_size);
} }
size_t add_backward(c10::SafePyObject&& obj) { void collect_pynode_objs(
return _compiler.emplace_hook(std::move(obj)); const Node* pynode,
} c10::SafePyObject&& bwd,
std::optional<c10::SafePyObject>&& bwd_state) {
size_t add_backward_state(c10::SafePyObject&& obj) { size_t bwd_idx = _compiler.emplace_hook(std::move(bwd));
return _compiler.emplace_hook(std::move(obj)); std::optional<size_t> bwd_state_idx;
if (auto state = std::move(bwd_state); state.has_value()) {
bwd_state_idx = _compiler.emplace_hook(std::move(state.value()));
}
_compiler.pynode_objs.emplace(
pynode, std::make_pair(bwd_idx, bwd_state_idx));
} }
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) { void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
@ -737,6 +745,13 @@ class SwapSavedVariables {
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes, // cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
// allows tracing to happen, then swaps them back afterwards. // allows tracing to happen, then swaps them back afterwards.
public: public:
std::pair<size_t, std::optional<size_t>> retrieve_pynode_objs(
Node* pynode) const {
auto it = compiler.pynode_objs.find(pynode);
TORCH_INTERNAL_ASSERT(it != compiler.pynode_objs.end());
return it->second;
}
void before(at::Tensor& t) { void before(at::Tensor& t) {
TensorArg& arg = compiler.tensor_args.lookup(t); TensorArg& arg = compiler.tensor_args.lookup(t);
stashed_tensors.save(&t, std::move(t)); stashed_tensors.save(&t, std::move(t));
@ -942,7 +957,7 @@ class SwapSavedVariables {
const NodeCall& n) const NodeCall& n)
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {} : compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
PyObject* get_py_compiler() { PyObject* get_py_compiler() const {
return py_compiler; return py_compiler;
} }