[ca] introduce RuntimeState to support c++ hooks via graph breaks (#149987)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149987
Approved by: https://github.com/jansel
ghstack dependencies: #149647, #149709, #149651, #149897
This commit is contained in:
Simon Fan
2025-03-25 15:16:44 -07:00
committed by PyTorch MergeBot
parent dcb378cff2
commit 748252378d
7 changed files with 143 additions and 23 deletions

View File

@ -4146,6 +4146,17 @@ known_graph_breaks_tests = {
"test_checkpointing_without_reentrant_memory_savings", # reentrant .backward
"test_dtensor_basic", # torch._dynamo.exc.Unsupported: Failed to convert args/kwargs to proxy
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent", # subclass constructor
"test_retain_grad", # retains_grad_hooks
"test_retain_grad_cycle", # retains_grad_hooks
"test_retain_grad_inplace", # retains_grad_hooks
"test_retain_grad_inplace_over_view", # retains_grad_hooks
"test_retains_grad_can_always_observe_tensor_prehook", # retains_grad_hooks
"test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks
"test_hook_edge_case_when_called_with_grad", # retains_grad_hooks
"test_multi_grad_all_hooks", # retains_grad_hooks
"test_prehook_ordering", # retains_grad_hooks
"test_will_engine_execute_node", # retains_grad_hooks
"test_backward_to_node", # retains_grad_hooks
}
test_contexts = {
@ -4173,11 +4184,6 @@ known_failing_tests = {
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
"test_post_accumulate_grad_hook_ordering", # accuracy error
"test_retain_grad_cycle", # retains_grad_hooks
"test_retain_grad_inplace", # retains_grad_hooks
"test_retain_grad_inplace_over_view", # retains_grad_hooks
"test_retains_grad_can_always_observe_tensor_prehook", # retains_grad_hooks
"test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks
"test_accumulate_grad", # create_graph
"test_anomaly_assign_parent_cleanup", # create_graph
"test_backward_create_graph_warns", # create_graph
@ -4198,19 +4204,13 @@ known_failing_tests = {
"test_grad_nonleaf", # create_graph
"test_grad_nonleaf_many_outputs", # create_graph
"test_hessian_vector", # create_graph
"test_hook_edge_case_when_called_with_grad", # retains_grad_hooks
"test_inplace_on_view_backward", # create_graph
"test_multi_grad_any_hooks", # register_multi_grad_hook
"test_multi_grad_all_hooks", # retains_grad_hooks
"test_nested_anomaly_detect_nan", # create_graph
"test_nested_anomaly_printstack_cleanup", # create_graph
"test_once_differentiable", # create_graph
"test_prehook_ordering", # retains_grad_hooks
"test_retain_grad", # retains_grad_hooks
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph
"test_select_sum", # create_graph, also needs graph breaks
"test_will_engine_execute_node", # retains_grad_hooks
"test_backward_to_node", # retains_grad_hooks NYI
"test_custom_autograd_no_early_free", # create_graph
"test_custom_function_error", # vjp
"test_custom_function_save_for_forward", # vjp

View File

@ -1,5 +1,6 @@
from typing import Callable
from torch import Tensor
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
def set_autograd_compiler(
@ -9,3 +10,4 @@ def set_autograd_compiler(
def clear_cache() -> None: ...
def is_cache_empty() -> bool: ...
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ...

View File

@ -734,6 +734,18 @@ class AutogradCompilerInstance:
self.bind_objects_to_proxies([inputs[i]], [proxy])
return inputs
def cpp_tensor_pre_hook(self, inputs: list[torch.Tensor], hook_id: int, i: int):
proxy = self.fx_tracer.create_proxy(
"call_function",
torch._C._dynamo.compiled_autograd.call_cpp_tensor_pre_hooks,
(hook_id, self.to_proxy(inputs[i])),
{},
)
with disable_proxy_modes_tracing():
inputs[i] = maybe_clone(inputs[i])
self.bind_objects_to_proxies([inputs[i]], [proxy])
return inputs
def pre_hook(self, inputs, hook_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]

View File

@ -64,4 +64,9 @@ variable_list CppFunctionSingleTensorPreHook::operator()(
return results;
}
void CppFunctionSingleTensorPreHook::compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
args.add_cpp_single_tensor_pre_hook(hook_, value_idx_);
}
} // namespace torch::autograd

View File

@ -22,6 +22,9 @@ struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
size_t value_idx);
variable_list operator()(const variable_list& values) override;
void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
std::function<at::TensorBase(const at::TensorBase&)> hook_;
size_t value_idx_;
};

View File

@ -164,6 +164,7 @@ struct NodeCall {
uint32_t id;
std::shared_ptr<Node> node;
std::vector<std::pair<int, int>> tensor_pre_hooks;
std::vector<std::pair<int, int>> cpp_tensor_pre_hooks;
std::vector<int> pre_hooks;
std::vector<int> post_hooks;
std::vector<int> post_acc_grad_hooks;
@ -333,6 +334,12 @@ struct AutogradCompilerCall {
return hooks.size() - 1;
}
size_t emplace_cpp_tensor_pre_hook(
std::function<at::TensorBase(const at::TensorBase&)>&& fn) {
cpp_tensor_pre_hooks.emplace_back(std::move(fn));
return cpp_tensor_pre_hooks.size() - 1;
}
size_t emplace_packed_input(c10::SafePyObject&& input) {
packed_inputs.emplace_back(std::move(input));
return packed_inputs.size() - 1;
@ -348,6 +355,8 @@ struct AutogradCompilerCall {
LiftedIValueArgs lifted_ivalue_args;
std::vector<int64_t> dyn_size_inputs;
std::vector<c10::SafePyObject> hooks;
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>
cpp_tensor_pre_hooks;
std::vector<c10::SafePyObject> packed_inputs;
NodeCalls node_calls;
SizeInput::DynType default_dyn_type;
@ -602,12 +611,12 @@ class CompiledNodeArgs {
#undef COLLECT_AS_BYTES
void collect_hooks_from(Node* fn) {
TORCH_CHECK(
fn->retains_grad_hooks().empty(),
"retains_grad_hooks not implemented for compiled autograd");
for (auto& i : fn->tensor_pre_hooks()) {
i->compiled_args(*this);
}
for (auto& [_, i] : fn->retains_grad_hooks()) {
i->compiled_args(*this);
}
for (auto& i : fn->pre_hooks()) {
i->compiled_args(*this);
}
@ -647,6 +656,23 @@ class CompiledNodeArgs {
_node_call.tensor_pre_hooks.emplace_back(fn_id, index);
}
void add_cpp_single_tensor_pre_hook(
const std::function<at::TensorBase(const at::TensorBase&)>& hook,
size_t idx) {
auto wrapper = [hook](const at::TensorBase& grad) {
// handle when hook returns nothing
auto out = hook(grad);
if (!out.defined()) {
return grad;
}
return out;
};
auto hook_id = _compiler.emplace_cpp_tensor_pre_hook(std::move(wrapper));
collect_size(hook_id);
_node_call.cpp_tensor_pre_hooks.emplace_back(hook_id, idx);
}
void add_pre_hook(c10::SafePyObject&& obj) {
auto fn_id = _compiler.emplace_hook(std::move(obj));
collect_size(fn_id);

View File

@ -58,6 +58,63 @@ int default_dyn_type_int = 0;
PyObject* python_verbose_logger = nullptr;
} // namespace
// see https://github.com/pytorch/pytorch/pull/34845
static void throw_python_error() {
python_error err;
err.persist();
throw std::move(err);
}
// RuntimeState contains arbitrary callables created during the forward pass.
// e.g. .retains_grad(). It is created during the compiled_args stage, and is
// used at runtime. The lifetime of RuntimeState is a single backward pass.
struct RuntimeState {
at::TensorBase call_cpp_tensor_pre_hooks(
size_t idx,
const at::TensorBase& grad) {
TORCH_INTERNAL_ASSERT(
cpp_tensor_pre_hooks.size() > static_cast<size_t>(idx));
return cpp_tensor_pre_hooks[idx](grad);
}
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>
cpp_tensor_pre_hooks;
size_t next_id = 0;
};
static RuntimeState* active_rstate;
struct RuntimeStateGuard {
RuntimeStateGuard() : _state(std::make_unique<RuntimeState>()) {
active_rstate = _state.get();
}
RuntimeStateGuard(const RuntimeStateGuard&) = delete;
RuntimeStateGuard& operator=(const RuntimeStateGuard&) = delete;
RuntimeStateGuard(RuntimeStateGuard&&) = delete;
RuntimeStateGuard& operator=(RuntimeStateGuard&&) = delete;
~RuntimeStateGuard() {
active_rstate = nullptr;
}
std::unique_ptr<RuntimeState> _state;
};
static PyObject* call_cpp_tensor_pre_hooks(PyObject* dummy, PyObject* args) {
HANDLE_TH_ERRORS;
int idx = -1;
PyObject* grad = nullptr;
if (!PyArg_ParseTuple(args, "iO", &idx, &grad)) {
throw_python_error();
}
TORCH_INTERNAL_ASSERT(idx > -1);
TORCH_INTERNAL_ASSERT(grad != nullptr);
TORCH_INTERNAL_ASSERT(active_rstate != nullptr);
auto res = active_rstate->call_cpp_tensor_pre_hooks(
static_cast<size_t>(idx), THPVariable_Unpack(grad));
return THPVariable_Wrap(res);
END_HANDLE_TH_ERRORS;
}
// List[Optional[Tensor]] in Python can't be directly parsed into a
// List[Tensor], so we need to do this conversion manually.
static std::vector<at::Tensor> toTensorList(
@ -253,13 +310,6 @@ static PyObject* convert_pyobj_list(std::vector<c10::SafePyObject>& inputs) {
return pyinput;
}
// see https://github.com/pytorch/pytorch/pull/34845
static void throw_python_error() {
python_error err;
err.persist();
throw std::move(err);
}
static PyObject* check(PyObject* pyresult) {
if (C10_UNLIKELY(pyresult == nullptr)) {
throw_python_error();
@ -608,6 +658,10 @@ static PyMethodDef _methods[] = {
{"clear_cache", clear_cache, METH_NOARGS, nullptr},
{"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr},
{"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr},
{"call_cpp_tensor_pre_hooks",
call_cpp_tensor_pre_hooks,
METH_VARARGS,
nullptr},
{nullptr, nullptr, 0, nullptr}};
static struct PyModuleDef _module = {
@ -827,7 +881,8 @@ static CacheNode* _compiled_autograd_impl(
THPObjectPtr* graph_arg_sizes,
THPObjectPtr* graph_arg_ivalue_args,
THPObjectPtr* graph_arg_hooks,
THPObjectPtr* graph_arg_packed_inputs) {
THPObjectPtr* graph_arg_packed_inputs,
RuntimeState* rstate) {
const std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
std::unordered_map<Node*, int> visited_dependencies;
visited_dependencies.reserve(dependencies.size());
@ -963,6 +1018,20 @@ static CacheNode* _compiled_autograd_impl(
}
inputs = THPVariable_UnpackList(pyinputs);
}
if (!call.cpp_tensor_pre_hooks.empty()) {
// proxy a call to runtimestate
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
for (const auto& [hook_id, idx] : call.cpp_tensor_pre_hooks) {
pyinputs = check(PyObject_CallMethod(
py_compiler,
"cpp_tensor_pre_hook",
"Oii",
pyinputs.get(),
hook_id,
idx));
}
inputs = THPVariable_UnpackList(pyinputs);
}
for (const auto& graph_output : call.graph_output) {
int input_nr = graph_output.first;
int output_index = graph_output.second;
@ -1090,6 +1159,7 @@ static CacheNode* _compiled_autograd_impl(
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
*graph_arg_hooks = convert_pyobj_list(compiler_call.hooks);
*graph_arg_packed_inputs = convert_pyobj_list(compiler_call.packed_inputs);
rstate->cpp_tensor_pre_hooks = std::move(compiler_call.cpp_tensor_pre_hooks);
return cache;
}
@ -1125,6 +1195,7 @@ static variable_list compiled_autograd(
LockGuardWithErrorLogs lock_guard(mtx);
pybind11::gil_scoped_acquire gil;
at::ThreadLocalStateGuard tls_guard(graph_task.thread_locals_);
RuntimeStateGuard rstate_guard;
THPObjectPtr inputs;
THPObjectPtr sizes;
@ -1140,7 +1211,8 @@ static variable_list compiled_autograd(
&sizes,
&ivalue_args,
&hooks,
&packed_inputs);
&packed_inputs,
active_rstate);
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
cache->runtime_wrapper.get(),