mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
Revert "[ca] side-effect free inital trace: compiled_args (#147804)"
This reverts commit ec768d8dc04b334e01db1a90e4e6646e4e867e67. Reverted https://github.com/pytorch/pytorch/pull/147804 on behalf of https://github.com/wdvr due to failing tests in the slow workflow, see below ([comment](https://github.com/pytorch/pytorch/pull/147804#issuecomment-2683594740))
This commit is contained in:
@ -34,7 +34,6 @@
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/tensor_dtypes.h>
|
||||
|
||||
#include <autograd/function.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
@ -186,9 +185,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
return to_variable_list(r.get(), is_variable_input);
|
||||
}
|
||||
|
||||
auto PyNode::apply_with_saved_impl(
|
||||
auto PyNode::defer_to_dynamo(
|
||||
const variable_list& inputs,
|
||||
const SwapSavedVariables& saved) -> variable_list {
|
||||
const std::optional<PyObject*>& compiler) -> variable_list {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::OptionalDeviceGuard _device_guard;
|
||||
THPFunction* py_fn = (THPFunction*)obj;
|
||||
@ -236,24 +235,24 @@ auto PyNode::apply_with_saved_impl(
|
||||
}
|
||||
THPObjectPtr saved_tensors(unpack_saved_variables(
|
||||
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
|
||||
|
||||
auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
_backward_idx.has_value(),
|
||||
"indices should already be set by compiled_args, called before apply_with_saved");
|
||||
PyObject* backward_state_idx = Py_None;
|
||||
if (maybe_bwd_state_idx.has_value()) {
|
||||
backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value());
|
||||
if (_backward_state_idx.has_value()) {
|
||||
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value());
|
||||
// this might be simplifiable now that we no longer inline
|
||||
Py_CLEAR(py_fn->compiled_autograd_backward_state);
|
||||
}
|
||||
THPObjectPtr r(PyObject_CallMethod(
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
saved.get_py_compiler(),
|
||||
compiler.value(),
|
||||
"proxy_call_backward",
|
||||
"OOOiOO",
|
||||
pyInputs.get(),
|
||||
fwdInputMetadatas.get(),
|
||||
saved_tensors.get(),
|
||||
bwd_idx,
|
||||
*_backward_idx,
|
||||
obj,
|
||||
backward_state_idx));
|
||||
|
||||
@ -302,7 +301,7 @@ bool PyNode::is_aot_backward() const {
|
||||
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
|
||||
}
|
||||
|
||||
void PyNode::compiled_args(CompiledNodeArgs& args) const {
|
||||
void PyNode::compiled_args(CompiledNodeArgs& args) {
|
||||
static PyObject* method_name =
|
||||
PyUnicode_InternFromString("_compiled_autograd_key");
|
||||
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
|
||||
@ -347,15 +346,14 @@ void PyNode::compiled_args(CompiledNodeArgs& args) const {
|
||||
args.collect(f->input_info);
|
||||
|
||||
Py_INCREF(obj);
|
||||
c10::SafePyObject backward_obj(obj, getPyInterpreter());
|
||||
std::optional<c10::SafePyObject> backward_state_obj;
|
||||
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
|
||||
|
||||
PyObject* bw_state = f->compiled_autograd_backward_state;
|
||||
if (args.cond(bw_state != nullptr)) {
|
||||
Py_INCREF(bw_state);
|
||||
backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter());
|
||||
_backward_state_idx = args.add_backward_state(
|
||||
c10::SafePyObject(bw_state, getPyInterpreter()));
|
||||
}
|
||||
args.collect_pynode_objs(
|
||||
this, std::move(backward_obj), std::move(backward_state_obj));
|
||||
}
|
||||
|
||||
variable_list PyNode::apply_with_saved(
|
||||
@ -370,7 +368,8 @@ variable_list PyNode::apply_with_saved(
|
||||
saved.before(f->output_info);
|
||||
saved.before(f->input_info);
|
||||
f->compiled_autograd_tracing = true;
|
||||
variable_list result = apply_with_saved_impl(variable_list(inputs), saved);
|
||||
variable_list result =
|
||||
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
|
||||
f->compiled_autograd_tracing = false;
|
||||
saved.after(f->compiled_autograd_symints);
|
||||
saved.after(f->saved_variables);
|
||||
|
||||
Reference in New Issue
Block a user