[ca] side-effect free inital trace: compiled_args (#147804)

const methods to prevent accidental mutation. changes mainly in Error nodes and PyNode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147804
Approved by: https://github.com/jansel
ghstack dependencies: #147242, #147796
This commit is contained in:
Simon Fan
2025-02-24 20:03:09 -08:00
committed by PyTorch MergeBot
parent 5758743f3c
commit ec768d8dc0
17 changed files with 105 additions and 79 deletions

View File

@ -34,6 +34,7 @@
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_dtypes.h>
#include <autograd/function.h>
#include <functional>
#include <memory>
#include <stdexcept>
@ -185,9 +186,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
return to_variable_list(r.get(), is_variable_input);
}
auto PyNode::defer_to_dynamo(
auto PyNode::apply_with_saved_impl(
const variable_list& inputs,
const std::optional<PyObject*>& compiler) -> variable_list {
const SwapSavedVariables& saved) -> variable_list {
pybind11::gil_scoped_acquire gil;
at::OptionalDeviceGuard _device_guard;
THPFunction* py_fn = (THPFunction*)obj;
@ -235,24 +236,24 @@ auto PyNode::defer_to_dynamo(
}
THPObjectPtr saved_tensors(unpack_saved_variables(
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
TORCH_INTERNAL_ASSERT(
_backward_idx.has_value(),
"indices should already be set by compiled_args, called before apply_with_saved");
auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this);
PyObject* backward_state_idx = Py_None;
if (_backward_state_idx.has_value()) {
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value());
if (maybe_bwd_state_idx.has_value()) {
backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value());
// this might be simplifiable now that we no longer inline
Py_CLEAR(py_fn->compiled_autograd_backward_state);
}
THPObjectPtr r(PyObject_CallMethod(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
compiler.value(),
saved.get_py_compiler(),
"proxy_call_backward",
"OOOiOO",
pyInputs.get(),
fwdInputMetadatas.get(),
saved_tensors.get(),
*_backward_idx,
bwd_idx,
obj,
backward_state_idx));
@ -301,7 +302,7 @@ bool PyNode::is_aot_backward() const {
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
}
void PyNode::compiled_args(CompiledNodeArgs& args) {
void PyNode::compiled_args(CompiledNodeArgs& args) const {
static PyObject* method_name =
PyUnicode_InternFromString("_compiled_autograd_key");
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
@ -346,14 +347,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
args.collect(f->input_info);
Py_INCREF(obj);
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
c10::SafePyObject backward_obj(obj, getPyInterpreter());
std::optional<c10::SafePyObject> backward_state_obj;
PyObject* bw_state = f->compiled_autograd_backward_state;
if (args.cond(bw_state != nullptr)) {
Py_INCREF(bw_state);
_backward_state_idx = args.add_backward_state(
c10::SafePyObject(bw_state, getPyInterpreter()));
backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter());
}
args.collect_pynode_objs(
this, std::move(backward_obj), std::move(backward_state_obj));
}
variable_list PyNode::apply_with_saved(
@ -368,8 +370,7 @@ variable_list PyNode::apply_with_saved(
saved.before(f->output_info);
saved.before(f->input_info);
f->compiled_autograd_tracing = true;
variable_list result =
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
variable_list result = apply_with_saved_impl(variable_list(inputs), saved);
f->compiled_autograd_tracing = false;
saved.after(f->compiled_autograd_symints);
saved.after(f->saved_variables);