mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca] introduce RuntimeState to support c++ hooks via graph breaks (#149987)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149987 Approved by: https://github.com/jansel ghstack dependencies: #149647, #149709, #149651, #149897
This commit is contained in:
committed by
PyTorch MergeBot
parent
dcb378cff2
commit
748252378d
@ -4146,6 +4146,17 @@ known_graph_breaks_tests = {
|
||||
"test_checkpointing_without_reentrant_memory_savings", # reentrant .backward
|
||||
"test_dtensor_basic", # torch._dynamo.exc.Unsupported: Failed to convert args/kwargs to proxy
|
||||
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent", # subclass constructor
|
||||
"test_retain_grad", # retains_grad_hooks
|
||||
"test_retain_grad_cycle", # retains_grad_hooks
|
||||
"test_retain_grad_inplace", # retains_grad_hooks
|
||||
"test_retain_grad_inplace_over_view", # retains_grad_hooks
|
||||
"test_retains_grad_can_always_observe_tensor_prehook", # retains_grad_hooks
|
||||
"test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks
|
||||
"test_hook_edge_case_when_called_with_grad", # retains_grad_hooks
|
||||
"test_multi_grad_all_hooks", # retains_grad_hooks
|
||||
"test_prehook_ordering", # retains_grad_hooks
|
||||
"test_will_engine_execute_node", # retains_grad_hooks
|
||||
"test_backward_to_node", # retains_grad_hooks
|
||||
}
|
||||
|
||||
test_contexts = {
|
||||
@ -4173,11 +4184,6 @@ known_failing_tests = {
|
||||
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
|
||||
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
|
||||
"test_post_accumulate_grad_hook_ordering", # accuracy error
|
||||
"test_retain_grad_cycle", # retains_grad_hooks
|
||||
"test_retain_grad_inplace", # retains_grad_hooks
|
||||
"test_retain_grad_inplace_over_view", # retains_grad_hooks
|
||||
"test_retains_grad_can_always_observe_tensor_prehook", # retains_grad_hooks
|
||||
"test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks
|
||||
"test_accumulate_grad", # create_graph
|
||||
"test_anomaly_assign_parent_cleanup", # create_graph
|
||||
"test_backward_create_graph_warns", # create_graph
|
||||
@ -4198,19 +4204,13 @@ known_failing_tests = {
|
||||
"test_grad_nonleaf", # create_graph
|
||||
"test_grad_nonleaf_many_outputs", # create_graph
|
||||
"test_hessian_vector", # create_graph
|
||||
"test_hook_edge_case_when_called_with_grad", # retains_grad_hooks
|
||||
"test_inplace_on_view_backward", # create_graph
|
||||
"test_multi_grad_any_hooks", # register_multi_grad_hook
|
||||
"test_multi_grad_all_hooks", # retains_grad_hooks
|
||||
"test_nested_anomaly_detect_nan", # create_graph
|
||||
"test_nested_anomaly_printstack_cleanup", # create_graph
|
||||
"test_once_differentiable", # create_graph
|
||||
"test_prehook_ordering", # retains_grad_hooks
|
||||
"test_retain_grad", # retains_grad_hooks
|
||||
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph
|
||||
"test_select_sum", # create_graph, also needs graph breaks
|
||||
"test_will_engine_execute_node", # retains_grad_hooks
|
||||
"test_backward_to_node", # retains_grad_hooks NYI
|
||||
"test_custom_autograd_no_early_free", # create_graph
|
||||
"test_custom_function_error", # vjp
|
||||
"test_custom_function_save_for_forward", # vjp
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Callable
|
||||
|
||||
from torch import Tensor
|
||||
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
|
||||
|
||||
def set_autograd_compiler(
|
||||
@ -9,3 +10,4 @@ def set_autograd_compiler(
|
||||
def clear_cache() -> None: ...
|
||||
def is_cache_empty() -> bool: ...
|
||||
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
|
||||
def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ...
|
||||
|
@ -734,6 +734,18 @@ class AutogradCompilerInstance:
|
||||
self.bind_objects_to_proxies([inputs[i]], [proxy])
|
||||
return inputs
|
||||
|
||||
def cpp_tensor_pre_hook(self, inputs: list[torch.Tensor], hook_id: int, i: int):
|
||||
proxy = self.fx_tracer.create_proxy(
|
||||
"call_function",
|
||||
torch._C._dynamo.compiled_autograd.call_cpp_tensor_pre_hooks,
|
||||
(hook_id, self.to_proxy(inputs[i])),
|
||||
{},
|
||||
)
|
||||
with disable_proxy_modes_tracing():
|
||||
inputs[i] = maybe_clone(inputs[i])
|
||||
self.bind_objects_to_proxies([inputs[i]], [proxy])
|
||||
return inputs
|
||||
|
||||
def pre_hook(self, inputs, hook_id):
|
||||
assert self.hooks_proxy is not None
|
||||
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||||
|
@ -64,4 +64,9 @@ variable_list CppFunctionSingleTensorPreHook::operator()(
|
||||
return results;
|
||||
}
|
||||
|
||||
void CppFunctionSingleTensorPreHook::compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
args.add_cpp_single_tensor_pre_hook(hook_, value_idx_);
|
||||
}
|
||||
|
||||
} // namespace torch::autograd
|
||||
|
@ -22,6 +22,9 @@ struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
|
||||
size_t value_idx);
|
||||
variable_list operator()(const variable_list& values) override;
|
||||
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
|
||||
std::function<at::TensorBase(const at::TensorBase&)> hook_;
|
||||
size_t value_idx_;
|
||||
};
|
||||
|
@ -164,6 +164,7 @@ struct NodeCall {
|
||||
uint32_t id;
|
||||
std::shared_ptr<Node> node;
|
||||
std::vector<std::pair<int, int>> tensor_pre_hooks;
|
||||
std::vector<std::pair<int, int>> cpp_tensor_pre_hooks;
|
||||
std::vector<int> pre_hooks;
|
||||
std::vector<int> post_hooks;
|
||||
std::vector<int> post_acc_grad_hooks;
|
||||
@ -333,6 +334,12 @@ struct AutogradCompilerCall {
|
||||
return hooks.size() - 1;
|
||||
}
|
||||
|
||||
size_t emplace_cpp_tensor_pre_hook(
|
||||
std::function<at::TensorBase(const at::TensorBase&)>&& fn) {
|
||||
cpp_tensor_pre_hooks.emplace_back(std::move(fn));
|
||||
return cpp_tensor_pre_hooks.size() - 1;
|
||||
}
|
||||
|
||||
size_t emplace_packed_input(c10::SafePyObject&& input) {
|
||||
packed_inputs.emplace_back(std::move(input));
|
||||
return packed_inputs.size() - 1;
|
||||
@ -348,6 +355,8 @@ struct AutogradCompilerCall {
|
||||
LiftedIValueArgs lifted_ivalue_args;
|
||||
std::vector<int64_t> dyn_size_inputs;
|
||||
std::vector<c10::SafePyObject> hooks;
|
||||
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>
|
||||
cpp_tensor_pre_hooks;
|
||||
std::vector<c10::SafePyObject> packed_inputs;
|
||||
NodeCalls node_calls;
|
||||
SizeInput::DynType default_dyn_type;
|
||||
@ -602,12 +611,12 @@ class CompiledNodeArgs {
|
||||
#undef COLLECT_AS_BYTES
|
||||
|
||||
void collect_hooks_from(Node* fn) {
|
||||
TORCH_CHECK(
|
||||
fn->retains_grad_hooks().empty(),
|
||||
"retains_grad_hooks not implemented for compiled autograd");
|
||||
for (auto& i : fn->tensor_pre_hooks()) {
|
||||
i->compiled_args(*this);
|
||||
}
|
||||
for (auto& [_, i] : fn->retains_grad_hooks()) {
|
||||
i->compiled_args(*this);
|
||||
}
|
||||
for (auto& i : fn->pre_hooks()) {
|
||||
i->compiled_args(*this);
|
||||
}
|
||||
@ -647,6 +656,23 @@ class CompiledNodeArgs {
|
||||
_node_call.tensor_pre_hooks.emplace_back(fn_id, index);
|
||||
}
|
||||
|
||||
void add_cpp_single_tensor_pre_hook(
|
||||
const std::function<at::TensorBase(const at::TensorBase&)>& hook,
|
||||
size_t idx) {
|
||||
auto wrapper = [hook](const at::TensorBase& grad) {
|
||||
// handle when hook returns nothing
|
||||
auto out = hook(grad);
|
||||
if (!out.defined()) {
|
||||
return grad;
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
auto hook_id = _compiler.emplace_cpp_tensor_pre_hook(std::move(wrapper));
|
||||
collect_size(hook_id);
|
||||
_node_call.cpp_tensor_pre_hooks.emplace_back(hook_id, idx);
|
||||
}
|
||||
|
||||
void add_pre_hook(c10::SafePyObject&& obj) {
|
||||
auto fn_id = _compiler.emplace_hook(std::move(obj));
|
||||
collect_size(fn_id);
|
||||
|
@ -58,6 +58,63 @@ int default_dyn_type_int = 0;
|
||||
PyObject* python_verbose_logger = nullptr;
|
||||
} // namespace
|
||||
|
||||
// see https://github.com/pytorch/pytorch/pull/34845
|
||||
static void throw_python_error() {
|
||||
python_error err;
|
||||
err.persist();
|
||||
throw std::move(err);
|
||||
}
|
||||
|
||||
// RuntimeState contains arbitrary callables created during the forward pass.
|
||||
// e.g. .retains_grad(). It is created during the compiled_args stage, and is
|
||||
// used at runtime. The lifetime of RuntimeState is a single backward pass.
|
||||
struct RuntimeState {
|
||||
at::TensorBase call_cpp_tensor_pre_hooks(
|
||||
size_t idx,
|
||||
const at::TensorBase& grad) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
cpp_tensor_pre_hooks.size() > static_cast<size_t>(idx));
|
||||
return cpp_tensor_pre_hooks[idx](grad);
|
||||
}
|
||||
|
||||
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>
|
||||
cpp_tensor_pre_hooks;
|
||||
size_t next_id = 0;
|
||||
};
|
||||
|
||||
static RuntimeState* active_rstate;
|
||||
struct RuntimeStateGuard {
|
||||
RuntimeStateGuard() : _state(std::make_unique<RuntimeState>()) {
|
||||
active_rstate = _state.get();
|
||||
}
|
||||
RuntimeStateGuard(const RuntimeStateGuard&) = delete;
|
||||
RuntimeStateGuard& operator=(const RuntimeStateGuard&) = delete;
|
||||
RuntimeStateGuard(RuntimeStateGuard&&) = delete;
|
||||
RuntimeStateGuard& operator=(RuntimeStateGuard&&) = delete;
|
||||
|
||||
~RuntimeStateGuard() {
|
||||
active_rstate = nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<RuntimeState> _state;
|
||||
};
|
||||
|
||||
static PyObject* call_cpp_tensor_pre_hooks(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
int idx = -1;
|
||||
PyObject* grad = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "iO", &idx, &grad)) {
|
||||
throw_python_error();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(idx > -1);
|
||||
TORCH_INTERNAL_ASSERT(grad != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(active_rstate != nullptr);
|
||||
auto res = active_rstate->call_cpp_tensor_pre_hooks(
|
||||
static_cast<size_t>(idx), THPVariable_Unpack(grad));
|
||||
return THPVariable_Wrap(res);
|
||||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
// List[Optional[Tensor]] in Python can't be directly parsed into a
|
||||
// List[Tensor], so we need to do this conversion manually.
|
||||
static std::vector<at::Tensor> toTensorList(
|
||||
@ -253,13 +310,6 @@ static PyObject* convert_pyobj_list(std::vector<c10::SafePyObject>& inputs) {
|
||||
return pyinput;
|
||||
}
|
||||
|
||||
// see https://github.com/pytorch/pytorch/pull/34845
|
||||
static void throw_python_error() {
|
||||
python_error err;
|
||||
err.persist();
|
||||
throw std::move(err);
|
||||
}
|
||||
|
||||
static PyObject* check(PyObject* pyresult) {
|
||||
if (C10_UNLIKELY(pyresult == nullptr)) {
|
||||
throw_python_error();
|
||||
@ -608,6 +658,10 @@ static PyMethodDef _methods[] = {
|
||||
{"clear_cache", clear_cache, METH_NOARGS, nullptr},
|
||||
{"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr},
|
||||
{"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr},
|
||||
{"call_cpp_tensor_pre_hooks",
|
||||
call_cpp_tensor_pre_hooks,
|
||||
METH_VARARGS,
|
||||
nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
static struct PyModuleDef _module = {
|
||||
@ -827,7 +881,8 @@ static CacheNode* _compiled_autograd_impl(
|
||||
THPObjectPtr* graph_arg_sizes,
|
||||
THPObjectPtr* graph_arg_ivalue_args,
|
||||
THPObjectPtr* graph_arg_hooks,
|
||||
THPObjectPtr* graph_arg_packed_inputs) {
|
||||
THPObjectPtr* graph_arg_packed_inputs,
|
||||
RuntimeState* rstate) {
|
||||
const std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
||||
std::unordered_map<Node*, int> visited_dependencies;
|
||||
visited_dependencies.reserve(dependencies.size());
|
||||
@ -963,6 +1018,20 @@ static CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
inputs = THPVariable_UnpackList(pyinputs);
|
||||
}
|
||||
if (!call.cpp_tensor_pre_hooks.empty()) {
|
||||
// proxy a call to runtimestate
|
||||
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
|
||||
for (const auto& [hook_id, idx] : call.cpp_tensor_pre_hooks) {
|
||||
pyinputs = check(PyObject_CallMethod(
|
||||
py_compiler,
|
||||
"cpp_tensor_pre_hook",
|
||||
"Oii",
|
||||
pyinputs.get(),
|
||||
hook_id,
|
||||
idx));
|
||||
}
|
||||
inputs = THPVariable_UnpackList(pyinputs);
|
||||
}
|
||||
for (const auto& graph_output : call.graph_output) {
|
||||
int input_nr = graph_output.first;
|
||||
int output_index = graph_output.second;
|
||||
@ -1090,6 +1159,7 @@ static CacheNode* _compiled_autograd_impl(
|
||||
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
|
||||
*graph_arg_hooks = convert_pyobj_list(compiler_call.hooks);
|
||||
*graph_arg_packed_inputs = convert_pyobj_list(compiler_call.packed_inputs);
|
||||
rstate->cpp_tensor_pre_hooks = std::move(compiler_call.cpp_tensor_pre_hooks);
|
||||
return cache;
|
||||
}
|
||||
|
||||
@ -1125,6 +1195,7 @@ static variable_list compiled_autograd(
|
||||
LockGuardWithErrorLogs lock_guard(mtx);
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::ThreadLocalStateGuard tls_guard(graph_task.thread_locals_);
|
||||
RuntimeStateGuard rstate_guard;
|
||||
|
||||
THPObjectPtr inputs;
|
||||
THPObjectPtr sizes;
|
||||
@ -1140,7 +1211,8 @@ static variable_list compiled_autograd(
|
||||
&sizes,
|
||||
&ivalue_args,
|
||||
&hooks,
|
||||
&packed_inputs);
|
||||
&packed_inputs,
|
||||
active_rstate);
|
||||
|
||||
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
|
||||
cache->runtime_wrapper.get(),
|
||||
|
Reference in New Issue
Block a user