Compare commits

...

2 Commits

Author SHA1 Message Date
6e16bfcd3e add partition meta 2025-11-14 13:22:25 -08:00
d1f572a901 Add test for unbacked symint expression
Add backend node meta to invoke subgraph
2025-11-14 13:06:29 -08:00
6 changed files with 340 additions and 23 deletions

View File

@ -10,6 +10,10 @@ import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward
from torch._guards import detect_fake_mode
from torch._higher_order_ops.invoke_subgraph import (
NestedCompileBackend,
NestedCompileRegionOptions,
)
from torch._inductor.output_code import RegionalOutputCode
from torch._inductor.test_case import run_tests
from torch._inductor.utils import run_fw_bw_and_get_code
@ -468,6 +472,86 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
# flex in forward and flex_backward in backward
self.assertEqual(len(codes), 2)
@parametrize("serialize", [True, False])
def test_invoke_subgraph_regional_compile(self, serialize):
call_test_partitioner_ct = 0
original_default_partitioner = torch._functorch.partitioners.default_partition
def test_partitioner(
*args, **kwargs
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
nonlocal call_test_partitioner_ct
call_test_partitioner_ct += 1
return original_default_partitioner(*args, **kwargs)
# pyrefly: ignore [not-iterable]
if serialize:
# Callable cannot be serialized
torch._functorch.partitioners.default_partition = test_partitioner
partitioner = "default_partition"
else:
partitioner = test_partitioner
backend = NestedCompileRegionOptions(
backend=NestedCompileBackend.INDUCTOR,
inductor_configs={
"max_autotune": True,
"triton.cudagraphs": False,
},
partitioner=partitioner,
)
@torch.compiler.nested_compile_region(backend_options=backend)
def gn_with_backend(x):
return torch.sin(x)
@torch.compiler.nested_compile_region
def gn_without_backend(x):
return torch.cos(x)
def fn(x):
return gn_with_backend(x) + gn_without_backend(x)
backend = aot_eager_regional_inductor(serialize=serialize)
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
import torch._inductor.config as inductor_config
# Hook to verify options
original_compile = torch._inductor.standalone_compile
captured_options = []
def verify_options(*args, **kwargs):
options = kwargs.get("options", {})
captured_options.append(options)
# Verify config is set as expected from explicit options
assert inductor_config.max_autotune, "max_autotune should be True"
assert not inductor_config.triton.cudagraphs, (
"triton.cudagraphs should be False"
)
return original_compile(*args, **kwargs)
torch._inductor.standalone_compile = verify_options
try:
x = torch.randn(8, 8, requires_grad=True)
# opt_fn(x)
res, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
self.assertEqual(len(codes), 2)
self.assertTrue("repeated_subgraph0" in codes[0])
self.assertTrue("repeated_subgraph1" not in codes[0])
self.assertTrue("repeated_subgraph0" in codes[1])
self.assertTrue("repeated_subgraph1" not in codes[1])
self.assertEqual(call_test_partitioner_ct, 1)
true_res = fn(x)
self.assertEqual(res, true_res)
finally:
torch._inductor.standalone_compile = original_compile
torch._functorch.partitioners.default_partition = (
original_default_partitioner
)
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
class TestRegionalOutputCode(torch._inductor.test_case.TestCase):

View File

@ -21,6 +21,10 @@ from torch._dynamo.testing import (
InductorAndRecordGraphs,
normalize_gm,
)
from torch._higher_order_ops.invoke_subgraph import (
NestedCompileBackend,
NestedCompileRegionOptions,
)
from torch._higher_order_ops.schema import find_hop_schema
from torch._inductor import config as inductor_config
from torch._inductor.pattern_matcher import (
@ -1556,6 +1560,101 @@ class GraphModule(torch.nn.Module):
res = opt_fn(x)
self.assertEqual(ref, res)
def test_unbacked_expr(self):
@nested_compile_region
def gn(x):
return x + 1
def fn(c):
d = torch.concat([c, c], dim=0)
d = gn(d)
return d
c = torch.randn((64, 32))
torch._dynamo.decorators.mark_unbacked(c, 0)
ref = fn(c)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
res = opt_fn(c)
self.assertEqual(ref, res)
def test_grad_accumulation(self):
mod1 = torch.nn.Linear(8, 8)
mod2 = torch.nn.Linear(8, 8)
mod3 = torch.nn.Linear(8, 8)
@nested_compile_region
def gn(x):
return mod1(x) - mod2(x)
def fn(c):
d = gn(c) - mod3(c)
return d * 2
c = torch.randn((8, 8), requires_grad=True)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
res = opt_fn(c)
res.sum().backward()
# fw_add_nodes = backend.fw_graphs[0].graph.find_nodes(op="call_function", target = torch.ops.aten.add.Tensor)
# The gradient addition node for mod3 is not in the subgraph.
bw_add_nodes = backend.bw_graphs[0].graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Tensor
)
self.assertEqual(len(bw_add_nodes), 1)
subgraph_node = backend.bw_graphs[0].graph.find_nodes(op="get_attr")[0]
subgraph_name = subgraph_node.target
# The gradient addition node between mod1 and mode2 will be in the subgraph
bw_add_nodes = getattr(backend.bw_graphs[0], subgraph_name).graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Tensor
)
self.assertEqual(len(bw_add_nodes), 1)
def test_backend_parameter(self):
backend = NestedCompileRegionOptions(NestedCompileBackend.INDUCTOR)
# Test that backend parameter is properly set in node.meta
@nested_compile_region(backend_options=backend)
def gn_with_backend(x):
return torch.sin(x)
@nested_compile_region
def gn_without_backend(x):
return torch.cos(x)
def fn(x):
return gn_with_backend(x) + gn_without_backend(x)
backend = EagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=False)
opt_fn(x)
# Check that we captured the graph
self.assertEqual(len(backend.graphs), 1)
graph = backend.graphs[0]
# Find invoke_subgraph nodes and check their backend metadata
invoke_subgraph_nodes = [
node
for node in graph.graph.nodes
if node.op == "call_function"
and node.target == torch.ops.higher_order.invoke_subgraph
]
# We should have 2 invoke_subgraph calls
self.assertEqual(len(invoke_subgraph_nodes), 2)
# First invoke_subgraph (gn_with_backend) should have backend
self.assertIn("custom", invoke_subgraph_nodes[0].meta)
# Second invoke_subgraph (gn_without_backend) should have custom=None or no custom
backend_value = invoke_subgraph_nodes[1].meta.get("custom", None)
self.assertIsNone(backend_value)
def test_complex(self):
# Observed in Wan2.1
@nested_compile_region

