[invoke_subgraph] Lazy backward (#150666)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150666
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This commit is contained in:
Animesh Jain
2025-04-07 12:35:57 -07:00
committed by PyTorch MergeBot
parent 78fe079c97
commit 6ea5514e04
3 changed files with 171 additions and 50 deletions

View File

@ -672,12 +672,19 @@ class HopSubgraphCache:
@abstractmethod
def get_proxy_dispatch_entry(self, identifier: str): ...
@abstractmethod
def add_lazy_bwd_entry(self, identifier: str, gmod: torch.fx.GraphModule): ...
@abstractmethod
def get_lazy_bwd_entry(self, identifier: str): ...
class InvokeSubgraphCache(HopSubgraphCache):
def __init__(self) -> None:
self.autograd_cache: dict[str, Callable] = {}
self.proxy_dispatch_cache: dict[str, Callable] = {}
self.dynamo_identifiers: dict[str, str] = {}
self.lazy_bwd_cache: dict[str, torch.fx.GraphModule] = {}
def add_dynamo_identifier(self, cache_key: str, identifier: str):
self.dynamo_identifiers[cache_key] = identifier
@ -697,6 +704,12 @@ class InvokeSubgraphCache(HopSubgraphCache):
def get_proxy_dispatch_entry(self, identifier: str):
return self.proxy_dispatch_cache.get(identifier, None)
def add_lazy_bwd_entry(self, identifier: str, gmod: torch.fx.GraphModule):
self.lazy_bwd_cache[identifier] = gmod
def get_lazy_bwd_entry(self, identifier: str):
return self.lazy_bwd_cache.get(identifier, None)
class HopDispatchSetCache:
def __init__(self) -> None:

View File

@ -151,9 +151,11 @@ class BaseHOPFunction(torch.autograd.Function):
from .utils import _from_fun
fw_inputs = pytree.tree_map(_from_fun, operands)
_, joint_graph, _, _ = create_fw_bw_graph(
subgraph, fw_inputs, grad_outputs
)
(
_,
joint_graph,
_,
) = create_fw_bw_graph(subgraph, fw_inputs, grad_outputs)
# The joint graph returns (*grad_inputs, *fwd_outputs).
# We only need the grad_inputs.

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Optional, Union
@ -42,7 +43,8 @@ invoke_subgraph_counter = 0
# used to filter out grad_outs/tangents in the `backward` method of
# InvokeSubgraphAutogradOp.
@dataclass
class FilterTangentInfo:
class OutputMetadata:
num_fw_outs: Optional[int] = None
indexes_with_none: set[int] = field(default_factory=set)
indexes_with_no_grad: set[int] = field(default_factory=set)
@ -144,6 +146,7 @@ def get_invoke_subgraph_cache():
return cache
# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra
def trace_joint_graph(fn, fw_inputs, fw_outputs):
"""
Naively trace out a joint graph. This simplifies the reconstruction of joint
@ -184,6 +187,7 @@ def trace_joint_graph(fn, fw_inputs, fw_outputs):
return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra
def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
@ -209,13 +213,14 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
# Also collect the indexes of no_grad in the output to filter out
# the grad_outs in the `backward` method.
filter_tangent_info = FilterTangentInfo()
output_metadata = OutputMetadata()
output_metadata.num_fw_outs = num_fw_outs
for idx, fw_out in enumerate(fw_outs):
if fw_out is None:
filter_tangent_info.indexes_with_none.add(idx)
output_metadata.indexes_with_none.add(idx)
elif not fw_out.requires_grad:
filter_tangent_info.indexes_with_no_grad.add(idx)
output_metadata.indexes_with_no_grad.add(idx)
if grad_outputs is None:
# Infer grad_outputs to be the same properties as the fw_outputs
@ -253,88 +258,182 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
fw_inputs,
grad_outputs,
)
return fw_graph, bw_graph, num_fw_outs, filter_tangent_info
return fw_graph, bw_graph, output_metadata
def get_output_metadata(subgraph, operands):
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
# args are functional tensors, generate some example tensors
fw_inputs = pytree.tree_map(_from_fun, operands)
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(fw_inputs)
context = (
nullcontext()
if fake_mode is None or fake_mode.shape_env is None
else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
)
with context:
fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
num_fw_outs = len(fw_outs)
# Collect the indexes of none in the output to check that the grad
# is None at the corresponding index in the backward. This check is
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
# Also collect the indexes of no_grad in the output to filter out
# the grad_outs in the `backward` method.
output_metadata = OutputMetadata()
output_metadata.num_fw_outs = num_fw_outs
for idx, fw_out in enumerate(fw_outs):
if fw_out is None:
output_metadata.indexes_with_none.add(idx)
elif not fw_out.requires_grad:
output_metadata.indexes_with_no_grad.add(idx)
return output_metadata
def trace_joint_graph_as_bwd(
fn, num_primals, joint_operands, include_key_set, exclude_key_set
):
"""
Naively trace out a joint graph. This simplifies the reconstruction of joint
graph in the min-cut partitioner later on.
"""
from torch._functorch.aot_autograd import create_joint
dummy_aot_config = get_dummy_aot_autograd_config()
# This joint_fn is inserted as the backward graph as is. This simplifies the
# min-cut partitioner work later on.
# Input signature - (*primals, *tangents)
# Output signature - (*grads, *fw_outs)
# The output signature is deliberately kept grads first and fw_outs second.
# Having grads first makes the min-cut partitioner HOP graph stitching
# easier.
def joint_fn(*primals_and_tangents):
primals = primals_and_tangents[:num_primals]
tangents = primals_and_tangents[num_primals:]
fw_outs, grads = create_joint(
prepare_fw_with_masks(fn), aot_config=dummy_aot_config
)(primals, tangents)
maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
# return signature is deliberately kept (*grads, *fw_outs). This
# simplifies partitioning work later on.
return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs)))
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
joint_operands = [_from_fun(arg) for arg in joint_operands]
with contextlib.ExitStack() as stack:
stack.enter_context(
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
)
with torch.enable_grad():
return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
class InvokeSubgraphAutogradOp(torch.autograd.Function):
"""
This autograd function op is to stash the backward graph in the ctx while
running forward.
Saves the subgraph, i.e. original callable, in the forward method. And then
traces out a joint graph in the backward. This delaying of tracing in
backward, also called as lazy backward, ensures that the assumptions about
the grad_out strides and tensor-subclass-ness are already accounted for.
"""
@staticmethod
def forward(
ctx,
fw_graph,
bw_graph,
subgraph,
identifier,
num_fw_outs,
filter_tangent_info,
output_metadata,
*operands,
):
ctx._fw_graph = fw_graph
ctx._bw_graph = bw_graph
# We want to delay the backward graph construction until the backward.
# So in forward, we just run the fw callable as is. And save all the
# information necessary to construct the backward graph in the ctx.
ctx._subgraph = subgraph
ctx._identifier = identifier
ctx._num_fw_outs = num_fw_outs
ctx._filter_tangent_info = filter_tangent_info
ctx._output_metadata = output_metadata
# We snapshot the dispatch keys in forward for materializing the
# the bw_graph in backward.
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
save_tensors_and_symints_for_backward(ctx, operands)
with torch._C._AutoDispatchBelowAutograd():
out = invoke_subgraph(
fw_graph,
subgraph,
f"___forward_{identifier}",
operands,
)
save_tensors_and_symints_for_backward(ctx, operands)
# Check that None is at expected indexes.
for idx, o in enumerate(out):
if o is None:
assert idx in filter_tangent_info.indexes_with_none
assert idx in output_metadata.indexes_with_none
return out
@staticmethod
def backward(ctx, *grad_outs):
bw_graph = ctx._bw_graph
def backward(
ctx,
*grad_outs,
):
subgraph = ctx._subgraph
identifier = ctx._identifier
output_metadata = ctx._output_metadata
primals = saved_tensors_and_symints(ctx)
num_fw_outs = ctx._num_fw_outs
filter_tangent_info = ctx._filter_tangent_info
# While tracing we made the assumption that tangents are contiguous. So,
# force the grad_outs to be contiguous.
# Also filter out grads that are None or do not require_grad. This was
# Filter out grads that are None or do not require_grad. This was
# the assumption we made during the tracing of joint_graph.
contiguous_grad_outs = []
filtered_grad_outs = []
for idx, o in enumerate(grad_outs):
if o is None:
assert idx in filter_tangent_info.indexes_with_none
elif idx in filter_tangent_info.indexes_with_no_grad:
assert idx in output_metadata.indexes_with_none
elif idx in output_metadata.indexes_with_no_grad:
# Deliberately skip over the grad_outs which we know should be
# None because the corresponding fwd_out does not require_grad.
pass
else:
contiguous_grad_outs.append(o.contiguous())
contiguous_grad_outs = tuple(contiguous_grad_outs)
filtered_grad_outs.append(o)
filtered_grad_outs = tuple(filtered_grad_outs)
# bw_graph is a joint graph with signature (*primals_and_tangents) and
# returns (*grads_and_fw_outs). To get the grads, we use the num_fw_outs
# to extract the grads.
primals_and_tangents = primals + contiguous_grad_outs
primals_and_tangents = primals + filtered_grad_outs
# Check if we have already traced the bwd subgraph.
bw_graph = None
invoke_subgraph_cache = get_invoke_subgraph_cache()
if invoke_subgraph_cache:
bw_graph = invoke_subgraph_cache.get_lazy_bwd_entry(identifier)
if bw_graph is None:
bw_graph = trace_joint_graph_as_bwd(
subgraph,
len(primals),
primals_and_tangents,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
)
if invoke_subgraph_cache:
invoke_subgraph_cache.add_lazy_bwd_entry(identifier, bw_graph)
grads = invoke_subgraph(
bw_graph, f"___backward_{identifier}", primals_and_tangents
)[:-num_fw_outs]
return None, None, None, None, None, *grads
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
def _(subgraph, identifier, operands):
from torch.utils._python_dispatch import _get_current_dispatch_mode
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return subgraph(*operands)
)[: -output_metadata.num_fw_outs]
return None, None, None, *grads
@invoke_subgraph.py_impl(DispatchKey.Autograd)
@ -361,13 +460,11 @@ def _(subgraph, identifier, operands):
):
return saved_autograd_fn(*operands)
fw_graph, bw_graph, num_fw_outs, filter_tangent_info = create_fw_bw_graph(
subgraph, operands
)
output_metadata = get_output_metadata(subgraph, operands)
def autograd_fn_callable(*args):
return InvokeSubgraphAutogradOp.apply(
fw_graph, bw_graph, identifier, num_fw_outs, filter_tangent_info, *args
subgraph, identifier, output_metadata, *args
)
# Save the autograd_fn_callable in the dispatch set cache.
@ -377,6 +474,15 @@ def _(subgraph, identifier, operands):
return autograd_fn_callable(*operands)
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
def _(subgraph, identifier, operands):
from torch.utils._python_dispatch import _get_current_dispatch_mode
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return subgraph(*operands)
@invoke_subgraph.py_functionalize_impl
def _(ctx, subgraph, identifier, operands):
unwrapped_operands = ctx.unwrap_tensors(operands)