mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[autograd] Support GradientEdge as output for torch.autograd.grad (#127766)
This is useful for splitting grad to run in two parts while preserving intermediates: <details> <summary> Click to see code </summary> ```python import collections import weakref from torch.autograd.graph import GradientEdge def _get_grad_fn_or_grad_acc(t): if t.requires_grad and t.grad_fn is None: return t.view_as(t).grad_fn.next_functions[0][0] else: return t.grad_fn def reverse_closure(roots, target_nodes): # Recurse until we reach a target node closure = set() actual_target_nodes = set() q: Deque = collections.deque() for node in roots: if node is not None and node not in closure: closure.add(node) q.append(node) while q: node = q.popleft() reverse_edges = node.metadata.get("reverse_edges", []) for holder_ref, idx in reverse_edges: ref = holder_ref() if ref is not None: raise RuntimeError("Reverse graph is no longer alive") fn = ref.node if fn in closure or fn is None: continue if fn in target_nodes: actual_target_nodes.add(fn) continue closure.add(fn) q.append(fn) return closure, actual_target_nodes # Enable weak pointer class Holder(): def __init__(self, node): self.node = node # TODO: use weak references to avoid reference cycle def construct_reverse_graph(roots): q: Deque = collections.deque() root_seen = set() reverse_graph_refs = [] for node in roots: if node is not None and node not in root_seen: q.append(node) root_seen.add(node) while q: node = q.popleft() for fn, idx in node.next_functions: if fn is not None: # Don't necessarily need to store on the graph reverse_edges = fn.metadata.get("reverse_edges", []) if len(reverse_edges) == 0: q.append(fn) holder = Holder(node) holder_ref = weakref.ref(holder) reverse_graph_refs.append(holder) reverse_edges.append((holder_ref, idx)) fn.metadata["reverse_edges"] = reverse_edges return reverse_graph_refs def get_param_groups(inputs, params): inputs_closure, _ = reverse_closure(inputs, set()) param_groups = dict() # keyed on intermediates for i, param in enumerate(params): closure, intersected = reverse_closure([param], inputs_closure) param_group = { "params": set([param]), "intermediates": set(intersected), } for input_node in intersected: existing = param_groups.get(input_node, None) if existing is not None: existing["params"] = existing["params"].union(param_group["params"]) existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"]) param_group = existing else: param_groups[input_node] = param_group # Sanity check: union of all param_groups params should be equal to all params union_params = set() seen_ids = set() unique_param_groups = [] for param_group in param_groups.values(): if id(param_group) not in seen_ids: seen_ids.add(id(param_group)) unique_param_groups.append(param_group) union_params = union_params.union(param_group["params"]) assert union_params == set(params) return unique_param_groups def compute_grads_only_inputs2(roots, inps, weights): root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots)) inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps)) weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights)) reverse_graph_refs = construct_reverse_graph(root_grad_fns) param_groups = get_param_groups(inp_grad_fns, weight_grad_fns) del reverse_graph_refs for param_group in param_groups: for i, intermediate in enumerate(param_group["intermediates"]): def get_hook(param_group, i): def hook(grad_inputs): if param_group.get("grads", None) is None: param_group["grads"] = [None] * len(param_group["intermediates"]) param_group["grads"][i] = grad_inputs return hook # These are always "split" nodes that we need to recompute, so # save their inputs. intermediate.register_prehook(get_hook(param_group, i)) dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True) return dinputs, param_groups def compute_grads_only_weights2(user_weights, param_groups): all_dweights = dict() for param_group in param_groups: # TODO: Handle case where intermediate can have multiple outputs intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"]) weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) assert all(len(g) == 1 for g in param_group["grads"]) # [NEW!] Able to pass a GradientEdge to autograd.grad as output # We do not need to retain_graph because... guarantee no overlap? print("trying to execute: ", intermediate_edges, weights_edges) dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple())) for w, dw in zip(param_group["params"], dweights): all_dweights[w] = dw # return grads in the original order weights were provided in out = [] for w in user_weights: grad_acc = _get_grad_fn_or_grad_acc(w) out.append(all_dweights[grad_acc]) return tuple(out) ``` </details> ```python import torch.nn as nn # Setup mod1 = nn.Linear(10, 10) mod2 = nn.Linear(10, 10) a = torch.rand(10, requires_grad=True) weights = tuple(mod1.parameters()) + tuple(mod2.parameters()) inps = (a,) out = mod2(mod1(a)) class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} rs = func(*args, **kwargs) print(f"{func.__module__}.{func.__name__}") return rs print(" -- SPLIT -- ") # Compute gradients in two parts with LoggingTensorMode(): print("PART 1") dinputs, state = compute_grads_only_inputs2((out,), inps, weights) print("PART 2") dweights = compute_grads_only_weights2(weights, state) out = mod2(mod1(a)) print(" -- REF -- ") # Compare with reference with LoggingTensorMode(): ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),)) for actual, ref in zip(dinputs + dweights, ref_all_gradients): print(torch.allclose(actual, ref)) ``` <img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671"> ``` PART 1 torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.ones_like.default V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]] torch._ops.aten.view.default V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]] torch._ops.aten.t.default torch._ops.aten.mm.default V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]] torch._ops.aten.view.default V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]] torch._ops.aten.view.default V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]] torch._ops.aten.t.default torch._ops.aten.mm.default V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]] torch._ops.aten.view.default PART 2 trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0)) V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]] torch._ops.aten.t.default torch._ops.aten.mm.default torch._ops.aten.t.default torch._ops.aten.sum.dim_IntList torch._ops.aten.view.default V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]] torch._ops.aten.t.default trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0)) V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]] torch._ops.aten.t.default torch._ops.aten.mm.default torch._ops.aten.t.default torch._ops.aten.sum.dim_IntList torch._ops.aten.view.default V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]] torch._ops.aten.t.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c1e7e40f24
commit
2eec02523b
@ -210,6 +210,27 @@ PyObject* THPCppFunction_set_sequence_nr(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPCppFunction_input_metadata(PyObject* self, void* closure) {
|
||||
HANDLE_TH_ERRORS;
|
||||
auto& fn = *((THPCppFunction*)self)->cdata;
|
||||
const auto num_inputs =
|
||||
fn.num_inputs(); // Assuming there's a method to get the number of inputs
|
||||
THPObjectPtr list(PyTuple_New(num_inputs));
|
||||
if (!list) {
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < num_inputs; ++i) {
|
||||
const auto& metadata = fn.input_metadata(i);
|
||||
THPObjectPtr item(py::cast(metadata).release().ptr());
|
||||
if (!item) {
|
||||
return nullptr;
|
||||
}
|
||||
PyTuple_SET_ITEM(list.get(), i, item.release());
|
||||
}
|
||||
return list.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
||||
static struct PyMethodDef default_methods[] = {
|
||||
THP_FUNCTION_DEFAULT_METHODS,
|
||||
|
Reference in New Issue
Block a user