mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 00:54:52 +08:00
This PR adds support for torch.autograd.Function subclasses in compiled autograd. We do this by:
- Creating a uid for all torch.autograd.Function via its metaclass. This uid is used in the compiled autograd key, which is a subset of the cache key to the compiled graph
- "Lifting" the backward/saved_tensors, having them as input arguments in the compiled graph
- Creating proxies to track the backward's inputs and outputs. Since the backward's outputs (grads) have to match the forward's inputs, we pass the node's `input_info` (forward's input sizes) to build the proxies tracking the backward's outputs.
- Use a `FakeContext` class as a replacement for the autograd node's context object (`BackwardCFunction`) during tracing, only support passing saved_tensors from the forward to the backward
- Index each backward, to support multiple torch.autograd.Functions in the same graph
- Special case for `CompiledFunctionBackward`, lifting CompiledFunction will fail 4 tests and requires some skipfiles changes that I'd rather do that in a separate PR
Example graph: test_custom_fn_saved_multiple_tensors (eager fw + compiled autograd)
```python
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return torch.sin(x), torch.sin(y)
@staticmethod
def backward(ctx, gO_x, gO_y):
(x, y) = ctx.saved_tensors
return gO_x * torch.cos(x), gO_y * torch.cos(y)
```
The backwards is lifted via `getitem_5` and `call_backward`
```python
# Compiled autograd graph
===== Compiled autograd graph =====
<eval_with_key>.0 class CompiledAutograd(torch.nn.Module):
def forward(self, inputs, sizes, hooks):
# No stacktrace found for following nodes
getitem: "f32[]" = inputs[0]
getitem_1: "f32[10]" = inputs[1]
getitem_2: "f32[10]" = inputs[2]
getitem_3: "f32[10]" = inputs[3]
getitem_4: "f32[10]" = inputs[4]; inputs = None
expand: "f32[10]" = torch.ops.aten.expand.default(getitem, [10]); getitem = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
getitem_5 = hooks[0]; hooks = None
call_backward = torch__dynamo_external_utils_call_backward(getitem_5, (getitem_3, getitem_4), mul_1, mul); getitem_5 = mul_1 = mul = None
getitem_6: "f32[10]" = call_backward[0]
getitem_7: "f32[10]" = call_backward[1]; call_backward = None
accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_7); getitem_4 = getitem_7 = None
accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_6); getitem_3 = getitem_6 = None
return []
```
then is later inlined by dynamo
```python
# Dynamo graph
===== __compiled_fn_0 =====
<eval_with_key>.1 class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor, L_inputs_1_ : torch.Tensor, L_inputs_2_ : torch.Tensor, L_inputs_3_ : torch.Tensor, L_inputs_4_ : torch.Tensor):
getitem = L_inputs_0_
getitem_1 = L_inputs_1_
getitem_2 = L_inputs_2_
x = L_inputs_3_
y = L_inputs_4_
# File: <eval_with_key>.0:10, code: expand = torch.ops.aten.expand.default(getitem, [10]); getitem = None
expand = torch.ops.aten.expand.default(getitem, [10]); getitem = None
# File: <eval_with_key>.0:11, code: mul = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
mul = torch.ops.aten.mul.Tensor(expand, getitem_2); getitem_2 = None
# File: <eval_with_key>.0:12, code: mul_1 = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
mul_1 = torch.ops.aten.mul.Tensor(expand, getitem_1); expand = getitem_1 = None
# File: /data/users/xmfan/core/pytorch/test/inductor/test_compiled_autograd.py:412, code: return gO_x * torch.cos(x), gO_y * torch.cos(y)
cos = torch.cos(x)
getitem_6 = mul_1 * cos; mul_1 = cos = None
cos_1 = torch.cos(y)
getitem_7 = mul * cos_1; mul = cos_1 = None
# File: <eval_with_key>.0:17, code: accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_4, getitem_7); getitem_4 = getitem_7 = None
accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(y, getitem_7); y = getitem_7 = None
# File: <eval_with_key>.0:18, code: accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, getitem_6); getitem_3 = getitem_6 = None
accumulate_grad__default_1 = torch.ops.inductor.accumulate_grad_.default(x, getitem_6); x = getitem_6 = None
return ()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115573
Approved by: https://github.com/jansel
156 lines
5.2 KiB
C++
156 lines
5.2 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/autograd/custom_function.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/saved_variable.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
|
|
#include <c10/core/DeviceGuard.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
#include <memory>
|
|
#include <optional>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch::jit {
|
|
struct Graph;
|
|
}
|
|
|
|
namespace torch::autograd {
|
|
|
|
// A Function which is implemented by a Python object (i.e., a THPFunction).
|
|
// Calls to 'apply' are forwarded to the Python method implementation.
|
|
struct PyNode : public Node {
|
|
PyNode(THPObjectPtr obj) : obj(obj.release()) {}
|
|
|
|
PyObject* to_py_args(
|
|
const variable_list& inputs,
|
|
at::OptionalDeviceGuard* device_guard);
|
|
variable_list to_variable_list(
|
|
const PyObject* r,
|
|
const std::vector<bool>& is_variable_input);
|
|
|
|
variable_list apply(variable_list&& inputs) override;
|
|
variable_list compiled_apply(
|
|
variable_list&& inputs,
|
|
std::optional<PyObject*> compiler);
|
|
|
|
void release_variables() override;
|
|
std::string name() const override;
|
|
bool is_traceable() override;
|
|
|
|
void compiled_args(CompiledNodeArgs& args) override;
|
|
variable_list apply_with_saved(
|
|
const variable_list& inputs,
|
|
SwapSavedVariables& saved) override;
|
|
|
|
bool compiled_autograd_should_lift() const;
|
|
|
|
// THPFunction this Function is wrapping. Owning!
|
|
PyObject* obj;
|
|
|
|
// The AutogradCompilerCall::hooks idx corresponding to this node's backward
|
|
std::optional<int> _backward_idx;
|
|
|
|
~PyNode() override {
|
|
// Can't use THPObjectPtr as a field in this class; destructor won't take
|
|
// out GIL! When I forgot to do this by hand
|
|
// TestAutograd.test_inplace_view_python called me out about it.
|
|
// If python is already dead, leak the wrapped python objects
|
|
if (Py_IsInitialized()) {
|
|
pybind11::gil_scoped_acquire gil;
|
|
Py_DECREF(obj);
|
|
}
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Cast an object into a tuple, if it is not a tuple already. Returns true
|
|
* if the original object was not a tuple.
|
|
*/
|
|
inline bool ensure_tuple(THPObjectPtr& obj) {
|
|
if (PyTuple_Check(obj.get()))
|
|
return false;
|
|
|
|
PyObject* tuple = PyTuple_New(1);
|
|
if (!tuple)
|
|
throw python_error();
|
|
PyTuple_SET_ITEM(tuple, 0, obj.release());
|
|
obj = tuple;
|
|
return true;
|
|
}
|
|
|
|
} // namespace torch::autograd
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
struct THPFunction {
|
|
PyObject_HEAD
|
|
|
|
PyObject* needs_input_grad;
|
|
|
|
// Python tuple of tensors whose variables we should save. Set
|
|
// by Python with 'save_for_backward'. If nullptr, no tensors were
|
|
// saved.
|
|
PyObject* to_save;
|
|
// Python tuple of tensors which are not differentiable. Set by
|
|
// Python with 'mark_non_differentiable'. If nullptr, no tensors were
|
|
// non-differentiable.
|
|
PyObject* non_differentiable;
|
|
// Python tuple of tensors which had inplace updates in the forward()
|
|
// pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were
|
|
// modified inplace.
|
|
PyObject* dirty_tensors;
|
|
|
|
// boolean indicating whether to materialize undefined output grad tensors
|
|
// into tensors full of zeros. Set by Python with 'set_materialize_grads'.
|
|
// Default is true.
|
|
bool materialize_grads;
|
|
|
|
// boolean indicating whether to materialize output grad tensors
|
|
// corresponding to non-differentiable outputs. Normally, someone would
|
|
// already get this behavior by switching off materialize_grads,
|
|
// but there are certain use cases where that is not feasible:
|
|
// https://github.com/pytorch/pytorch/pull/98659#pullrequestreview-1376822560
|
|
bool materialize_non_diff_grads;
|
|
|
|
// This is enabled by compiled autograd as a way to signal to AotAutograd it
|
|
// should call the original FX graph rather than compiling.
|
|
bool compiled_autograd_tracing;
|
|
std::vector<c10::SymInt> compiled_autograd_symints;
|
|
|
|
std::vector<torch::autograd::VariableInfo> output_info;
|
|
std::vector<torch::autograd::VariableInfo> input_info;
|
|
std::vector<torch::autograd::SavedVariable> saved_variables;
|
|
// For each input, true if the input is a THPVariable
|
|
std::vector<bool> is_variable_input;
|
|
char has_freed_buffers;
|
|
|
|
PyObject* saved_for_forward;
|
|
// The actual PyNode (in the autograd graph) that this data was
|
|
// saved for. This field may be NULL (because a user can construct
|
|
// a THPFunction directly from Python), but when this field is non-NULL,
|
|
// it is guaranteed that cdata.lock()->obj == this
|
|
//
|
|
// In most ordinary use, this field should always be non-NULL; e.g.,
|
|
// when we allocate a THPFunction because we are running Node.apply,
|
|
// after constructing a THPFunction, we immediately allocate a PyNode
|
|
// for it. We can't enforce this directly in the constructor of
|
|
// THPFunction though, because there's no way to keep it live long enough
|
|
// to save an owning reference to PyNode into the grad_fn of a Variable.
|
|
std::weak_ptr<torch::autograd::PyNode> cdata;
|
|
};
|
|
|
|
bool THPFunction_initModule(PyObject* module);
|
|
extern PyTypeObject THPFunctionType;
|
|
extern PyObject* THPFunctionClass;
|
|
extern PyObject* THPGradientEdgeClass;
|
|
|
|
inline bool THPFunction_Check(PyObject* obj) {
|
|
return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
|
|
}
|