mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 08:00:58 +08:00
[autograd] Support GradientEdge as output for torch.autograd.grad (#127766)
This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c1e7e40f24
commit
2eec02523b
@ -166,6 +166,24 @@ c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
|
||||
|
||||
PyObject* THPEngineClass = nullptr;
|
||||
|
||||
inline static Edge parseGradientEdge(PyObject* obj, int64_t index) {
|
||||
PyObject* grad_fn = PyTuple_GetItem(obj, 0);
|
||||
auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1));
|
||||
std::shared_ptr<torch::autograd::Node> grad_fn_sp;
|
||||
if (THPFunction_Check(grad_fn)) {
|
||||
grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock();
|
||||
} else if (THPCppFunction_Check(grad_fn)) {
|
||||
grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"GradientEdge's first object must be an autograd.graph.Node "
|
||||
"but got ",
|
||||
THPUtils_typename(grad_fn));
|
||||
}
|
||||
return Edge(grad_fn_sp, output_nr);
|
||||
}
|
||||
|
||||
// Implementation of torch._C._EngineBase.run_backward
|
||||
PyObject* THPEngine_run_backward(
|
||||
PyObject* self,
|
||||
@ -239,22 +257,29 @@ PyObject* THPEngine_run_backward(
|
||||
grads.reserve(num_tensors);
|
||||
for (const auto i : c10::irange(num_tensors)) {
|
||||
PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
|
||||
TORCH_CHECK(
|
||||
THPVariable_Check(_tensor),
|
||||
"element ",
|
||||
i,
|
||||
" of tensors tuple is not a Tensor");
|
||||
const auto& variable = THPVariable_Unpack(_tensor);
|
||||
TORCH_CHECK(
|
||||
!isBatchedTensor(variable),
|
||||
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
|
||||
"torch.vmap. We do not support the case where any outputs are ",
|
||||
"vmapped tensors (output ",
|
||||
i,
|
||||
" is being vmapped over). Please "
|
||||
"call autograd.grad() outside torch.vmap or file a bug report "
|
||||
"with your use case.")
|
||||
auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
|
||||
Edge gradient_edge; // Temporary variable to hold the gradient edge
|
||||
c10::optional<at::Tensor> mb_output;
|
||||
if (THPVariable_Check(_tensor)) {
|
||||
mb_output = THPVariable_Unpack(_tensor);
|
||||
TORCH_CHECK(
|
||||
!isBatchedTensor(mb_output.value()),
|
||||
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
|
||||
"torch.vmap. We do not support the case where any outputs are ",
|
||||
"vmapped tensors (output ",
|
||||
i,
|
||||
" is being vmapped over). Please "
|
||||
"call autograd.grad() outside torch.vmap or file a bug report "
|
||||
"with your use case.");
|
||||
gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value());
|
||||
} else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) {
|
||||
gradient_edge = parseGradientEdge(_tensor, i);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"element ",
|
||||
i,
|
||||
" of tensors tuple is neither a Tensor nor a GradientEdge");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
gradient_edge.function,
|
||||
"element ",
|
||||
@ -281,7 +306,13 @@ PyObject* THPEngine_run_backward(
|
||||
i,
|
||||
" of gradients tuple is not a Tensor or None");
|
||||
TORCH_CHECK(
|
||||
!variable.requires_grad(),
|
||||
mb_output.has_value(),
|
||||
"element ",
|
||||
i,
|
||||
" of gradients tuple is None, but the corresponding output is a GradientEdge."
|
||||
"This is not supported.");
|
||||
TORCH_CHECK(
|
||||
!mb_output.value().requires_grad(),
|
||||
"element ",
|
||||
i,
|
||||
" of gradients tuple is None, but the corresponding Tensor requires grad");
|
||||
@ -330,23 +361,7 @@ PyObject* THPEngine_run_backward(
|
||||
output_edges.emplace_back(grad_fn, output_nr);
|
||||
}
|
||||
} else if (PyObject_IsInstance(input, THPGradientEdgeClass)) {
|
||||
auto node = PyTuple_GetItem(input, 0);
|
||||
bool isTHPFunction = THPFunction_Check(node);
|
||||
bool isTHPCppFunction = THPCppFunction_Check(node);
|
||||
TORCH_CHECK(
|
||||
isTHPFunction || isTHPCppFunction,
|
||||
"GradientEdge first object must be an autograd.graph.Node "
|
||||
"but got ",
|
||||
THPUtils_typename(node));
|
||||
std::shared_ptr<torch::autograd::Node> node_sp;
|
||||
if (isTHPFunction) {
|
||||
node_sp = ((THPFunction*)node)->cdata.lock();
|
||||
} else {
|
||||
node_sp = ((torch::autograd::THPCppFunction*)node)->cdata;
|
||||
}
|
||||
|
||||
auto output_nr = THPUtils_unpackUInt32(PyTuple_GetItem(input, 1));
|
||||
output_edges.emplace_back(node_sp, output_nr);
|
||||
output_edges.emplace_back(parseGradientEdge(input, i));
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
||||
Reference in New Issue
Block a user