View File

@ -20,6 +20,7 @@ their semantic behavior.
"""
import contextlib
import copy
import functools
import inspect
import itertools
@ -42,6 +43,10 @@ from torch._dynamo.variables.functions import UserFunctionVariable
from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable
from torch._dynamo.variables.tensor import SymNodeVariable
from torch._guards import Source
from torch._higher_order_ops.invoke_subgraph import (
NestedCompileBackend,
NestedCompileRegionOptions,
)
from torch._ops import HigherOrderOperator
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils import _pytree as pytree
@ -259,6 +264,7 @@ def _call_function_with_auto_output_flattening(
flat_example_value: Any,
body_r: Optional[VariableTracker],
graph_output_vts: VariableTracker | tuple[VariableTracker, ...],
backend_options: Optional[NestedCompileRegionOptions] = None,
) -> Optional[VariableTracker]:
"""
Create HOP call node and reproxify output VTs for HOPs with auto output semantics.
@ -285,14 +291,30 @@ def _call_function_with_auto_output_flattening(
from .builder import wrap_fx_proxy
# Store the invocation as a call
proxy = tx.output.create_proxy(
"call_function",
fn,
args=args,
kwargs=kwargs,
)
# Set backend metadata if provided
if backend_options is not None:
if "custom" not in proxy.node.meta:
proxy.node.meta["custom"] = {}
if backend_options.backend == NestedCompileBackend.INDUCTOR:
inductor_configs = {}
if backend_options.inductor_configs:
inductor_configs = copy.deepcopy(backend_options.inductor_configs)
proxy.node.meta["custom"]["compile_with_inductor"] = {
"inductor_configs": inductor_configs
}
if backend_options.partitioner is not None:
proxy.node.meta["custom"]["partitioner"] = backend_options.partitioner
flat_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn,
args=args,
kwargs=kwargs,
),
proxy=proxy,
example_value=flat_example_value,
)
@ -324,7 +346,13 @@ def _call_function_with_auto_output_flattening(
def _call_function_and_unflatten_output(
tx, fn, args, kwargs, flat_example_value, ret_spec, body_r
tx,
fn,
args,
kwargs,
flat_example_value,
ret_spec,
body_r,
):
from .builder import wrap_fx_proxy
@ -4235,6 +4263,18 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
],
)
# Extract backend from the function if it was decorated with nested_compile_region(backend=...)
backend_options = None
fn_var = args[0]
if hasattr(fn_var, "get_function"):
try:
fn = fn_var.get_function()
if hasattr(fn, "__marked_compile_region_backend__"):
backend_options = fn.__marked_compile_region_backend__
except Exception:
pass
p_args = (
p_args[0],
body_name,
@ -4248,6 +4288,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
example_value,
body_r,
body_graph_output_vts,
backend_options=backend_options,
)

