[ca] side-effect free inital trace: compiled_args (#147804)

const methods to prevent accidental mutation. changes mainly in Error nodes and PyNode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147804
Approved by: https://github.com/jansel
ghstack dependencies: #147242, #147796
This commit is contained in:
Simon Fan
2025-02-25 19:57:55 -08:00
committed by PyTorch MergeBot
parent 5e3069dde8
commit fd1220e386
17 changed files with 105 additions and 79 deletions

View File

@ -67,7 +67,7 @@ struct TORCH_API ${op} : public ${superclass} {
${release_variables}
}
${will_release_variables}
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
${saved_variables}
${saved_list_sizes}
@ -127,7 +127,7 @@ variable_list ${op}::apply(variable_list&& grads) {
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
}
void ${op}::compiled_args(CompiledNodeArgs& args) {
void ${op}::compiled_args(CompiledNodeArgs& args) const {
${compiled_args}
}
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {

View File

@ -284,7 +284,7 @@ struct CppNode : public Node {
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
void save_variables_to_ctx();
void compiled_args(CompiledNodeArgs& args) override {
void compiled_args(CompiledNodeArgs& args) const override {
// although neither of the 2 methods below have uniqueness guarantees
// it is unlikely for them to collide at the same time
args.collect(static_cast<uint64_t>(typeid(T).hash_code()));

View File

@ -545,8 +545,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return tensor_pre_hooks_;
}
virtual std::unique_ptr<PostAccumulateGradHook>&
tensor_post_acc_grad_hooks() noexcept {
virtual std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
const noexcept {
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
return empty;
}
@ -593,7 +593,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// 2) Collect node information for specialization and caching
// Implementations in subclasses should call args.collect() with all node
// attrs. These functions are only called durring backward.
virtual void compiled_args(CompiledNodeArgs& args) {
virtual void compiled_args(CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("compiled_args not implemented: ") + name());
}

View File

@ -22,7 +22,8 @@ struct TORCH_API FunctionPreHook {
virtual ~FunctionPreHook() = default;
virtual variable_list operator()(const variable_list& grads) = 0;
// only implemented for python hooks, registers hook with compiled autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
@ -35,7 +36,8 @@ struct TORCH_API FunctionPostHook {
const variable_list& outputs /* grad_inputs */,
const variable_list& inputs /* grad_outputs */) = 0;
// only implemented for python hooks, registers hook with compiled autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
@ -47,7 +49,8 @@ struct TORCH_API PostAccumulateGradHook {
virtual void operator()(const Variable& tensor) = 0;
// only implemented for python hooks on nodes, registers hook with compiled
// autograd
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("not yet implemented for compiled autograd: ") +
typeid(*this).name());

View File

@ -66,12 +66,12 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
return variable_list();
}
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) {
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
if (args.cond(variable.defined() && variable.requires_grad())) {
args.collect(variable);
args.collect(variable.grad());
}
auto& hook = tensor_post_acc_grad_hooks();
const auto& hook = tensor_post_acc_grad_hooks();
if (hook != nullptr) {
hook->compiled_args(args);
}

View File

@ -50,8 +50,8 @@ struct TORCH_API AccumulateGrad : public Node {
return impl::hooks(variable);
}
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept
override {
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
const noexcept override {
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
// it can be destroyed even though the Tensor is still alive (contrary
// to all other Nodes). So we must lazily read the Tensor hooks here.
@ -262,7 +262,7 @@ struct TORCH_API AccumulateGrad : public Node {
}
}
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;

View File

@ -12,11 +12,15 @@
namespace torch::autograd {
auto Error::apply(variable_list&& inputs) -> variable_list {
variable_list Error::apply(variable_list&& inputs) {
return static_cast<const Error*>(this)->apply(std::move(inputs));
}
variable_list Error::apply(variable_list&& inputs) const {
throw std::runtime_error(msg);
}
void Error::compiled_args(CompiledNodeArgs& args) {
void Error::compiled_args(CompiledNodeArgs& args) const {
// throw the error durring collect, the graph won't get compiled
apply(variable_list());
}
@ -66,7 +70,7 @@ auto Identity::apply(variable_list&& grads) -> variable_list {
return std::move(grads);
}
void GraphRoot::compiled_args(CompiledNodeArgs& args) {
void GraphRoot::compiled_args(CompiledNodeArgs& args) const {
args.collect(outputs);
}
variable_list GraphRoot::apply_with_saved(

View File

@ -18,8 +18,9 @@ struct TORCH_API Error : public Node {
Error(std::string msg) : msg(std::move(msg)) {}
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;
@ -51,6 +52,7 @@ struct TORCH_API DelayedError : public Node {
}
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
std::string msg;
};
@ -61,6 +63,7 @@ struct TORCH_API UndefinedGrad : public Node {
}
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
};
struct TORCH_API UndefinedGradBackward : public Node {
@ -69,8 +72,9 @@ struct TORCH_API UndefinedGradBackward : public Node {
UndefinedGradBackward() = default;
variable_list apply(variable_list&& inputs) override;
variable_list apply(variable_list&& inputs) const;
void compiled_args(CompiledNodeArgs& args) override {}
void compiled_args(CompiledNodeArgs& args) const override {}
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override {
@ -93,7 +97,7 @@ struct TORCH_API GraphRoot : public Node {
return outputs;
}
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;

View File

@ -60,7 +60,7 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
src_options);
}
void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
void CopyBackwards::compiled_args(CompiledNodeArgs& args) const {
args.collect(src_options);
}
@ -235,7 +235,7 @@ void CopySlices::release_variables() {
fn = nullptr;
}
void CopySlices::compiled_args(CompiledNodeArgs& args) {
void CopySlices::compiled_args(CompiledNodeArgs& args) const {
TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
TORCH_INTERNAL_ASSERT((bool)fn);
args.collect(base);

View File

@ -15,7 +15,7 @@ namespace torch::autograd {
struct TORCH_API CopyBackwards : public Node {
variable_list apply(variable_list&& grads) override;
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;
@ -168,7 +168,7 @@ struct TORCH_API CopySlices : public Node {
variable_list apply(variable_list&& inputs) override;
void release_variables() override;
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;

View File

@ -34,6 +34,7 @@
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_dtypes.h>
#include <autograd/function.h>
#include <functional>
#include <memory>
#include <stdexcept>
@ -185,9 +186,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
return to_variable_list(r.get(), is_variable_input);
}
auto PyNode::defer_to_dynamo(
auto PyNode::apply_with_saved_impl(
const variable_list& inputs,
const std::optional<PyObject*>& compiler) -> variable_list {
const SwapSavedVariables& saved) -> variable_list {
pybind11::gil_scoped_acquire gil;
at::OptionalDeviceGuard _device_guard;
THPFunction* py_fn = (THPFunction*)obj;
@ -235,24 +236,24 @@ auto PyNode::defer_to_dynamo(
}
THPObjectPtr saved_tensors(unpack_saved_variables(
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
TORCH_INTERNAL_ASSERT(
_backward_idx.has_value(),
"indices should already be set by compiled_args, called before apply_with_saved");
auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this);
PyObject* backward_state_idx = Py_None;
if (_backward_state_idx.has_value()) {
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value());
if (maybe_bwd_state_idx.has_value()) {
backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value());
// this might be simplifiable now that we no longer inline
Py_CLEAR(py_fn->compiled_autograd_backward_state);
}
THPObjectPtr r(PyObject_CallMethod(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
compiler.value(),
saved.get_py_compiler(),
"proxy_call_backward",
"OOOiOO",
pyInputs.get(),
fwdInputMetadatas.get(),
saved_tensors.get(),
*_backward_idx,
bwd_idx,
obj,
backward_state_idx));
@ -301,7 +302,7 @@ bool PyNode::is_aot_backward() const {
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
}
void PyNode::compiled_args(CompiledNodeArgs& args) {
void PyNode::compiled_args(CompiledNodeArgs& args) const {
static PyObject* method_name =
PyUnicode_InternFromString("_compiled_autograd_key");
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
@ -346,14 +347,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
args.collect(f->input_info);
Py_INCREF(obj);
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
c10::SafePyObject backward_obj(obj, getPyInterpreter());
std::optional<c10::SafePyObject> backward_state_obj;
PyObject* bw_state = f->compiled_autograd_backward_state;
if (args.cond(bw_state != nullptr)) {
Py_INCREF(bw_state);
_backward_state_idx = args.add_backward_state(
c10::SafePyObject(bw_state, getPyInterpreter()));
backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter());
}
args.collect_pynode_objs(
this, std::move(backward_obj), std::move(backward_state_obj));
}
variable_list PyNode::apply_with_saved(
@ -368,8 +370,7 @@ variable_list PyNode::apply_with_saved(
saved.before(f->output_info);
saved.before(f->input_info);
f->compiled_autograd_tracing = true;
variable_list result =
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
variable_list result = apply_with_saved_impl(variable_list(inputs), saved);
f->compiled_autograd_tracing = false;
saved.after(f->compiled_autograd_symints);
saved.after(f->saved_variables);

View File

@ -35,9 +35,9 @@ struct PyNode : public Node {
const std::vector<bool>& is_variable_input);
variable_list apply(variable_list&& inputs) override;
variable_list defer_to_dynamo(
variable_list apply_with_saved_impl(
const variable_list& inputs,
const std::optional<PyObject*>& compiler);
const SwapSavedVariables& saved);
void release_variables() override;
std::string name() const override;
@ -45,7 +45,7 @@ struct PyNode : public Node {
bool is_aot_backward() const override;
void compiled_args(CompiledNodeArgs& args) override;
void compiled_args(CompiledNodeArgs& args) const override;
variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) override;
@ -53,13 +53,6 @@ struct PyNode : public Node {
// THPFunction this Function is wrapping. Owning!
PyObject* obj;
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
std::optional<int> _backward_idx;
// The AutogradCompilerCall::hooks idx corresponding to this node's
// backward_state
std::optional<int> _backward_state_idx;
// NOLINTNEXTLINE(bugprone-exception-escape)
~PyNode() override {
// Can't use THPObjectPtr as a field in this class; destructor won't take

View File

@ -176,7 +176,7 @@ auto PyFunctionPostHook::operator()(
return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
}
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict);
@ -189,7 +189,7 @@ void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
Py_END_CRITICAL_SECTION();
}
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict);
@ -200,7 +200,7 @@ void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
Py_END_CRITICAL_SECTION();
}
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict);
@ -237,7 +237,7 @@ auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
}
void PyFunctionTensorPostAccGradHooks::compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) {
torch::dynamo::autograd::CompiledNodeArgs& args) const {
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
Py_BEGIN_CRITICAL_SECTION(dict);

View File

@ -14,7 +14,8 @@ struct PyFunctionTensorPreHook : public FunctionPreHook {
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
~PyFunctionTensorPreHook() override;
variable_list operator()(const variable_list& values) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict;
size_t value_idx;
};
@ -23,7 +24,8 @@ struct PyFunctionPreHook : public FunctionPreHook {
PyFunctionPreHook(PyObject* dict);
~PyFunctionPreHook() override;
variable_list operator()(const variable_list& values) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict;
};
@ -33,7 +35,8 @@ struct PyFunctionPostHook : public FunctionPostHook {
variable_list operator()(
const variable_list& outputs,
const variable_list& inputs) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
PyObject* dict;
};
@ -45,7 +48,8 @@ struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
PyFunctionTensorPostAccGradHooks(PyObject* dict);
~PyFunctionTensorPostAccGradHooks() override;
void operator()(const Variable& tensor) override;
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
void apply_with_saved(
Variable& tensor,
torch::dynamo::autograd::SwapSavedVariables& saved) override;

View File

@ -27,7 +27,7 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
return fn_(outputs, inputs);
}
void compiled_args(CompiledNodeArgs& args) override {}
void compiled_args(CompiledNodeArgs& args) const override {}
protected:
std::function<variable_list(const variable_list&, const variable_list&)> fn_;

View File

@ -183,21 +183,23 @@ Reducer::Reducer(
#endif
// Hook to execute after the gradient accumulator has executed.
hooks_.emplace_back(
grad_accumulator->add_post_hook(
std::make_unique<torch::autograd::utils::LambdaPostHook>(
[this, variable_index](
const torch::autograd::variable_list& outputs,
const torch::autograd::variable_list& /* unused */) {
grad_accumulator->add_post_hook(std::make_unique<
torch::autograd::utils::
LambdaPostHook>(
[this, variable_index](
const torch::autograd::variable_list& outputs,
const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
this->rpc_context_.set(
ThreadLocalDistAutogradContext::getContextPtr());
this->rpc_context_.set(
ThreadLocalDistAutogradContext::getContextPtr());
#endif
this->autograd_hook(variable_index);
return outputs;
},
[=](torch::autograd::CompiledNodeArgs& args) {
// Make post_hook an noop if compiled_autograds is enabled.
})),
this->autograd_hook(variable_index);
return outputs;
},
[=](torch::autograd::CompiledNodeArgs& args) {
TORCH_INTERNAL_ASSERT(
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
})),
grad_accumulator);
// Map raw function pointer to parameter index.

View File

@ -343,6 +343,9 @@ struct AutogradCompilerCall {
std::vector<uint32_t> size_input_origins;
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
sv_to_hooks;
// pynode -> backward and backward state idx
std::unordered_map<const Node*, std::pair<size_t, std::optional<size_t>>>
pynode_objs;
};
class CompiledNodeArgs {
@ -613,12 +616,17 @@ class CompiledNodeArgs {
typeid(*node), _specialization_key, _specialization_key_size);
}
size_t add_backward(c10::SafePyObject&& obj) {
return _compiler.emplace_hook(std::move(obj));
}
size_t add_backward_state(c10::SafePyObject&& obj) {
return _compiler.emplace_hook(std::move(obj));
void collect_pynode_objs(
const Node* pynode,
c10::SafePyObject&& bwd,
std::optional<c10::SafePyObject>&& bwd_state) {
size_t bwd_idx = _compiler.emplace_hook(std::move(bwd));
std::optional<size_t> bwd_state_idx;
if (auto state = std::move(bwd_state); state.has_value()) {
bwd_state_idx = _compiler.emplace_hook(std::move(state.value()));
}
_compiler.pynode_objs.emplace(
pynode, std::make_pair(bwd_idx, bwd_state_idx));
}
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
@ -737,6 +745,13 @@ class SwapSavedVariables {
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
// allows tracing to happen, then swaps them back afterwards.
public:
std::pair<size_t, std::optional<size_t>> retrieve_pynode_objs(
Node* pynode) const {
auto it = compiler.pynode_objs.find(pynode);
TORCH_INTERNAL_ASSERT(it != compiler.pynode_objs.end());
return it->second;
}
void before(at::Tensor& t) {
TensorArg& arg = compiler.tensor_args.lookup(t);
stashed_tensors.save(&t, std::move(t));
@ -942,7 +957,7 @@ class SwapSavedVariables {
const NodeCall& n)
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
PyObject* get_py_compiler() {
PyObject* get_py_compiler() const {
return py_compiler;
}