mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca] side-effect free inital trace: compiled_args (#147804)
const methods to prevent accidental mutation. changes mainly in Error nodes and PyNode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147804 Approved by: https://github.com/jansel ghstack dependencies: #147242, #147796
This commit is contained in:
committed by
PyTorch MergeBot
parent
5e3069dde8
commit
fd1220e386
@ -67,7 +67,7 @@ struct TORCH_API ${op} : public ${superclass} {
|
|||||||
${release_variables}
|
${release_variables}
|
||||||
}
|
}
|
||||||
${will_release_variables}
|
${will_release_variables}
|
||||||
void compiled_args(CompiledNodeArgs& args) override;
|
void compiled_args(CompiledNodeArgs& args) const override;
|
||||||
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
|
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
|
||||||
${saved_variables}
|
${saved_variables}
|
||||||
${saved_list_sizes}
|
${saved_list_sizes}
|
||||||
@ -127,7 +127,7 @@ variable_list ${op}::apply(variable_list&& grads) {
|
|||||||
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
|
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
|
||||||
}
|
}
|
||||||
|
|
||||||
void ${op}::compiled_args(CompiledNodeArgs& args) {
|
void ${op}::compiled_args(CompiledNodeArgs& args) const {
|
||||||
${compiled_args}
|
${compiled_args}
|
||||||
}
|
}
|
||||||
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
|
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
|
||||||
|
@ -284,7 +284,7 @@ struct CppNode : public Node {
|
|||||||
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
|
void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
|
||||||
void save_variables_to_ctx();
|
void save_variables_to_ctx();
|
||||||
|
|
||||||
void compiled_args(CompiledNodeArgs& args) override {
|
void compiled_args(CompiledNodeArgs& args) const override {
|
||||||
// although neither of the 2 methods below have uniqueness guarantees
|
// although neither of the 2 methods below have uniqueness guarantees
|
||||||
// it is unlikely for them to collide at the same time
|
// it is unlikely for them to collide at the same time
|
||||||
args.collect(static_cast<uint64_t>(typeid(T).hash_code()));
|
args.collect(static_cast<uint64_t>(typeid(T).hash_code()));
|
||||||
|
@ -545,8 +545,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
|||||||
return tensor_pre_hooks_;
|
return tensor_pre_hooks_;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::unique_ptr<PostAccumulateGradHook>&
|
virtual std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
|
||||||
tensor_post_acc_grad_hooks() noexcept {
|
const noexcept {
|
||||||
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
|
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
|
||||||
return empty;
|
return empty;
|
||||||
}
|
}
|
||||||
@ -593,7 +593,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
|||||||
// 2) Collect node information for specialization and caching
|
// 2) Collect node information for specialization and caching
|
||||||
// Implementations in subclasses should call args.collect() with all node
|
// Implementations in subclasses should call args.collect() with all node
|
||||||
// attrs. These functions are only called durring backward.
|
// attrs. These functions are only called durring backward.
|
||||||
virtual void compiled_args(CompiledNodeArgs& args) {
|
virtual void compiled_args(CompiledNodeArgs& args) const {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
std::string("compiled_args not implemented: ") + name());
|
std::string("compiled_args not implemented: ") + name());
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,8 @@ struct TORCH_API FunctionPreHook {
|
|||||||
virtual ~FunctionPreHook() = default;
|
virtual ~FunctionPreHook() = default;
|
||||||
virtual variable_list operator()(const variable_list& grads) = 0;
|
virtual variable_list operator()(const variable_list& grads) = 0;
|
||||||
// only implemented for python hooks, registers hook with compiled autograd
|
// only implemented for python hooks, registers hook with compiled autograd
|
||||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
virtual void compiled_args(
|
||||||
|
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||||
typeid(*this).name());
|
typeid(*this).name());
|
||||||
@ -35,7 +36,8 @@ struct TORCH_API FunctionPostHook {
|
|||||||
const variable_list& outputs /* grad_inputs */,
|
const variable_list& outputs /* grad_inputs */,
|
||||||
const variable_list& inputs /* grad_outputs */) = 0;
|
const variable_list& inputs /* grad_outputs */) = 0;
|
||||||
// only implemented for python hooks, registers hook with compiled autograd
|
// only implemented for python hooks, registers hook with compiled autograd
|
||||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
virtual void compiled_args(
|
||||||
|
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||||
typeid(*this).name());
|
typeid(*this).name());
|
||||||
@ -47,7 +49,8 @@ struct TORCH_API PostAccumulateGradHook {
|
|||||||
virtual void operator()(const Variable& tensor) = 0;
|
virtual void operator()(const Variable& tensor) = 0;
|
||||||
// only implemented for python hooks on nodes, registers hook with compiled
|
// only implemented for python hooks on nodes, registers hook with compiled
|
||||||
// autograd
|
// autograd
|
||||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
virtual void compiled_args(
|
||||||
|
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
std::string("not yet implemented for compiled autograd: ") +
|
std::string("not yet implemented for compiled autograd: ") +
|
||||||
typeid(*this).name());
|
typeid(*this).name());
|
||||||
|
@ -66,12 +66,12 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
|
|||||||
return variable_list();
|
return variable_list();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) {
|
void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
|
||||||
if (args.cond(variable.defined() && variable.requires_grad())) {
|
if (args.cond(variable.defined() && variable.requires_grad())) {
|
||||||
args.collect(variable);
|
args.collect(variable);
|
||||||
args.collect(variable.grad());
|
args.collect(variable.grad());
|
||||||
}
|
}
|
||||||
auto& hook = tensor_post_acc_grad_hooks();
|
const auto& hook = tensor_post_acc_grad_hooks();
|
||||||
if (hook != nullptr) {
|
if (hook != nullptr) {
|
||||||
hook->compiled_args(args);
|
hook->compiled_args(args);
|
||||||
}
|
}
|
||||||
|
@ -50,8 +50,8 @@ struct TORCH_API AccumulateGrad : public Node {
|
|||||||
return impl::hooks(variable);
|
return impl::hooks(variable);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept
|
std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks()
|
||||||
override {
|
const noexcept override {
|
||||||
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
|
// NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
|
||||||
// it can be destroyed even though the Tensor is still alive (contrary
|
// it can be destroyed even though the Tensor is still alive (contrary
|
||||||
// to all other Nodes). So we must lazily read the Tensor hooks here.
|
// to all other Nodes). So we must lazily read the Tensor hooks here.
|
||||||
@ -262,7 +262,7 @@ struct TORCH_API AccumulateGrad : public Node {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void compiled_args(CompiledNodeArgs& args) override;
|
void compiled_args(CompiledNodeArgs& args) const override;
|
||||||
variable_list apply_with_saved(
|
variable_list apply_with_saved(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
SwapSavedVariables& saved) override;
|
SwapSavedVariables& saved) override;
|
||||||
|
@ -12,11 +12,15 @@
|
|||||||
|
|
||||||
namespace torch::autograd {
|
namespace torch::autograd {
|
||||||
|
|
||||||
auto Error::apply(variable_list&& inputs) -> variable_list {
|
variable_list Error::apply(variable_list&& inputs) {
|
||||||
|
return static_cast<const Error*>(this)->apply(std::move(inputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
variable_list Error::apply(variable_list&& inputs) const {
|
||||||
throw std::runtime_error(msg);
|
throw std::runtime_error(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Error::compiled_args(CompiledNodeArgs& args) {
|
void Error::compiled_args(CompiledNodeArgs& args) const {
|
||||||
// throw the error durring collect, the graph won't get compiled
|
// throw the error durring collect, the graph won't get compiled
|
||||||
apply(variable_list());
|
apply(variable_list());
|
||||||
}
|
}
|
||||||
@ -66,7 +70,7 @@ auto Identity::apply(variable_list&& grads) -> variable_list {
|
|||||||
return std::move(grads);
|
return std::move(grads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphRoot::compiled_args(CompiledNodeArgs& args) {
|
void GraphRoot::compiled_args(CompiledNodeArgs& args) const {
|
||||||
args.collect(outputs);
|
args.collect(outputs);
|
||||||
}
|
}
|
||||||
variable_list GraphRoot::apply_with_saved(
|
variable_list GraphRoot::apply_with_saved(
|
||||||
|
@ -18,8 +18,9 @@ struct TORCH_API Error : public Node {
|
|||||||
Error(std::string msg) : msg(std::move(msg)) {}
|
Error(std::string msg) : msg(std::move(msg)) {}
|
||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
|
variable_list apply(variable_list&& inputs) const;
|
||||||
|
|
||||||
void compiled_args(CompiledNodeArgs& args) override;
|
void compiled_args(CompiledNodeArgs& args) const override;
|
||||||
variable_list apply_with_saved(
|
variable_list apply_with_saved(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
SwapSavedVariables& saved) override;
|
SwapSavedVariables& saved) override;
|
||||||
@ -51,6 +52,7 @@ struct TORCH_API DelayedError : public Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
|
variable_list apply(variable_list&& inputs) const;
|
||||||
|
|
||||||
std::string msg;
|
std::string msg;
|
||||||
};
|
};
|
||||||
@ -61,6 +63,7 @@ struct TORCH_API UndefinedGrad : public Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
|
variable_list apply(variable_list&& inputs) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TORCH_API UndefinedGradBackward : public Node {
|
struct TORCH_API UndefinedGradBackward : public Node {
|
||||||
@ -69,8 +72,9 @@ struct TORCH_API UndefinedGradBackward : public Node {
|
|||||||
UndefinedGradBackward() = default;
|
UndefinedGradBackward() = default;
|
||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
|
variable_list apply(variable_list&& inputs) const;
|
||||||
|
|
||||||
void compiled_args(CompiledNodeArgs& args) override {}
|
void compiled_args(CompiledNodeArgs& args) const override {}
|
||||||
variable_list apply_with_saved(
|
variable_list apply_with_saved(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
SwapSavedVariables& saved) override {
|
SwapSavedVariables& saved) override {
|
||||||
@ -93,7 +97,7 @@ struct TORCH_API GraphRoot : public Node {
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
void compiled_args(CompiledNodeArgs& args) override;
|
void compiled_args(CompiledNodeArgs& args) const override;
|
||||||
variable_list apply_with_saved(
|
variable_list apply_with_saved(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
SwapSavedVariables& saved) override;
|
SwapSavedVariables& saved) override;
|
||||||
|
@ -60,7 +60,7 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
|
|||||||
src_options);
|
src_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
|
void CopyBackwards::compiled_args(CompiledNodeArgs& args) const {
|
||||||
args.collect(src_options);
|
args.collect(src_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ void CopySlices::release_variables() {
|
|||||||
fn = nullptr;
|
fn = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopySlices::compiled_args(CompiledNodeArgs& args) {
|
void CopySlices::compiled_args(CompiledNodeArgs& args) const {
|
||||||
TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
|
TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
|
||||||
TORCH_INTERNAL_ASSERT((bool)fn);
|
TORCH_INTERNAL_ASSERT((bool)fn);
|
||||||
args.collect(base);
|
args.collect(base);
|
||||||
|
@ -15,7 +15,7 @@ namespace torch::autograd {
|
|||||||
|
|
||||||
struct TORCH_API CopyBackwards : public Node {
|
struct TORCH_API CopyBackwards : public Node {
|
||||||
variable_list apply(variable_list&& grads) override;
|
variable_list apply(variable_list&& grads) override;
|
||||||
void compiled_args(CompiledNodeArgs& args) override;
|
void compiled_args(CompiledNodeArgs& args) const override;
|
||||||
variable_list apply_with_saved(
|
variable_list apply_with_saved(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
SwapSavedVariables& saved) override;
|
SwapSavedVariables& saved) override;
|
||||||
@ -168,7 +168,7 @@ struct TORCH_API CopySlices : public Node {
|
|||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
void release_variables() override;
|
void release_variables() override;
|
||||||
void compiled_args(CompiledNodeArgs& args) override;
|
void compiled_args(CompiledNodeArgs& args) const override;
|
||||||
variable_list apply_with_saved(
|
variable_list apply_with_saved(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
SwapSavedVariables& saved) override;
|
SwapSavedVariables& saved) override;
|
||||||
|
@ -34,6 +34,7 @@
|
|||||||
#include <torch/csrc/utils/python_strings.h>
|
#include <torch/csrc/utils/python_strings.h>
|
||||||
#include <torch/csrc/utils/tensor_dtypes.h>
|
#include <torch/csrc/utils/tensor_dtypes.h>
|
||||||
|
|
||||||
|
#include <autograd/function.h>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
@ -185,9 +186,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
|||||||
return to_variable_list(r.get(), is_variable_input);
|
return to_variable_list(r.get(), is_variable_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto PyNode::defer_to_dynamo(
|
auto PyNode::apply_with_saved_impl(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
const std::optional<PyObject*>& compiler) -> variable_list {
|
const SwapSavedVariables& saved) -> variable_list {
|
||||||
pybind11::gil_scoped_acquire gil;
|
pybind11::gil_scoped_acquire gil;
|
||||||
at::OptionalDeviceGuard _device_guard;
|
at::OptionalDeviceGuard _device_guard;
|
||||||
THPFunction* py_fn = (THPFunction*)obj;
|
THPFunction* py_fn = (THPFunction*)obj;
|
||||||
@ -235,24 +236,24 @@ auto PyNode::defer_to_dynamo(
|
|||||||
}
|
}
|
||||||
THPObjectPtr saved_tensors(unpack_saved_variables(
|
THPObjectPtr saved_tensors(unpack_saved_variables(
|
||||||
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
|
py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
_backward_idx.has_value(),
|
auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this);
|
||||||
"indices should already be set by compiled_args, called before apply_with_saved");
|
|
||||||
PyObject* backward_state_idx = Py_None;
|
PyObject* backward_state_idx = Py_None;
|
||||||
if (_backward_state_idx.has_value()) {
|
if (maybe_bwd_state_idx.has_value()) {
|
||||||
backward_state_idx = THPUtils_packInt64(_backward_state_idx.value());
|
backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value());
|
||||||
// this might be simplifiable now that we no longer inline
|
// this might be simplifiable now that we no longer inline
|
||||||
Py_CLEAR(py_fn->compiled_autograd_backward_state);
|
Py_CLEAR(py_fn->compiled_autograd_backward_state);
|
||||||
}
|
}
|
||||||
THPObjectPtr r(PyObject_CallMethod(
|
THPObjectPtr r(PyObject_CallMethod(
|
||||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||||
compiler.value(),
|
saved.get_py_compiler(),
|
||||||
"proxy_call_backward",
|
"proxy_call_backward",
|
||||||
"OOOiOO",
|
"OOOiOO",
|
||||||
pyInputs.get(),
|
pyInputs.get(),
|
||||||
fwdInputMetadatas.get(),
|
fwdInputMetadatas.get(),
|
||||||
saved_tensors.get(),
|
saved_tensors.get(),
|
||||||
*_backward_idx,
|
bwd_idx,
|
||||||
obj,
|
obj,
|
||||||
backward_state_idx));
|
backward_state_idx));
|
||||||
|
|
||||||
@ -301,7 +302,7 @@ bool PyNode::is_aot_backward() const {
|
|||||||
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
|
return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id");
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyNode::compiled_args(CompiledNodeArgs& args) {
|
void PyNode::compiled_args(CompiledNodeArgs& args) const {
|
||||||
static PyObject* method_name =
|
static PyObject* method_name =
|
||||||
PyUnicode_InternFromString("_compiled_autograd_key");
|
PyUnicode_InternFromString("_compiled_autograd_key");
|
||||||
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
|
THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr));
|
||||||
@ -346,14 +347,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
|
|||||||
args.collect(f->input_info);
|
args.collect(f->input_info);
|
||||||
|
|
||||||
Py_INCREF(obj);
|
Py_INCREF(obj);
|
||||||
_backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
|
c10::SafePyObject backward_obj(obj, getPyInterpreter());
|
||||||
|
std::optional<c10::SafePyObject> backward_state_obj;
|
||||||
PyObject* bw_state = f->compiled_autograd_backward_state;
|
PyObject* bw_state = f->compiled_autograd_backward_state;
|
||||||
if (args.cond(bw_state != nullptr)) {
|
if (args.cond(bw_state != nullptr)) {
|
||||||
Py_INCREF(bw_state);
|
Py_INCREF(bw_state);
|
||||||
_backward_state_idx = args.add_backward_state(
|
backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter());
|
||||||
c10::SafePyObject(bw_state, getPyInterpreter()));
|
|
||||||
}
|
}
|
||||||
|
args.collect_pynode_objs(
|
||||||
|
this, std::move(backward_obj), std::move(backward_state_obj));
|
||||||
}
|
}
|
||||||
|
|
||||||
variable_list PyNode::apply_with_saved(
|
variable_list PyNode::apply_with_saved(
|
||||||
@ -368,8 +370,7 @@ variable_list PyNode::apply_with_saved(
|
|||||||
saved.before(f->output_info);
|
saved.before(f->output_info);
|
||||||
saved.before(f->input_info);
|
saved.before(f->input_info);
|
||||||
f->compiled_autograd_tracing = true;
|
f->compiled_autograd_tracing = true;
|
||||||
variable_list result =
|
variable_list result = apply_with_saved_impl(variable_list(inputs), saved);
|
||||||
defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
|
|
||||||
f->compiled_autograd_tracing = false;
|
f->compiled_autograd_tracing = false;
|
||||||
saved.after(f->compiled_autograd_symints);
|
saved.after(f->compiled_autograd_symints);
|
||||||
saved.after(f->saved_variables);
|
saved.after(f->saved_variables);
|
||||||
|
@ -35,9 +35,9 @@ struct PyNode : public Node {
|
|||||||
const std::vector<bool>& is_variable_input);
|
const std::vector<bool>& is_variable_input);
|
||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
variable_list defer_to_dynamo(
|
variable_list apply_with_saved_impl(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
const std::optional<PyObject*>& compiler);
|
const SwapSavedVariables& saved);
|
||||||
|
|
||||||
void release_variables() override;
|
void release_variables() override;
|
||||||
std::string name() const override;
|
std::string name() const override;
|
||||||
@ -45,7 +45,7 @@ struct PyNode : public Node {
|
|||||||
|
|
||||||
bool is_aot_backward() const override;
|
bool is_aot_backward() const override;
|
||||||
|
|
||||||
void compiled_args(CompiledNodeArgs& args) override;
|
void compiled_args(CompiledNodeArgs& args) const override;
|
||||||
variable_list apply_with_saved(
|
variable_list apply_with_saved(
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
SwapSavedVariables& saved) override;
|
SwapSavedVariables& saved) override;
|
||||||
@ -53,13 +53,6 @@ struct PyNode : public Node {
|
|||||||
// THPFunction this Function is wrapping. Owning!
|
// THPFunction this Function is wrapping. Owning!
|
||||||
PyObject* obj;
|
PyObject* obj;
|
||||||
|
|
||||||
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
|
|
||||||
std::optional<int> _backward_idx;
|
|
||||||
|
|
||||||
// The AutogradCompilerCall::hooks idx corresponding to this node's
|
|
||||||
// backward_state
|
|
||||||
std::optional<int> _backward_state_idx;
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||||
~PyNode() override {
|
~PyNode() override {
|
||||||
// Can't use THPObjectPtr as a field in this class; destructor won't take
|
// Can't use THPObjectPtr as a field in this class; destructor won't take
|
||||||
|
@ -176,7 +176,7 @@ auto PyFunctionPostHook::operator()(
|
|||||||
return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
|
return unwrap_variables(PyTuple_GetItem(tup.get(), 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
|
void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) const {
|
||||||
PyObject *key = nullptr, *value = nullptr;
|
PyObject *key = nullptr, *value = nullptr;
|
||||||
Py_ssize_t pos = 0;
|
Py_ssize_t pos = 0;
|
||||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||||
@ -189,7 +189,7 @@ void PyFunctionTensorPreHook::compiled_args(CompiledNodeArgs& args) {
|
|||||||
Py_END_CRITICAL_SECTION();
|
Py_END_CRITICAL_SECTION();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
|
void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) const {
|
||||||
PyObject *key = nullptr, *value = nullptr;
|
PyObject *key = nullptr, *value = nullptr;
|
||||||
Py_ssize_t pos = 0;
|
Py_ssize_t pos = 0;
|
||||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||||
@ -200,7 +200,7 @@ void PyFunctionPreHook::compiled_args(CompiledNodeArgs& args) {
|
|||||||
Py_END_CRITICAL_SECTION();
|
Py_END_CRITICAL_SECTION();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) {
|
void PyFunctionPostHook::compiled_args(CompiledNodeArgs& args) const {
|
||||||
PyObject *key = nullptr, *value = nullptr;
|
PyObject *key = nullptr, *value = nullptr;
|
||||||
Py_ssize_t pos = 0;
|
Py_ssize_t pos = 0;
|
||||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||||
@ -237,7 +237,7 @@ auto PyFunctionTensorPostAccGradHooks::operator()(const Variable& tensor)
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PyFunctionTensorPostAccGradHooks::compiled_args(
|
void PyFunctionTensorPostAccGradHooks::compiled_args(
|
||||||
torch::dynamo::autograd::CompiledNodeArgs& args) {
|
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||||
PyObject *key = nullptr, *value = nullptr;
|
PyObject *key = nullptr, *value = nullptr;
|
||||||
Py_ssize_t pos = 0;
|
Py_ssize_t pos = 0;
|
||||||
Py_BEGIN_CRITICAL_SECTION(dict);
|
Py_BEGIN_CRITICAL_SECTION(dict);
|
||||||
|
@ -14,7 +14,8 @@ struct PyFunctionTensorPreHook : public FunctionPreHook {
|
|||||||
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
|
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
|
||||||
~PyFunctionTensorPreHook() override;
|
~PyFunctionTensorPreHook() override;
|
||||||
variable_list operator()(const variable_list& values) override;
|
variable_list operator()(const variable_list& values) override;
|
||||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
void compiled_args(
|
||||||
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||||
PyObject* dict;
|
PyObject* dict;
|
||||||
size_t value_idx;
|
size_t value_idx;
|
||||||
};
|
};
|
||||||
@ -23,7 +24,8 @@ struct PyFunctionPreHook : public FunctionPreHook {
|
|||||||
PyFunctionPreHook(PyObject* dict);
|
PyFunctionPreHook(PyObject* dict);
|
||||||
~PyFunctionPreHook() override;
|
~PyFunctionPreHook() override;
|
||||||
variable_list operator()(const variable_list& values) override;
|
variable_list operator()(const variable_list& values) override;
|
||||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
void compiled_args(
|
||||||
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||||
PyObject* dict;
|
PyObject* dict;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -33,7 +35,8 @@ struct PyFunctionPostHook : public FunctionPostHook {
|
|||||||
variable_list operator()(
|
variable_list operator()(
|
||||||
const variable_list& outputs,
|
const variable_list& outputs,
|
||||||
const variable_list& inputs) override;
|
const variable_list& inputs) override;
|
||||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
void compiled_args(
|
||||||
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||||
PyObject* dict;
|
PyObject* dict;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -45,7 +48,8 @@ struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
|
|||||||
PyFunctionTensorPostAccGradHooks(PyObject* dict);
|
PyFunctionTensorPostAccGradHooks(PyObject* dict);
|
||||||
~PyFunctionTensorPostAccGradHooks() override;
|
~PyFunctionTensorPostAccGradHooks() override;
|
||||||
void operator()(const Variable& tensor) override;
|
void operator()(const Variable& tensor) override;
|
||||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
void compiled_args(
|
||||||
|
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
|
||||||
void apply_with_saved(
|
void apply_with_saved(
|
||||||
Variable& tensor,
|
Variable& tensor,
|
||||||
torch::dynamo::autograd::SwapSavedVariables& saved) override;
|
torch::dynamo::autograd::SwapSavedVariables& saved) override;
|
||||||
|
@ -27,7 +27,7 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
|
|||||||
return fn_(outputs, inputs);
|
return fn_(outputs, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void compiled_args(CompiledNodeArgs& args) override {}
|
void compiled_args(CompiledNodeArgs& args) const override {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
|
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
|
||||||
|
@ -183,21 +183,23 @@ Reducer::Reducer(
|
|||||||
#endif
|
#endif
|
||||||
// Hook to execute after the gradient accumulator has executed.
|
// Hook to execute after the gradient accumulator has executed.
|
||||||
hooks_.emplace_back(
|
hooks_.emplace_back(
|
||||||
grad_accumulator->add_post_hook(
|
grad_accumulator->add_post_hook(std::make_unique<
|
||||||
std::make_unique<torch::autograd::utils::LambdaPostHook>(
|
torch::autograd::utils::
|
||||||
[this, variable_index](
|
LambdaPostHook>(
|
||||||
const torch::autograd::variable_list& outputs,
|
[this, variable_index](
|
||||||
const torch::autograd::variable_list& /* unused */) {
|
const torch::autograd::variable_list& outputs,
|
||||||
|
const torch::autograd::variable_list& /* unused */) {
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
this->rpc_context_.set(
|
this->rpc_context_.set(
|
||||||
ThreadLocalDistAutogradContext::getContextPtr());
|
ThreadLocalDistAutogradContext::getContextPtr());
|
||||||
#endif
|
#endif
|
||||||
this->autograd_hook(variable_index);
|
this->autograd_hook(variable_index);
|
||||||
return outputs;
|
return outputs;
|
||||||
},
|
},
|
||||||
[=](torch::autograd::CompiledNodeArgs& args) {
|
[=](torch::autograd::CompiledNodeArgs& args) {
|
||||||
// Make post_hook an noop if compiled_autograds is enabled.
|
TORCH_INTERNAL_ASSERT(
|
||||||
})),
|
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
|
||||||
|
})),
|
||||||
grad_accumulator);
|
grad_accumulator);
|
||||||
|
|
||||||
// Map raw function pointer to parameter index.
|
// Map raw function pointer to parameter index.
|
||||||
|
@ -343,6 +343,9 @@ struct AutogradCompilerCall {
|
|||||||
std::vector<uint32_t> size_input_origins;
|
std::vector<uint32_t> size_input_origins;
|
||||||
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
|
std::unordered_map<const SavedVariable*, std::pair<size_t, size_t>>
|
||||||
sv_to_hooks;
|
sv_to_hooks;
|
||||||
|
// pynode -> backward and backward state idx
|
||||||
|
std::unordered_map<const Node*, std::pair<size_t, std::optional<size_t>>>
|
||||||
|
pynode_objs;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CompiledNodeArgs {
|
class CompiledNodeArgs {
|
||||||
@ -613,12 +616,17 @@ class CompiledNodeArgs {
|
|||||||
typeid(*node), _specialization_key, _specialization_key_size);
|
typeid(*node), _specialization_key, _specialization_key_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t add_backward(c10::SafePyObject&& obj) {
|
void collect_pynode_objs(
|
||||||
return _compiler.emplace_hook(std::move(obj));
|
const Node* pynode,
|
||||||
}
|
c10::SafePyObject&& bwd,
|
||||||
|
std::optional<c10::SafePyObject>&& bwd_state) {
|
||||||
size_t add_backward_state(c10::SafePyObject&& obj) {
|
size_t bwd_idx = _compiler.emplace_hook(std::move(bwd));
|
||||||
return _compiler.emplace_hook(std::move(obj));
|
std::optional<size_t> bwd_state_idx;
|
||||||
|
if (auto state = std::move(bwd_state); state.has_value()) {
|
||||||
|
bwd_state_idx = _compiler.emplace_hook(std::move(state.value()));
|
||||||
|
}
|
||||||
|
_compiler.pynode_objs.emplace(
|
||||||
|
pynode, std::make_pair(bwd_idx, bwd_state_idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
|
void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
|
||||||
@ -737,6 +745,13 @@ class SwapSavedVariables {
|
|||||||
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
|
// cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
|
||||||
// allows tracing to happen, then swaps them back afterwards.
|
// allows tracing to happen, then swaps them back afterwards.
|
||||||
public:
|
public:
|
||||||
|
std::pair<size_t, std::optional<size_t>> retrieve_pynode_objs(
|
||||||
|
Node* pynode) const {
|
||||||
|
auto it = compiler.pynode_objs.find(pynode);
|
||||||
|
TORCH_INTERNAL_ASSERT(it != compiler.pynode_objs.end());
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
void before(at::Tensor& t) {
|
void before(at::Tensor& t) {
|
||||||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
TensorArg& arg = compiler.tensor_args.lookup(t);
|
||||||
stashed_tensors.save(&t, std::move(t));
|
stashed_tensors.save(&t, std::move(t));
|
||||||
@ -942,7 +957,7 @@ class SwapSavedVariables {
|
|||||||
const NodeCall& n)
|
const NodeCall& n)
|
||||||
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
|
: compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
|
||||||
|
|
||||||
PyObject* get_py_compiler() {
|
PyObject* get_py_compiler() const {
|
||||||
return py_compiler;
|
return py_compiler;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user