mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[invoke_subgraph] Lazy backward (#150666)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150666 Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
78fe079c97
commit
6ea5514e04
@ -672,12 +672,19 @@ class HopSubgraphCache:
|
||||
@abstractmethod
|
||||
def get_proxy_dispatch_entry(self, identifier: str): ...
|
||||
|
||||
@abstractmethod
|
||||
def add_lazy_bwd_entry(self, identifier: str, gmod: torch.fx.GraphModule): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_lazy_bwd_entry(self, identifier: str): ...
|
||||
|
||||
|
||||
class InvokeSubgraphCache(HopSubgraphCache):
|
||||
def __init__(self) -> None:
|
||||
self.autograd_cache: dict[str, Callable] = {}
|
||||
self.proxy_dispatch_cache: dict[str, Callable] = {}
|
||||
self.dynamo_identifiers: dict[str, str] = {}
|
||||
self.lazy_bwd_cache: dict[str, torch.fx.GraphModule] = {}
|
||||
|
||||
def add_dynamo_identifier(self, cache_key: str, identifier: str):
|
||||
self.dynamo_identifiers[cache_key] = identifier
|
||||
@ -697,6 +704,12 @@ class InvokeSubgraphCache(HopSubgraphCache):
|
||||
def get_proxy_dispatch_entry(self, identifier: str):
|
||||
return self.proxy_dispatch_cache.get(identifier, None)
|
||||
|
||||
def add_lazy_bwd_entry(self, identifier: str, gmod: torch.fx.GraphModule):
|
||||
self.lazy_bwd_cache[identifier] = gmod
|
||||
|
||||
def get_lazy_bwd_entry(self, identifier: str):
|
||||
return self.lazy_bwd_cache.get(identifier, None)
|
||||
|
||||
|
||||
class HopDispatchSetCache:
|
||||
def __init__(self) -> None:
|
||||
|
||||
@ -151,9 +151,11 @@ class BaseHOPFunction(torch.autograd.Function):
|
||||
from .utils import _from_fun
|
||||
|
||||
fw_inputs = pytree.tree_map(_from_fun, operands)
|
||||
_, joint_graph, _, _ = create_fw_bw_graph(
|
||||
subgraph, fw_inputs, grad_outputs
|
||||
)
|
||||
(
|
||||
_,
|
||||
joint_graph,
|
||||
_,
|
||||
) = create_fw_bw_graph(subgraph, fw_inputs, grad_outputs)
|
||||
|
||||
# The joint graph returns (*grad_inputs, *fwd_outputs).
|
||||
# We only need the grad_inputs.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
|
||||
import contextlib
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
@ -42,7 +43,8 @@ invoke_subgraph_counter = 0
|
||||
# used to filter out grad_outs/tangents in the `backward` method of
|
||||
# InvokeSubgraphAutogradOp.
|
||||
@dataclass
|
||||
class FilterTangentInfo:
|
||||
class OutputMetadata:
|
||||
num_fw_outs: Optional[int] = None
|
||||
indexes_with_none: set[int] = field(default_factory=set)
|
||||
indexes_with_no_grad: set[int] = field(default_factory=set)
|
||||
|
||||
@ -144,6 +146,7 @@ def get_invoke_subgraph_cache():
|
||||
return cache
|
||||
|
||||
|
||||
# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra
|
||||
def trace_joint_graph(fn, fw_inputs, fw_outputs):
|
||||
"""
|
||||
Naively trace out a joint graph. This simplifies the reconstruction of joint
|
||||
@ -184,6 +187,7 @@ def trace_joint_graph(fn, fw_inputs, fw_outputs):
|
||||
return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
|
||||
|
||||
|
||||
# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra
|
||||
def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
@ -209,13 +213,14 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
||||
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
|
||||
# Also collect the indexes of no_grad in the output to filter out
|
||||
# the grad_outs in the `backward` method.
|
||||
filter_tangent_info = FilterTangentInfo()
|
||||
output_metadata = OutputMetadata()
|
||||
|
||||
output_metadata.num_fw_outs = num_fw_outs
|
||||
for idx, fw_out in enumerate(fw_outs):
|
||||
if fw_out is None:
|
||||
filter_tangent_info.indexes_with_none.add(idx)
|
||||
output_metadata.indexes_with_none.add(idx)
|
||||
elif not fw_out.requires_grad:
|
||||
filter_tangent_info.indexes_with_no_grad.add(idx)
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
|
||||
if grad_outputs is None:
|
||||
# Infer grad_outputs to be the same properties as the fw_outputs
|
||||
@ -253,88 +258,182 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
||||
fw_inputs,
|
||||
grad_outputs,
|
||||
)
|
||||
return fw_graph, bw_graph, num_fw_outs, filter_tangent_info
|
||||
return fw_graph, bw_graph, output_metadata
|
||||
|
||||
|
||||
def get_output_metadata(subgraph, operands):
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
# args are functional tensors, generate some example tensors
|
||||
fw_inputs = pytree.tree_map(_from_fun, operands)
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode(fw_inputs)
|
||||
context = (
|
||||
nullcontext()
|
||||
if fake_mode is None or fake_mode.shape_env is None
|
||||
else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
||||
)
|
||||
|
||||
with context:
|
||||
fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
|
||||
|
||||
num_fw_outs = len(fw_outs)
|
||||
|
||||
# Collect the indexes of none in the output to check that the grad
|
||||
# is None at the corresponding index in the backward. This check is
|
||||
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
|
||||
# Also collect the indexes of no_grad in the output to filter out
|
||||
# the grad_outs in the `backward` method.
|
||||
output_metadata = OutputMetadata()
|
||||
|
||||
output_metadata.num_fw_outs = num_fw_outs
|
||||
for idx, fw_out in enumerate(fw_outs):
|
||||
if fw_out is None:
|
||||
output_metadata.indexes_with_none.add(idx)
|
||||
elif not fw_out.requires_grad:
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
return output_metadata
|
||||
|
||||
|
||||
def trace_joint_graph_as_bwd(
|
||||
fn, num_primals, joint_operands, include_key_set, exclude_key_set
|
||||
):
|
||||
"""
|
||||
Naively trace out a joint graph. This simplifies the reconstruction of joint
|
||||
graph in the min-cut partitioner later on.
|
||||
"""
|
||||
from torch._functorch.aot_autograd import create_joint
|
||||
|
||||
dummy_aot_config = get_dummy_aot_autograd_config()
|
||||
|
||||
# This joint_fn is inserted as the backward graph as is. This simplifies the
|
||||
# min-cut partitioner work later on.
|
||||
# Input signature - (*primals, *tangents)
|
||||
# Output signature - (*grads, *fw_outs)
|
||||
# The output signature is deliberately kept grads first and fw_outs second.
|
||||
# Having grads first makes the min-cut partitioner HOP graph stitching
|
||||
# easier.
|
||||
def joint_fn(*primals_and_tangents):
|
||||
primals = primals_and_tangents[:num_primals]
|
||||
tangents = primals_and_tangents[num_primals:]
|
||||
|
||||
fw_outs, grads = create_joint(
|
||||
prepare_fw_with_masks(fn), aot_config=dummy_aot_config
|
||||
)(primals, tangents)
|
||||
|
||||
maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
|
||||
|
||||
# return signature is deliberately kept (*grads, *fw_outs). This
|
||||
# simplifies partitioning work later on.
|
||||
return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs)))
|
||||
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
joint_operands = [_from_fun(arg) for arg in joint_operands]
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
|
||||
)
|
||||
with torch.enable_grad():
|
||||
return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
|
||||
|
||||
|
||||
class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
"""
|
||||
This autograd function op is to stash the backward graph in the ctx while
|
||||
running forward.
|
||||
Saves the subgraph, i.e. original callable, in the forward method. And then
|
||||
traces out a joint graph in the backward. This delaying of tracing in
|
||||
backward, also called as lazy backward, ensures that the assumptions about
|
||||
the grad_out strides and tensor-subclass-ness are already accounted for.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
fw_graph,
|
||||
bw_graph,
|
||||
subgraph,
|
||||
identifier,
|
||||
num_fw_outs,
|
||||
filter_tangent_info,
|
||||
output_metadata,
|
||||
*operands,
|
||||
):
|
||||
ctx._fw_graph = fw_graph
|
||||
ctx._bw_graph = bw_graph
|
||||
# We want to delay the backward graph construction until the backward.
|
||||
# So in forward, we just run the fw callable as is. And save all the
|
||||
# information necessary to construct the backward graph in the ctx.
|
||||
ctx._subgraph = subgraph
|
||||
ctx._identifier = identifier
|
||||
ctx._num_fw_outs = num_fw_outs
|
||||
ctx._filter_tangent_info = filter_tangent_info
|
||||
ctx._output_metadata = output_metadata
|
||||
# We snapshot the dispatch keys in forward for materializing the
|
||||
# the bw_graph in backward.
|
||||
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
|
||||
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
||||
|
||||
save_tensors_and_symints_for_backward(ctx, operands)
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
out = invoke_subgraph(
|
||||
fw_graph,
|
||||
subgraph,
|
||||
f"___forward_{identifier}",
|
||||
operands,
|
||||
)
|
||||
|
||||
save_tensors_and_symints_for_backward(ctx, operands)
|
||||
|
||||
# Check that None is at expected indexes.
|
||||
for idx, o in enumerate(out):
|
||||
if o is None:
|
||||
assert idx in filter_tangent_info.indexes_with_none
|
||||
assert idx in output_metadata.indexes_with_none
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
bw_graph = ctx._bw_graph
|
||||
def backward(
|
||||
ctx,
|
||||
*grad_outs,
|
||||
):
|
||||
subgraph = ctx._subgraph
|
||||
identifier = ctx._identifier
|
||||
output_metadata = ctx._output_metadata
|
||||
primals = saved_tensors_and_symints(ctx)
|
||||
num_fw_outs = ctx._num_fw_outs
|
||||
filter_tangent_info = ctx._filter_tangent_info
|
||||
|
||||
# While tracing we made the assumption that tangents are contiguous. So,
|
||||
# force the grad_outs to be contiguous.
|
||||
# Also filter out grads that are None or do not require_grad. This was
|
||||
# Filter out grads that are None or do not require_grad. This was
|
||||
# the assumption we made during the tracing of joint_graph.
|
||||
contiguous_grad_outs = []
|
||||
filtered_grad_outs = []
|
||||
for idx, o in enumerate(grad_outs):
|
||||
if o is None:
|
||||
assert idx in filter_tangent_info.indexes_with_none
|
||||
elif idx in filter_tangent_info.indexes_with_no_grad:
|
||||
assert idx in output_metadata.indexes_with_none
|
||||
elif idx in output_metadata.indexes_with_no_grad:
|
||||
# Deliberately skip over the grad_outs which we know should be
|
||||
# None because the corresponding fwd_out does not require_grad.
|
||||
pass
|
||||
else:
|
||||
contiguous_grad_outs.append(o.contiguous())
|
||||
contiguous_grad_outs = tuple(contiguous_grad_outs)
|
||||
filtered_grad_outs.append(o)
|
||||
filtered_grad_outs = tuple(filtered_grad_outs)
|
||||
|
||||
# bw_graph is a joint graph with signature (*primals_and_tangents) and
|
||||
# returns (*grads_and_fw_outs). To get the grads, we use the num_fw_outs
|
||||
# to extract the grads.
|
||||
primals_and_tangents = primals + contiguous_grad_outs
|
||||
primals_and_tangents = primals + filtered_grad_outs
|
||||
|
||||
# Check if we have already traced the bwd subgraph.
|
||||
bw_graph = None
|
||||
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
||||
if invoke_subgraph_cache:
|
||||
bw_graph = invoke_subgraph_cache.get_lazy_bwd_entry(identifier)
|
||||
|
||||
if bw_graph is None:
|
||||
bw_graph = trace_joint_graph_as_bwd(
|
||||
subgraph,
|
||||
len(primals),
|
||||
primals_and_tangents,
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
)
|
||||
|
||||
if invoke_subgraph_cache:
|
||||
invoke_subgraph_cache.add_lazy_bwd_entry(identifier, bw_graph)
|
||||
|
||||
grads = invoke_subgraph(
|
||||
bw_graph, f"___backward_{identifier}", primals_and_tangents
|
||||
)[:-num_fw_outs]
|
||||
return None, None, None, None, None, *grads
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def _(subgraph, identifier, operands):
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
||||
return subgraph(*operands)
|
||||
)[: -output_metadata.num_fw_outs]
|
||||
return None, None, None, *grads
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(DispatchKey.Autograd)
|
||||
@ -361,13 +460,11 @@ def _(subgraph, identifier, operands):
|
||||
):
|
||||
return saved_autograd_fn(*operands)
|
||||
|
||||
fw_graph, bw_graph, num_fw_outs, filter_tangent_info = create_fw_bw_graph(
|
||||
subgraph, operands
|
||||
)
|
||||
output_metadata = get_output_metadata(subgraph, operands)
|
||||
|
||||
def autograd_fn_callable(*args):
|
||||
return InvokeSubgraphAutogradOp.apply(
|
||||
fw_graph, bw_graph, identifier, num_fw_outs, filter_tangent_info, *args
|
||||
subgraph, identifier, output_metadata, *args
|
||||
)
|
||||
|
||||
# Save the autograd_fn_callable in the dispatch set cache.
|
||||
@ -377,6 +474,15 @@ def _(subgraph, identifier, operands):
|
||||
return autograd_fn_callable(*operands)
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def _(subgraph, identifier, operands):
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
||||
return subgraph(*operands)
|
||||
|
||||
|
||||
@invoke_subgraph.py_functionalize_impl
|
||||
def _(ctx, subgraph, identifier, operands):
|
||||
unwrapped_operands = ctx.unwrap_tensors(operands)
|
||||
|
||||
Reference in New Issue
Block a user