View File

@ -675,6 +675,42 @@ def prepare_for_partitioner(mod, num_primals, num_fw_outputs):
return out
def _get_partition_fn(fw_hop_node, aot_config):
"""
Return either the default `partition_fn` in aot_config or a HOP specific partition
function.
If a HOP specific partition function is returned, used_hop_custom_partition is True.
See Note [InvokeSubgraphHOP Partitioner]
"""
used_hop_custom_partition = False
partition_fn: Callable[..., tuple[torch.fx.GraphModule, torch.fx.GraphModule]] = (
aot_config.partition_fn
)
if (
fw_hop_node.target == torch._higher_order_ops.invoke_subgraph
and "custom" in fw_hop_node.meta
and "partitioner" in fw_hop_node.meta["custom"]
):
hop_partition_fn = fw_hop_node.meta["custom"]["partitioner"]
if callable(hop_partition_fn):
partition_fn = hop_partition_fn # pyrefly: ignore [bad-assignment]
used_hop_custom_partition = True
else:
assert isinstance(hop_partition_fn, str)
match hop_partition_fn:
case "default_partition":
partition_fn = torch._functorch.partitioners.default_partition
case "min_cut_rematerialization_partition":
partition_fn = torch._functorch.partitioners.min_cut_rematerialization_partition
case _:
raise ValueError(
f"Unknown HOP partitioner config: {hop_partition_fn}"
)
return used_hop_custom_partition, partition_fn
def run_joint_graph_passes_on_hops(
joint_gm: torch.fx.GraphModule,
joint_inputs: Any,
@ -779,14 +815,27 @@ def run_joint_graph_passes_on_hops(
# TODO: invoke_subgraph should track which of its inputs static indices
# so it can propagate them to the partitioner (and use in cudagraphs)
static_lifetime_input_indices: list[int] = []
# Step 2) and 3) - Run joint graph passes and partitioner
new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn(
joint_hop_gm,
[],
num_fwd_outputs=num_fw_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
used_hop_custom_partition, partition_fn = _get_partition_fn(
fw_hop_node, aot_config
)
# Step 2) and 3) - Run joint graph passes and partitioner
try:
new_fw_hop_gm, new_bw_hop_gm = partition_fn(
joint_hop_gm,
[],
num_fwd_outputs=num_fw_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
except Exception as e:
if used_hop_custom_partition:
raise RuntimeError(
f"Error in custom partition function for invoke_subgraph node {fw_hop_node.name}: {e}"
) from e
else:
raise
# Save the new forward and backward graph modules
new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm
new_hop_graphs[identifier].new_bw_hop_gm = new_bw_hop_gm

View File

@ -1,9 +1,11 @@
# mypy: allow-untyped-defs
import contextlib
import enum
from collections.abc import Callable
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -36,10 +38,6 @@ from torch.fx.graph_module import GraphModule
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
if TYPE_CHECKING:
from collections.abc import Callable
invoke_subgraph_counter = 0
@ -53,6 +51,32 @@ class OutputMetadata:
indexes_with_no_grad: set[int] = field(default_factory=set)
class NestedCompileBackend(enum.Enum):
INDUCTOR = "inductor"
DEFAULT = "default"
@dataclass
class NestedCompileRegionOptions:
# If default, does nothing, inherient the torch.compile backend
# If "inductor", will add {"compile_with_inductor": {"inductor_configs":config}} to HOP node meta "custom"
# If "custom" already has "compile_with_inductor", this config will override
backend: NestedCompileBackend = NestedCompileBackend.DEFAULT
# If backend == "inductor", the configs
inductor_configs: Optional[dict[str, Any]] = None
# Note: [InvokeSubgraphHOP Partitioner]
# If not None, add "partitioner" to HOP node meta.
# If Callable, directly assign the callable, but the callable cannot be pickled
# If str, the options are "default_partition" and "min_cut_rematerialization_partition".
# The HOP joint graph will be partitioned using the corresponding functions in
# torch/_functorch/partitioners.py
partitioner: Optional[Callable | str] = None
# TODO: add decomposition function
class InvokeSubgraphHOP(HigherOrderOperator):
def __init__(self) -> None:
# Invoke subgraph does not have any state, it is just a wrapper over a
@ -153,7 +177,9 @@ def invoke_subgraph_placeholder(func, *args, **kwargs):
return func(*args, **kwargs)
def mark_compile_region(fn=None):
def mark_compile_region(
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
):
"""
This wrapper instructs torch.compile to compile the wrapped region once and
reuse the compiled artifact, instead of the usual way of aggressively
@ -161,6 +187,10 @@ def mark_compile_region(fn=None):
Under the hood, it tells TorchDynamo to use InvokeSubgraph HOP for the
region. For PyTorch eager, this is a no-op.
Args:
fn: The function to wrap
backend: Optional backend to use for compiling the subgraph
"""
def wrap(func):
@ -172,6 +202,7 @@ def mark_compile_region(fn=None):
return invoke_subgraph_placeholder(inner_func, *args, **kwargs)
inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined]
func.__marked_compile_region_backend__ = backend_options # type: ignore[attr-defined]
return inner

View File

@ -1,14 +1,21 @@
# mypy: allow-untyped-defs
import io
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
import torch
from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions
from . import config
try:
from typing import LiteralString
except ImportError:
from typing_extensions import LiteralString
if TYPE_CHECKING:
from ._cache import CacheInfo
@ -635,7 +642,9 @@ def skip_all_guards_unsafe(guard_entries):
return [False for entry in guard_entries]
def nested_compile_region(fn=None):
def nested_compile_region(
fn=None, backend_options: Optional[NestedCompileRegionOptions] = None
):
"""
Tells **``torch.compile``** that the marked set of operations forms a nested
compile region (which is often repeated in the full model) whose code can be
@ -644,8 +653,8 @@ def nested_compile_region(fn=None):
During **``torch.compile``** tracing, the compiler applies *hierarchical
compilation* with ``nested_compile_region``: it emits optimized code for the
marked region the first time it is encountered and re-emits (or stamps
out) the previously compiled code on every subsequent invocation. This can
marked region the first time it is encountered and re-emits (or "stamps
out") the previously compiled code on every subsequent invocation. This can
substantially reduce overall compile time for deeply-stacked,
structurally-identical components such as the transformer layers of a
large-language-model (LLM).
@ -659,13 +668,17 @@ def nested_compile_region(fn=None):
to reuse, it will transparently re-compile the region. Using it is
therefore *safe*: correctness is always preserved, and you pay the extra
compilation cost only when required.
Args:
fn: The function to wrap
backend: Optional backend to use for compiling the subgraph.
"""
from torch._higher_order_ops.invoke_subgraph import (
mark_compile_region as _mark_compile_region,
)
return _mark_compile_region(fn)
return _mark_compile_region(fn, backend_options=backend_options)
def load_compiled_function(file: io.IOBase) -> Callable[..., Any]: