mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca] side-effect free inital trace: compiled_args (#147804)
const methods to prevent accidental mutation. changes mainly in Error nodes and PyNode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147804 Approved by: https://github.com/jansel ghstack dependencies: #147242, #147796
This commit is contained in:
committed by
PyTorch MergeBot
parent
5758743f3c
commit
ec768d8dc0
@ -67,7 +67,7 @@ struct TORCH_API ${op} : public ${superclass} {
|
||||
${release_variables}
|
||||
}
|
||||
${will_release_variables}
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
|
||||
${saved_variables}
|
||||
${saved_list_sizes}
|
||||
@ -127,7 +127,7 @@ variable_list ${op}::apply(variable_list&& grads) {
|
||||
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
|
||||
}
|
||||
|
||||
void ${op}::compiled_args(CompiledNodeArgs& args) {
|
||||
void ${op}::compiled_args(CompiledNodeArgs& args) const {
|
||||
${compiled_args}
|
||||
}
|
||||
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
|
||||
|
@ -284,7 +284,7 @@ struct CppNode : public Node {
|
||||
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
|
||||
void save_variables_to_ctx();
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override {
|
||||
void compiled_args(CompiledNodeArgs& args) const override {
|
||||
// although neither of the 2 methods below have uniqueness guarantees
|
||||
// it is unlikely for them to collide at the same time
|
||||
args.collect(static_cast<uint64_t>(typeid(T).hash_code()));
|
||||
|
@ -545,8 +545,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return tensor_pre_hooks_;
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<PostAccumulateGradHook>&
|
||||
tensor_post_acc_grad_hooks() noexcept {
|
||||
virtual std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
|
||||
const noexcept {
|
||||
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
|
||||
return empty;
|
||||
}
|
||||
@ -593,7 +593,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
// 2) Collect node information for specialization and caching
|
||||
// Implementations in subclasses should call args.collect() with all node
|
||||
// attrs. These functions are only called durring backward.
|
||||
virtual void compiled_args(CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args not implemented: ") + name());
|
||||
}
|
||||
|
@ -22,7 +22,8 @@ struct TORCH_API FunctionPreHook {
|
||||
virtual ~FunctionPreHook() = default;
|
||||
virtual variable_list operator()(const variable_list& grads) = 0;
|
||||
// only implemented for python hooks, registers hook with compiled autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
@ -35,7 +36,8 @@ struct TORCH_API FunctionPostHook {
|
||||
const variable_list& outputs /* grad_inputs */,
|
||||
const variable_list& inputs /* grad_outputs */) = 0;
|
||||
// only implemented for python hooks, registers hook with compiled autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
@ -47,7 +49,8 @@ struct TORCH_API PostAccumulateGradHook {
|
||||
virtual void operator()(const Variable& tensor) = 0;
|
||||
// only implemented for python hooks on nodes, registers hook with compiled
|
||||
// autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("not yet implemented for compiled autograd: ") +
|
||||
typeid(*this).name());
|
||||
|
@ -66,12 +66,12 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
|
||||
return variable_list();
|
||||
}
|
||||
|
||||
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) {
|
||||
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
|
||||
if (args.cond(variable.defined() && variable.requires_grad())) {
|
||||
args.collect(variable);
|
||||
args.collect(variable.grad());
|
||||
}
|
||||
auto& hook = tensor_post_acc_grad_hooks();
|
||||
const auto& hook = tensor_post_acc_grad_hooks();
|
||||
if (hook != nullptr) {
|
||||
hook->compiled_args(args);
|
||||
}
|
||||
|
@ -50,8 +50,8 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
return impl::hooks(variable);
|
||||
}
|
||||
|
||||
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept
|
||||
override {
|
||||
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
|
||||
const noexcept override {
|
||||
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
|
||||
// it can be destroyed even though the Tensor is still alive (contrary
|
||||
// to all other Nodes). So we must lazily read the Tensor hooks here.
|
||||
@ -262,7 +262,7 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
}
|
||||
}
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
|
@ -12,11 +12,15 @@
|
||||
|
||||
namespace torch::autograd {
|
||||
|
||||
auto Error::apply(variable_list&& inputs) -> variable_list {
|
||||
variable_list Error::apply(variable_list&& inputs) {
|
||||
return static_cast<const Error*>(this)->apply(std::move(inputs));
|
||||
}
|
||||
|
||||
variable_list Error::apply(variable_list&& inputs) const {
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
|
||||
void Error::compiled_args(CompiledNodeArgs& args) {
|
||||
void Error::compiled_args(CompiledNodeArgs& args) const {
|
||||
// throw the error durring collect, the graph won't get compiled
|
||||
apply(variable_list());
|
||||
}
|
||||
@ -66,7 +70,7 @@ auto Identity::apply(variable_list&& grads) -> variable_list {
|
||||
return std::move(grads);
|
||||
}
|
||||
|
||||
void GraphRoot::compiled_args(CompiledNodeArgs& args) {
|
||||
void GraphRoot::compiled_args(CompiledNodeArgs& args) const {
|
||||
args.collect(outputs);
|
||||
}
|
||||
variable_list GraphRoot::apply_with_saved(
|
||||
|
@ -18,8 +18,9 @@ struct TORCH_API Error : public Node {
|
||||
Error(std::string msg) : msg(std::move(msg)) {}
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list apply(variable_list&& inputs) const;
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
@ -51,6 +52,7 @@ struct TORCH_API DelayedError : public Node {
|
||||
}
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list apply(variable_list&& inputs) const;
|
||||
|
||||
std::string msg;
|
||||
};
|
||||
@ -61,6 +63,7 @@ struct TORCH_API UndefinedGrad : public Node {
|
||||
}
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list apply(variable_list&& inputs) const;
|
||||
};
|
||||
|
||||
struct TORCH_API UndefinedGradBackward : public Node {
|
||||
@ -69,8 +72,9 @@ struct TORCH_API UndefinedGradBackward : public Node {
|
||||
UndefinedGradBackward() = default;
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list apply(variable_list&& inputs) const;
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override {}
|
||||
void compiled_args(CompiledNodeArgs& args) const override {}
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override {
|
||||
@ -93,7 +97,7 @@ struct TORCH_API GraphRoot : public Node {
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
|
@ -60,7 +60,7 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
|
||||
src_options);
|
||||
}
|
||||
|
||||
void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
|
||||
void CopyBackwards::compiled_args(CompiledNodeArgs& args) const {
|
||||
args.collect(src_options);
|
||||
}
|
||||
|
||||
@ -235,7 +235,7 @@ void CopySlices::release_variables() {
|
||||
fn = nullptr;
|
||||
}
|
||||
|
||||
void CopySlices::compiled_args(CompiledNodeArgs& args) {
|
||||
void CopySlices::compiled_args(CompiledNodeArgs& args) const {
|
||||
TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
|
||||
TORCH_INTERNAL_ASSERT((bool)fn);
|
||||
args.collect(base);
|
||||
|
@ -15,7 +15,7 @@ namespace torch::autograd {
|
||||
|
||||
struct TORCH_API CopyBackwards : public Node {
|
||||
variable_list apply(variable_list&& grads) override;
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
@ -168,7 +168,7 @@ struct TORCH_API CopySlices : public Node {
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
void release_variables() override;
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
|
@ -34,6 +34,7 @@
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/tensor_dtypes.h>
|
||||
|
||||
#include <autograd/function.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
@ -185,9 +186,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
return to_variable_list(r.get(), is_variable_input);
|
||||
}
|
||||
|
||||
auto PyNode::defer_to_dynamo(
|
||||
auto PyNode::apply_with_saved_impl(
|
||||
const variable_list& inputs,
|
||||
const std::optional<PyObject*>& compiler) -> variable_list {
|
||||
const SwapSavedVariables& saved) -> variable_list {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::OptionalDeviceGuard _device_guard;
|
||||
THPFunction* py_fn = (THPFunction*)obj;
|
||||
@ -235,24 +236,24 @@ auto PyNode::defer_to_dynamo(
|
||||
}
|
||||
THPObjectPtr saved_tensors(unpack_saved_variables(
|
||||
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
_backward_idx.has_value(),
|
||||
"indices should already be set by compiled_args, called before apply_with_saved");
|
||||
|
||||
auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this);
|
||||
|
||||
PyObject* backward_state_idx = Py_None;
|
||||
if (_backward_state_idx.has_value()) {
|
||||
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value());
|
||||
if (maybe_bwd_state_idx.has_value()) {
|
||||
backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value());
|
||||
// this might be simplifiable now that we no longer inline
|
||||
Py_CLEAR(py_fn->compiled_autograd_backward_state);
|
||||
}
|
||||
THPObjectPtr r(PyObject_CallMethod(
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
compiler.value(),
|
||||
saved.get_py_compiler(),
|
||||
"proxy_call_backward",
|
||||
"OOOiOO",
|
||||
pyInputs.get(),
|
||||
fwdInputMetadatas.get(),
|
||||
saved_tensors.get(),
|
||||
*_backward_idx,
|
||||
bwd_idx,
|
||||
obj,
|
||||
backward_state_idx));
|
||||
|
||||
@ -301,7 +302,7 @@ bool PyNode::is_aot_backward() const {
|
||||
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
|
||||
}
|
||||
|
||||
void PyNode::compiled_args(CompiledNodeArgs& args) {
|
||||
void PyNode::compiled_args(CompiledNodeArgs& args) const {
|
||||
static PyObject* method_name =
|
||||
PyUnicode_InternFromString("_compiled_autograd_key");
|
||||
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
|
||||
@ -346,14 +347,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
|
||||
args.collect(f->input_info);
|
||||
|
||||
Py_INCREF(obj);
|
||||
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
|
||||
|
||||
c10::SafePyObject backward_obj(obj, getPyInterpreter());
|
||||
std::optional<c10::SafePyObject> backward_state_obj;
|
||||
PyObject* bw_state = f->compiled_autograd_backward_state;
|
||||
if (args.cond(bw_state != nullptr)) {
|
||||
Py_INCREF(bw_state);
|
||||
_backward_state_idx = args.add_backward_state(
|
||||
c10::SafePyObject(bw_state, getPyInterpreter()));
|
||||
backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter());
|
||||
}
|
||||
args.collect_pynode_objs(
|
||||
this, std::move(backward_obj), std::move(backward_state_obj));
|
||||
}
|
||||
|
||||
variable_list PyNode::apply_with_saved(
|
||||
@ -368,8 +370,7 @@ variable_list PyNode::apply_with_saved(
|
||||
saved.before(f->output_info);
|
||||
saved.before(f->input_info);
|
||||
f->compiled_autograd_tracing = true;
|
||||
variable_list result =
|
||||
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
|
||||
variable_list result = apply_with_saved_impl(variable_list(inputs), saved);
|
||||
f->compiled_autograd_tracing = false;
|
||||
saved.after(f->compiled_autograd_symints);
|
||||
saved.after(f->saved_variables);
|
||||
|
@ -35,9 +35,9 @@ struct PyNode : public Node {
|
||||
const std::vector<bool>& is_variable_input);
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
variable_list defer_to_dynamo(
|
||||
variable_list apply_with_saved_impl(
|
||||
const variable_list& inputs,
|
||||
const std::optional<PyObject*>& compiler);
|
||||
const SwapSavedVariables& saved);
|
||||
|
||||
void release_variables() override;
|
||||
std::string name() const override;
|
||||
@ -45,7 +45,7 @@ struct PyNode : public Node {
|
||||
|
||||
bool is_aot_backward() const override;
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) const override;
|
||||
variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) override;
|
||||
@ -53,13 +53,6 @@ struct PyNode : public Node {
|
||||
// THPFunction this Function is wrapping. Owning!
|
||||
PyObject* obj;
|
||||
|
||||
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
|
||||
std::optional<int> _backward_idx;
|
||||
|
||||
// The AutogradCompilerCall::hooks idx corresponding to this node's
|
||||
// backward_state
|
||||
std::optional<int> _backward_state_idx;
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
~PyNode() override {
|
||||
// Can't use THPObjectPtr as a field in this class; destructor won't take
|
||||
|
@ -176,7 +176,7 @@ auto PyFunctionPostHook::operator()(
|
||||
return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
|
||||
}
|
||||
|
||||
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
|
||||
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) const {
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||
@ -189,7 +189,7 @@ void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
|
||||
Py_END_CRITICAL_SECTION();
|
||||
}
|
||||
|
||||
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
|
||||
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) const {
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||
@ -200,7 +200,7 @@ void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
|
||||
Py_END_CRITICAL_SECTION();
|
||||
}
|
||||
|
||||
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
|
||||
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) const {
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||
@ -237,7 +237,7 @@ auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
|
||||
}
|
||||
|
||||
void PyFunctionTensorPostAccGradHooks::compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
PyObject *key = nullptr, *value = nullptr;
|
||||
Py_ssize_t pos = 0;
|
||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||
|
@ -14,7 +14,8 @@ struct PyFunctionTensorPreHook : public FunctionPreHook {
|
||||
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
|
||||
~PyFunctionTensorPreHook() override;
|
||||
variable_list operator()(const variable_list& values) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
PyObject* dict;
|
||||
size_t value_idx;
|
||||
};
|
||||
@ -23,7 +24,8 @@ struct PyFunctionPreHook : public FunctionPreHook {
|
||||
PyFunctionPreHook(PyObject* dict);
|
||||
~PyFunctionPreHook() override;
|
||||
variable_list operator()(const variable_list& values) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
PyObject* dict;
|
||||
};
|
||||
|
||||
@ -33,7 +35,8 @@ struct PyFunctionPostHook : public FunctionPostHook {
|
||||
variable_list operator()(
|
||||
const variable_list& outputs,
|
||||
const variable_list& inputs) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
PyObject* dict;
|
||||
};
|
||||
|
||||
@ -45,7 +48,8 @@ struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
|
||||
PyFunctionTensorPostAccGradHooks(PyObject* dict);
|
||||
~PyFunctionTensorPostAccGradHooks() override;
|
||||
void operator()(const Variable& tensor) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||
void apply_with_saved(
|
||||
Variable& tensor,
|
||||
torch::dynamo::autograd::SwapSavedVariables& saved) override;
|
||||
|
@ -27,7 +27,7 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
|
||||
return fn_(outputs, inputs);
|
||||
}
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override {}
|
||||
void compiled_args(CompiledNodeArgs& args) const override {}
|
||||
|
||||
protected:
|
||||
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
|
||||
|
@ -183,21 +183,23 @@ Reducer::Reducer(
|
||||
#endif
|
||||
// Hook to execute after the gradient accumulator has executed.
|
||||
hooks_.emplace_back(
|
||||
grad_accumulator->add_post_hook(
|
||||
std::make_unique<torch::autograd::utils::LambdaPostHook>(
|
||||
[this, variable_index](
|
||||
const torch::autograd::variable_list& outputs,
|
||||
const torch::autograd::variable_list& /* unused */) {
|
||||
grad_accumulator->add_post_hook(std::make_unique<
|
||||
torch::autograd::utils::
|
||||
LambdaPostHook>(
|
||||
[this, variable_index](
|
||||
const torch::autograd::variable_list& outputs,
|
||||
const torch::autograd::variable_list& /* unused */) {
|
||||
#ifndef _WIN32
|
||||
this->rpc_context_.set(
|
||||
ThreadLocalDistAutogradContext::getContextPtr());
|
||||
this->rpc_context_.set(
|
||||
ThreadLocalDistAutogradContext::getContextPtr());
|
||||
#endif
|
||||
this->autograd_hook(variable_index);
|
||||
return outputs;
|
||||
},
|
||||
[=](torch::autograd::CompiledNodeArgs& args) {
|
||||
// Make post_hook an noop if compiled_autograds is enabled.
|
||||
})),
|
||||
this->autograd_hook(variable_index);
|
||||
return outputs;
|
||||
},
|
||||
[=](torch::autograd::CompiledNodeArgs& args) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
|
||||
})),
|
||||
grad_accumulator);
|
||||
|
||||
// Map raw function pointer to parameter index.
|
||||
|
@ -343,6 +343,9 @@ struct AutogradCompilerCall {
|
||||
std::vector<uint32_t> size_input_origins;
|
||||
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
|
||||
sv_to_hooks;
|
||||
// pynode -> backward and backward state idx
|
||||
std::unordered_map<const Node*, std::pair<size_t, std::optional<size_t>>>
|
||||
pynode_objs;
|
||||
};
|
||||
|
||||
class CompiledNodeArgs {
|
||||
@ -613,12 +616,17 @@ class CompiledNodeArgs {
|
||||
typeid(*node), _specialization_key, _specialization_key_size);
|
||||
}
|
||||
|
||||
size_t add_backward(c10::SafePyObject&& obj) {
|
||||
return _compiler.emplace_hook(std::move(obj));
|
||||
}
|
||||
|
||||
size_t add_backward_state(c10::SafePyObject&& obj) {
|
||||
return _compiler.emplace_hook(std::move(obj));
|
||||
void collect_pynode_objs(
|
||||
const Node* pynode,
|
||||
c10::SafePyObject&& bwd,
|
||||
std::optional<c10::SafePyObject>&& bwd_state) {
|
||||
size_t bwd_idx = _compiler.emplace_hook(std::move(bwd));
|
||||
std::optional<size_t> bwd_state_idx;
|
||||
if (auto state = std::move(bwd_state); state.has_value()) {
|
||||
bwd_state_idx = _compiler.emplace_hook(std::move(state.value()));
|
||||
}
|
||||
_compiler.pynode_objs.emplace(
|
||||
pynode, std::make_pair(bwd_idx, bwd_state_idx));
|
||||
}
|
||||
|
||||
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
|
||||
@ -737,6 +745,13 @@ class SwapSavedVariables {
|
||||
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
|
||||
// allows tracing to happen, then swaps them back afterwards.
|
||||
public:
|
||||
std::pair<size_t, std::optional<size_t>> retrieve_pynode_objs(
|
||||
Node* pynode) const {
|
||||
auto it = compiler.pynode_objs.find(pynode);
|
||||
TORCH_INTERNAL_ASSERT(it != compiler.pynode_objs.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void before(at::Tensor& t) {
|
||||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
||||
stashed_tensors.save(&t, std::move(t));
|
||||
@ -942,7 +957,7 @@ class SwapSavedVariables {
|
||||
const NodeCall& n)
|
||||
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
|
||||
|
||||
PyObject* get_py_compiler() {
|
||||
PyObject* get_py_compiler() const {
|
||||
return py_compiler;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user