Compare commits

...

1 Commits

Author SHA1 Message Date
6875071248 compiler pipeline (#166198)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/166198

Test Plan: added tests

Differential Revision: D85466034
2025-10-30 16:04:59 -07:00
3 changed files with 1030 additions and 1 deletions

View File

@ -0,0 +1,462 @@
# Owner(s): ["oncall: pt2"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch._decomp import decomposition_table
from torch._export import CompilerPipeline
from torch._functorch.partitioners import default_partition
from torch.testing._internal.common_utils import run_tests, TestCase
class TestCompilerPipeline(TestCase):
def test_simple_linear_stage_by_stage(self):
"""Test calling CompilerPipeline methods individually."""
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
return self.linear(x)
model = SimpleLinear()
inputs = (torch.randn(4, 3, requires_grad=True),)
# Step 1: Create CompilerPipeline
pipeline = CompilerPipeline(model, inputs)
with pipeline.stack:
# Step 2: Generate ATen IR for training
pipeline.generate_aten_ir_for_training(decompositions=decomposition_table)
def nop_compiler(gm, args):
return gm.forward
# Step 3: Partition the joint graph
partition_output = pipeline.partition(
default_partition,
pipeline.aot_graph_capture.graph_module,
pipeline.aot_graph_capture.updated_flat_args,
)
# Step 4: Compile forward
fw_compile_output = pipeline.fw_compile(
nop_compiler,
partition_output.fw_module,
partition_output.adjusted_flat_args,
partition_output.num_fw_outs_saved_for_bw,
)
# Step 5: Compile backward
bw_compile_output = pipeline.bw_compile(
nop_compiler,
partition_output.bw_module,
fw_compile_output.fwd_output_strides,
partition_output.num_symints_saved_for_bw,
)
# Step 5: Create the final autograd function (with clean calling convention)
model_fn = pipeline.make_autograd_function(
flat_args=pipeline.aot_state.flat_args,
wrappers=pipeline.aot_graph_capture.wrappers,
compiled_fw_func=fw_compile_output.compiled_fw_func,
compiled_bw_func=bw_compile_output.compiled_bw_func,
lazy_backward_info=bw_compile_output.lazy_backward_info,
indices_of_inps_to_detach=partition_output.indices_of_inps_to_detach,
num_symints_saved_for_bw=partition_output.num_symints_saved_for_bw,
)
# Test functional correctness: model_fn should produce same results as original model
# Test forward
expected_output = model(*inputs)
actual_output = model_fn(*inputs)
torch.testing.assert_close(actual_output, expected_output)
# Test backward - check that gradients match
# Create fresh inputs for both eager and compiled
inputs_eager = (torch.randn(4, 3, requires_grad=True),)
inputs_compiled = (inputs_eager[0].detach().clone().requires_grad_(True),)
# Run eager backward
out_eager = model(*inputs_eager)
out_eager.sum().backward()
# Run compiled backward
out_compiled = model_fn(*inputs_compiled)
out_compiled.sum().backward()
# Compare gradients for input
torch.testing.assert_close(inputs_eager[0].grad, inputs_compiled[0].grad)
# Compare gradients for parameters (note: pipeline.graph_module has the parameters)
for (name_eager, param_eager), (name_compiled, param_compiled) in zip(
model.named_parameters(), pipeline.graph_module.named_parameters()
):
self.assertEqual(name_eager, name_compiled)
torch.testing.assert_close(param_eager.grad, param_compiled.grad)
def test_conv_bn_stage_by_stage(self):
"""Test CompilerPipeline stage-by-stage with conv+batchnorm model."""
class ConvBN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 3, 3, padding=1)
self.bn = nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return torch.relu(x)
model = ConvBN()
inputs = (torch.randn(2, 1, 4, 4, requires_grad=True),)
# Step 1: Create CompilerPipeline
pipeline = CompilerPipeline(model, inputs)
with pipeline.stack:
# Step 2: Generate ATen IR for training
pipeline.generate_aten_ir_for_training(decompositions=decomposition_table)
# Step 3: Partition
partition_output = pipeline.partition(
default_partition,
pipeline.aot_graph_capture.graph_module,
pipeline.aot_graph_capture.updated_flat_args,
)
from torch._functorch.aot_autograd import boxed_nop_preserve_node_meta
# Step 4: Compile forward and backward
fw_compile_output = pipeline.fw_compile(
boxed_nop_preserve_node_meta,
partition_output.fw_module,
partition_output.adjusted_flat_args,
partition_output.num_fw_outs_saved_for_bw,
)
bw_compile_output = pipeline.bw_compile(
boxed_nop_preserve_node_meta,
partition_output.bw_module,
fw_compile_output.fwd_output_strides,
partition_output.num_symints_saved_for_bw,
)
# Step 5: Create the final autograd function (with clean calling convention)
model_fn = pipeline.make_autograd_function(
flat_args=pipeline.aot_state.flat_args,
wrappers=pipeline.aot_graph_capture.wrappers,
compiled_fw_func=fw_compile_output.compiled_fw_func,
compiled_bw_func=bw_compile_output.compiled_bw_func,
lazy_backward_info=bw_compile_output.lazy_backward_info,
indices_of_inps_to_detach=partition_output.indices_of_inps_to_detach,
num_symints_saved_for_bw=partition_output.num_symints_saved_for_bw,
)
# Test forward
expected_output = model(*inputs)
actual_output = model_fn(*inputs)
torch.testing.assert_close(actual_output, expected_output)
# Test backward - check that gradients match
# Create fresh inputs for both eager and compiled
inputs_eager = (torch.randn(2, 1, 4, 4, requires_grad=True),)
inputs_compiled = (inputs_eager[0].detach().clone().requires_grad_(True),)
# Run eager backward
out_eager = model(*inputs_eager)
out_eager.sum().backward()
# Run compiled backward
out_compiled = model_fn(*inputs_compiled)
out_compiled.sum().backward()
# Compare gradients for input
torch.testing.assert_close(inputs_eager[0].grad, inputs_compiled[0].grad)
# Compare gradients for parameters (note: pipeline.graph_module has the parameters)
for (name_eager, param_eager), (name_compiled, param_compiled) in zip(
model.named_parameters(), pipeline.graph_module.named_parameters()
):
self.assertEqual(name_eager, name_compiled)
torch.testing.assert_close(param_eager.grad, param_compiled.grad)
def test_simple_linear_with_structured_io(self):
"""Test calling CompilerPipeline with structured dict input and tuple output."""
class SimpleLinearStructuredIO(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3, 2)
self.linear2 = nn.Linear(4, 2)
def forward(self, inputs):
# Take a dict with two tensors and return a tuple
x = self.linear1(inputs["x"])
y = self.linear2(inputs["y"])
return (x + y, x - y, x * y)
model = SimpleLinearStructuredIO()
inputs = (
{
"x": torch.randn(4, 3, requires_grad=True),
"y": torch.randn(4, 4, requires_grad=True),
},
)
# Step 1: Create CompilerPipeline
pipeline = CompilerPipeline(model, inputs)
with pipeline.stack:
# Step 2: Generate ATen IR for training
pipeline.generate_aten_ir_for_training(decompositions=decomposition_table)
def nop_compiler(gm, args):
return gm.forward
# Step 3: Partition the joint graph
partition_output = pipeline.partition(
default_partition,
pipeline.aot_graph_capture.graph_module,
pipeline.aot_graph_capture.updated_flat_args,
)
# Step 4: Compile forward
fw_compile_output = pipeline.fw_compile(
nop_compiler,
partition_output.fw_module,
partition_output.adjusted_flat_args,
partition_output.num_fw_outs_saved_for_bw,
)
# Step 5: Compile backward
bw_compile_output = pipeline.bw_compile(
nop_compiler,
partition_output.bw_module,
fw_compile_output.fwd_output_strides,
partition_output.num_symints_saved_for_bw,
)
# Step 5: Create the final autograd function (with clean calling convention)
model_fn = pipeline.make_autograd_function(
flat_args=pipeline.aot_state.flat_args,
wrappers=pipeline.aot_graph_capture.wrappers,
compiled_fw_func=fw_compile_output.compiled_fw_func,
compiled_bw_func=bw_compile_output.compiled_bw_func,
lazy_backward_info=bw_compile_output.lazy_backward_info,
indices_of_inps_to_detach=partition_output.indices_of_inps_to_detach,
num_symints_saved_for_bw=partition_output.num_symints_saved_for_bw,
)
# Test functional correctness: model_fn should preserve dict calling convention and tuple output
# Test forward with dict input and tuple output
expected_output = model(*inputs)
actual_output = model_fn(*inputs)
# Verify we got a tuple with 2 elements
self.assertIsInstance(expected_output, tuple)
self.assertIsInstance(actual_output, tuple)
self.assertEqual(len(expected_output), 3)
self.assertEqual(len(actual_output), 3)
# Verify each element of the tuple matches
for actual, expected in zip(actual_output, expected_output):
torch.testing.assert_close(actual, expected)
# Test backward - check that gradients match
# Create fresh inputs for both eager and compiled
inputs_eager = (
{
"x": torch.randn(4, 3, requires_grad=True),
"y": torch.randn(4, 4, requires_grad=True),
},
)
inputs_compiled = (
{
"x": inputs_eager[0]["x"].detach().clone().requires_grad_(True),
"y": inputs_eager[0]["y"].detach().clone().requires_grad_(True),
},
)
# Run eager backward (sum over both tuple outputs)
out_eager = model(*inputs_eager)
sum(x.sum() for x in out_eager).backward()
# Run compiled backward (sum over both tuple outputs)
out_compiled = model_fn(*inputs_compiled)
sum(x.sum() for x in out_compiled).backward()
# Compare gradients for inputs
for input_eager, input_compiled in zip(
inputs_eager[0].values(), inputs_compiled[0].values()
):
torch.testing.assert_close(input_eager.grad, input_compiled.grad)
# Compare gradients for parameters (note: pipeline.graph_module has the parameters)
for (name_eager, param_eager), (name_compiled, param_compiled) in zip(
model.named_parameters(), pipeline.graph_module.named_parameters()
):
self.assertEqual(name_eager, name_compiled)
torch.testing.assert_close(param_eager.grad, param_compiled.grad)
def test_simple_linear_inference_with_structured_io(self):
"""Test inference compilation path with structured dict input and tuple output."""
class SimpleLinearStructuredIO(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3, 2)
self.linear2 = nn.Linear(4, 2)
def forward(self, inputs):
# Take a dict with two tensors and return a tuple
x = self.linear1(inputs["x"])
y = self.linear2(inputs["y"])
return (x + y, x - y, x * y)
model = SimpleLinearStructuredIO()
# NOTE: inputs do NOT require grad for inference path
inputs = ({"x": torch.randn(4, 3), "y": torch.randn(4, 4)},)
# Step 1: Create CompilerPipeline
pipeline = CompilerPipeline(model, inputs)
with pipeline.stack:
# Step 2: Generate ATen IR for inference
# NOTE: params and buffers do NOT require grad for inference path
pipeline.generate_aten_ir_for_inference(decompositions=decomposition_table)
def nop_compiler(gm, args):
return gm.forward
# Step 3: Compile for inference (no partitioning needed)
compiled_fw = pipeline.inference_compile(
nop_compiler,
pipeline.aot_graph_capture.graph_module,
pipeline.aot_graph_capture.updated_flat_args,
)
# Step 4: Create the final inference function (with clean calling convention)
model_fn = pipeline.make_inference_function(
compiled_fw=compiled_fw,
wrappers=pipeline.aot_graph_capture.wrappers,
entry=None,
)
# Test functional correctness: model_fn should preserve dict calling convention and tuple output
# Test forward with dict input and tuple output
expected_output = model(*inputs)
actual_output = model_fn(*inputs)
# Verify we got a tuple with 3 elements
self.assertIsInstance(expected_output, tuple)
self.assertIsInstance(actual_output, tuple)
self.assertEqual(len(expected_output), 3)
self.assertEqual(len(actual_output), 3)
# Verify each element of the tuple matches
for actual, expected in zip(actual_output, expected_output):
torch.testing.assert_close(actual, expected)
def test_simple_linear_with_identity_passes(self):
"""Test applying identity passes between compilation stages."""
class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
return self.linear(x)
model = SimpleModule()
inputs = (torch.randn(4, 3, requires_grad=True),)
# Identity pass that returns the graph unchanged
def identity_pass(gm):
"""An identity pass that marks the graph as processed."""
gm._identity_pass_applied = True
return gm
# Create CompilerPipeline
pipeline = CompilerPipeline(model, inputs)
with pipeline.stack:
pipeline.generate_aten_ir_for_training(decompositions=decomposition_table)
# Apply identity pass to joint graph
pipeline.aot_graph_capture.graph_module = identity_pass(
pipeline.aot_graph_capture.graph_module
)
self.assertTrue(
hasattr(
pipeline.aot_graph_capture.graph_module, "_identity_pass_applied"
)
)
def nop_compiler(gm, args):
return gm.forward
# Partition
partition_output = pipeline.partition(
default_partition,
pipeline.aot_graph_capture.graph_module,
pipeline.aot_graph_capture.updated_flat_args,
)
# Apply identity pass to forward and backward modules
partition_output.fw_module = identity_pass(partition_output.fw_module)
partition_output.bw_module = identity_pass(partition_output.bw_module)
self.assertTrue(
hasattr(partition_output.fw_module, "_identity_pass_applied")
)
self.assertTrue(
hasattr(partition_output.bw_module, "_identity_pass_applied")
)
# Continue compilation
fw_compile_output = pipeline.fw_compile(
nop_compiler,
partition_output.fw_module,
partition_output.adjusted_flat_args,
partition_output.num_fw_outs_saved_for_bw,
)
bw_compile_output = pipeline.bw_compile(
nop_compiler,
partition_output.bw_module,
fw_compile_output.fwd_output_strides,
partition_output.num_symints_saved_for_bw,
)
# Create the final autograd function (with clean calling convention)
model_fn = pipeline.make_autograd_function(
flat_args=pipeline.aot_state.flat_args,
wrappers=pipeline.aot_graph_capture.wrappers,
compiled_fw_func=fw_compile_output.compiled_fw_func,
compiled_bw_func=bw_compile_output.compiled_bw_func,
lazy_backward_info=bw_compile_output.lazy_backward_info,
indices_of_inps_to_detach=partition_output.indices_of_inps_to_detach,
num_symints_saved_for_bw=partition_output.num_symints_saved_for_bw,
)
# Test functional correctness: model_fn should produce same results as original model
expected_output = model(*inputs)
actual_output = model_fn(*inputs)
torch.testing.assert_close(actual_output, expected_output)
if __name__ == "__main__":
run_tests()

View File

@ -816,3 +816,77 @@ def _dynamo_graph_capture_for_export(
return transformed_graph
return inner
def _clear_traced_params_buffers(
traced_module: torch.fx.GraphModule, const_keys: set[str]
) -> None:
"""Remove all parameters and buffers from traced module before restoring."""
for key in const_keys:
assert key in traced_module._buffers.keys()
# We don't want constants to show up as a buffer in the state dict.
# Instead they should just be a direct attribute.
buffer = getattr(traced_module, key)
torch.fx.graph_module._del_attr(traced_module, key)
setattr(traced_module, key, buffer)
def _restore_state_dict(
original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
) -> None:
"""
TODO: move this into torch.export
Restores the state dict of the traced module to match the original module exactly.
Preserves the original FQNs with dots, creating intermediate empty modules as needed.
Ensures that the ordering of parameters/buffers matches the original module.
"""
# Build ID-based lookups for traced module params/buffers
traced_params: dict[int, tuple[str, torch.nn.Parameter]] = {}
for name, param in traced_module.named_parameters(remove_duplicate=False):
traced_params[id(param)] = (name, param)
traced_buffers: dict[int, tuple[str, torch.Tensor]] = {}
for name, buffer in traced_module.named_buffers(remove_duplicate=False):
traced_buffers[id(buffer)] = (name, buffer)
# Build mapping from old names to new names for graph node updates
name_mapping: dict[str, str] = {}
# Restore parameters in the order they appear in original module
for orig_name, orig_param in original_module.named_parameters(
remove_duplicate=False
):
if id(orig_param) in traced_params:
# This param exists in traced module - restore it with original FQN
traced_name, traced_param = traced_params[id(orig_param)]
torch.fx.graph_module._assign_attr(traced_param, traced_module, orig_name)
torch.fx.graph_module._del_attr(traced_module, traced_name)
name_mapping[traced_name] = orig_name
else:
# This param doesn't exist in traced module - add it
torch.fx.graph_module._assign_attr(orig_param, traced_module, orig_name)
# Restore buffers in the order they appear in original module
for orig_name, orig_buffer in original_module.named_buffers(remove_duplicate=False):
if id(orig_buffer) in traced_buffers:
# This buffer exists in traced module - restore it with original FQN
traced_name, traced_buffer = traced_buffers[id(orig_buffer)]
torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name)
name_mapping[traced_name] = orig_name
torch.fx.graph_module._del_attr(traced_module, traced_name)
else:
# This buffer doesn't exist in traced module - add it
torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name)
param_names = [v[0] for v in traced_params.values()]
buffer_names = [v[0] for v in traced_buffers.values()]
const_keys = set(param_names + buffer_names).difference(set(name_mapping.keys()))
_clear_traced_params_buffers(traced_module, const_keys)
# Update get_attr nodes in the graph to use the correct FQNs
for node in traced_module.graph.nodes:
if node.op == "get_attr" and node.target in name_mapping:
node.target = name_mapping[node.target]
traced_module.recompile()

View File

@ -13,7 +13,7 @@ import warnings
import weakref
import zipfile
from collections import OrderedDict
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager
from functools import lru_cache
from typing import Any, Optional, TYPE_CHECKING, Union
@ -55,6 +55,499 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
# Dataclasses for compiler pipeline outputs
@dataclasses.dataclass
class PartitionOutput:
"""Output from the partition stage of the joint graph compiler."""
fw_module: torch.fx.GraphModule
bw_module: torch.fx.GraphModule
num_fw_outs_saved_for_bw: int
num_symints_saved_for_bw: int
indices_of_inps_to_detach: list[int]
adjusted_flat_args: list[Any]
@dataclasses.dataclass
class ForwardCompileOutput:
"""Output from the forward compilation stage of the joint graph compiler."""
fwd_output_strides: Optional[list[Optional[tuple[int, ...]]]]
compiled_fw_func: Callable
@dataclasses.dataclass
class BackwardCompileOutput:
"""Output from the backward compilation stage of the joint graph compiler."""
lazy_backward_info: Any # AutogradLazyBackwardCompileInfo - avoid circular import
compiled_bw_func: Optional[Callable]
def _make_callable(
compiled_fn: Callable,
gm: torch.fx.GraphModule,
params_spec: list[str],
buffers_spec: list[str],
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
) -> Callable:
"""
Wrap the compiled function to provide a cleaner calling convention.
The compiled function expects flat args: [*params, *buffers, *flat_inputs]
This wrapper allows calling with just: (*inputs) where inputs can be structured
Args:
compiled_fn: The compiled function from make_autograd_function
gm: The graph module from capture_graph (should have _restore_state_dict called on it)
params_spec: List of parameter FQNs in order
buffers_spec: List of buffer FQNs in order
in_spec: Input pytree spec for flattening structured inputs
out_spec: Output pytree spec for unflattening structured outputs
Returns:
A callable that takes structured user inputs and returns structured outputs
"""
# Get parameter and buffer dictionaries from graph module
params_dict = dict(gm.named_parameters())
buffers_dict = dict(gm.named_buffers())
# Look up params and buffers by FQN in the order specified by specs
params = [params_dict[fqn] for fqn in params_spec]
buffers = [buffers_dict[fqn] for fqn in buffers_spec]
def wrapper(*args, **kwargs):
# Flatten the inputs using in_spec to handle structured inputs (dicts, tuples, etc.)
# The in_spec includes params/buffers/inputs, so we reconstruct the full input structure
# and flatten just the user inputs portion
user_inputs_flat, _ = pytree.tree_flatten((args, kwargs))
# Construct the full flat args list
flat_args = [*params, *buffers, *user_inputs_flat]
# Call the compiled function
flat_outputs = compiled_fn(flat_args)
# Unflatten outputs using out_spec to handle structured outputs
return pytree.tree_unflatten(flat_outputs, out_spec)
return wrapper
class CompilerPipeline:
"""
A unified pipeline for graph capture, joint graph generation, and compilation.
This class provides an end-to-end API for:
1. Capturing a graph from an nn.Module using Dynamo
2. Generating a joint forward-backward graph with descriptors
3. Partitioning the joint graph into forward and backward graphs
4. Compiling forward and backward graphs
5. Creating autograd or inference functions
"""
def __init__(
self,
model: torch.nn.Module,
inputs: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
):
"""
Initialize the CompilerPipeline.
Args:
model: The nn.Module to capture and compile
inputs: Example input tensors for tracing
kwargs: Optional keyword arguments for the model
NOTE: You must use the following pattern to use this class:
pipeline = CompilerPipeline(...) # capture graph
with pipeline.stack:
# call pipeline methods to compile
"""
self.model = model
self.inputs = inputs
self.kwargs = kwargs if kwargs is not None else {}
self._gm = self.capture_graph()
self.stack = ExitStack()
self.joint_with_descriptors: Optional[Any] = None # JointWithDescriptors
# Internal state for compilation phases
self._aot_config: Optional[Any] = None
self._fw_metadata: Optional[Any] = None
self._maybe_subclass_meta: Optional[Any] = None
def capture_graph(self) -> torch.fx.GraphModule:
"""
Capture the graph from the model using Dynamo graph capture.
Returns:
The captured GraphModule
"""
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export, _restore_state_dict
# Capture graph for export
with torch._dynamo.config.patch(install_free_tensors=True):
assert isinstance(self.model, torch.nn.Module)
gm = _dynamo_graph_capture_for_export(self.model)(*self.inputs, **self.kwargs)
# Restore state dict to the captured graph module
_restore_state_dict(self.model, gm)
return gm
def generate_aten_ir_for_training(
self,
decompositions: Optional[dict] = None,
keep_inference_input_mutations: bool = False,
ignore_shape_env: bool = False,
disable_functionalization: bool = False,
) -> Any: # JointWithDescriptors
"""
Generate ATen IR for training (joint forward-backward graph with descriptors).
Args:
decompositions: Optional decomposition table for operations
keep_inference_input_mutations: Whether to keep input mutations in inference mode
ignore_shape_env: Whether to ignore shape environment
disable_functionalization: Whether to disable functionalization
Returns:
JointWithDescriptors object containing the joint graph and metadata
"""
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
if self._gm is None:
self._gm = self.capture_graph()
self.stack.enter_context(self._gm.meta["fake_mode"]) # type: ignore[union-attr]
# Export joint graph with descriptors
self.joint_with_descriptors = aot_export_joint_with_descriptors(
self.stack,
self._gm, # type: ignore[arg-type]
self.inputs,
self.kwargs,
decompositions=decompositions,
keep_inference_input_mutations=keep_inference_input_mutations,
ignore_shape_env=ignore_shape_env,
disable_functionalization=disable_functionalization,
)
return self.joint_with_descriptors
def generate_aten_ir_for_inference(
self,
decompositions: Optional[dict] = None,
keep_inference_input_mutations: bool = False,
ignore_shape_env: bool = False,
disable_functionalization: bool = False,
) -> Any: # JointWithDescriptors
"""
Generate ATen IR for inference (forward graph with descriptors).
Args:
decompositions: Optional decomposition table for operations
keep_inference_input_mutations: Whether to keep input mutations in inference mode
ignore_shape_env: Whether to ignore shape environment
disable_functionalization: Whether to disable functionalization
Returns:
JointWithDescriptors object containing the forward graph and metadata
"""
with torch.no_grad():
return self.generate_aten_ir_for_training(
decompositions=decompositions,
keep_inference_input_mutations=keep_inference_input_mutations,
ignore_shape_env=ignore_shape_env,
disable_functionalization=disable_functionalization,
)
def _ensure_state_initialized(self):
"""Initialize internal state from joint_with_descriptors if not already done."""
if self._aot_config is None:
if self.joint_with_descriptors is None:
raise RuntimeError("Must call generate_aten_ir_for_training() or generate_aten_ir_for_inference() first")
aot_state = self.joint_with_descriptors._aot_state
aot_graph_capture = self.joint_with_descriptors._aot_graph_capture
self._aot_config = aot_state.aot_config
self._fw_metadata = aot_state.fw_metadata
self._maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta
def partition(
self,
partition_fn: Callable,
fx_g: torch.fx.GraphModule,
joint_inputs: Union[list[Any], tuple[list[Any], list[Any]]],
) -> PartitionOutput:
"""
Partition the joint graph into a forward graph and a backward graph.
Args:
partition_fn: Partition function to use
fx_g: The joint graph module to partition
joint_inputs: Flattened inputs to the joint graph
Returns:
PartitionOutput containing the partitioned forward and backward modules
"""
self._ensure_state_initialized()
self._aot_config.partition_fn = partition_fn # type: ignore[union-attr]
from torch._functorch._aot_autograd.graph_compile import _aot_stage2a_partition
result = _aot_stage2a_partition(
fx_g, joint_inputs, self._maybe_subclass_meta, self._fw_metadata, self._aot_config # type: ignore[arg-type]
)
return PartitionOutput(
fw_module=result[0],
bw_module=result[1],
num_fw_outs_saved_for_bw=result[2],
num_symints_saved_for_bw=result[3],
indices_of_inps_to_detach=result[4],
adjusted_flat_args=result[5],
)
def fw_compile(
self,
fw_compiler: Callable,
fw_module: torch.fx.GraphModule,
adjusted_flat_args: list[Any],
num_fw_outs_saved_for_bw: int,
) -> ForwardCompileOutput:
"""
Compile the forward graph.
Args:
fw_compiler: Compiler function to use for forward graph
fw_module: The forward graph module to compile
adjusted_flat_args: Flattened arguments after adjustments
num_fw_outs_saved_for_bw: Number of forward outputs saved for backward
Returns:
ForwardCompileOutput containing strides and compiled forward function
"""
self._ensure_state_initialized()
self._aot_config.fw_compiler = fw_compiler # type: ignore[union-attr]
from torch._functorch._aot_autograd.graph_compile import _aot_stage2b_fw_compile
result = _aot_stage2b_fw_compile(
fw_module,
adjusted_flat_args,
self._maybe_subclass_meta,
self._fw_metadata, # type: ignore[arg-type]
num_fw_outs_saved_for_bw,
self._aot_config, # type: ignore[arg-type]
)
return ForwardCompileOutput(
fwd_output_strides=result[0],
compiled_fw_func=result[1],
)
def bw_compile(
self,
bw_compiler: Callable,
bw_module: torch.fx.GraphModule,
fwd_output_strides: Optional[list[Optional[tuple[int, ...]]]],
num_symints_saved_for_bw: int,
) -> BackwardCompileOutput:
"""
Compile the backward graph.
Args:
bw_compiler: Compiler function to use for backward graph
bw_module: The backward graph module to compile
fwd_output_strides: Output strides from forward compilation
num_symints_saved_for_bw: Number of symbolic ints saved for backward
Returns:
BackwardCompileOutput containing lazy backward info and compiled backward function
"""
self._ensure_state_initialized()
self._aot_config.bw_compiler = bw_compiler # type: ignore[union-attr]
from torch._functorch._aot_autograd.graph_compile import _aot_stage2b_bw_compile
result = _aot_stage2b_bw_compile(
bw_module,
self._maybe_subclass_meta,
self._fw_metadata, # type: ignore[arg-type]
fwd_output_strides,
num_symints_saved_for_bw,
self._aot_config, # type: ignore[arg-type]
)
return BackwardCompileOutput(
lazy_backward_info=result[0],
compiled_bw_func=result[1],
)
def inference_compile(
self,
inference_compiler: Callable,
fw_module: torch.fx.GraphModule,
updated_flat_args: list[Any],
) -> Callable:
"""
Compile the inference graph (no autograd).
Args:
inference_compiler: Compiler function to use for inference graph
fw_module: The forward/inference graph module to compile
updated_flat_args: Flattened arguments after adjustments
Returns:
Compiled inference function
"""
self._ensure_state_initialized()
self._aot_config.inference_compiler = inference_compiler # type: ignore[union-attr]
from torch._functorch._aot_autograd.graph_compile import _aot_stage2b_inference_compile
return _aot_stage2b_inference_compile(
fw_module,
updated_flat_args,
self._maybe_subclass_meta,
self._fw_metadata, # type: ignore[arg-type]
self._aot_config,
)
def make_inference_function(
self,
compiled_fw: Callable,
wrappers: list[Any], # list[CompilerWrapper]
entry: Optional[Any], # Optional[GenericAOTAutogradCacheEntry]
) -> Callable:
"""
Make the final inference function with clean calling convention.
Args:
compiled_fw: Compiled forward function
wrappers: List of compiler wrappers to apply
entry: Optional cache entry
Returns:
Callable with clean calling convention (takes structured user inputs)
"""
self._ensure_state_initialized()
from torch._functorch._aot_autograd.graph_compile import _aot_stage2c_make_inference_function
compiled_fn, _ = _aot_stage2c_make_inference_function(
self._aot_config,
self._fw_metadata,
compiled_fw,
wrappers,
entry,
)
return _make_callable(
compiled_fn,
self._gm,
self.joint_with_descriptors.params_spec, # type: ignore[union-attr]
self.joint_with_descriptors.buffers_spec, # type: ignore[union-attr]
self.joint_with_descriptors.in_spec, # type: ignore[union-attr]
self.joint_with_descriptors.out_spec, # type: ignore[union-attr]
)
def make_autograd_function(
self,
flat_args: list[Any],
wrappers: list[Any], # list[CompilerWrapper]
compiled_fw_func: Callable,
compiled_bw_func: Optional[Callable],
lazy_backward_info: Any, # AutogradLazyBackwardCompileInfo
indices_of_inps_to_detach: list[int],
num_symints_saved_for_bw: int,
try_save_cache_entry: Optional[Callable] = None,
entry: Optional[Any] = None, # Optional[GenericAOTAutogradCacheEntry]
) -> Callable:
"""
Make the final autograd function with clean calling convention.
Args:
flat_args: Flattened input arguments
wrappers: List of compiler wrappers to apply
compiled_fw_func: Compiled forward function
compiled_bw_func: Optional compiled backward function
lazy_backward_info: Information for lazy backward compilation
indices_of_inps_to_detach: Indices of inputs to detach
num_symints_saved_for_bw: Number of symbolic ints saved for backward
try_save_cache_entry: Optional callback to save cache entry
entry: Optional existing cache entry
Returns:
Callable with clean calling convention (takes structured user inputs)
"""
self._ensure_state_initialized()
from torch._functorch._aot_autograd.graph_compile import _aot_stage2c_make_autograd_function
compiled_fn, _ = _aot_stage2c_make_autograd_function(
self._aot_config,
flat_args,
self._fw_metadata,
self._maybe_subclass_meta,
wrappers,
compiled_fw_func,
compiled_bw_func,
lazy_backward_info,
try_save_cache_entry,
entry,
indices_of_inps_to_detach,
num_symints_saved_for_bw,
)
return _make_callable(
compiled_fn,
self._gm,
self.joint_with_descriptors.params_spec, # type: ignore[union-attr]
self.joint_with_descriptors.buffers_spec, # type: ignore[union-attr]
self.joint_with_descriptors.in_spec, # type: ignore[union-attr]
self.joint_with_descriptors.out_spec, # type: ignore[union-attr]
)
@property
def graph_module(self) -> Optional[torch.fx.GraphModule]:
"""Get the captured graph module."""
return self._gm
@property
def aot_state(self) -> Optional[Any]:
"""Get the AOT state from the joint graph."""
return self.joint_with_descriptors._aot_state if self.joint_with_descriptors else None
@property
def aot_graph_capture(self) -> Optional[Any]:
"""Get the AOT graph capture from the joint graph."""
return self.joint_with_descriptors._aot_graph_capture if self.joint_with_descriptors else None
@property
def params_spec(self) -> Optional[list[str]]:
"""Get the parameter specification from the joint graph."""
return self.joint_with_descriptors.params_spec if self.joint_with_descriptors else None
@property
def buffers_spec(self) -> Optional[list[str]]:
"""Get the buffer specification from the joint graph."""
return self.joint_with_descriptors.buffers_spec if self.joint_with_descriptors else None
@property
def in_spec(self) -> Optional[pytree.TreeSpec]:
"""Get the input tree specification from the joint graph."""
return self.joint_with_descriptors.in_spec if self.joint_with_descriptors else None
@property
def out_spec(self) -> Optional[pytree.TreeSpec]:
"""Get the output tree specification from the joint graph."""
return self.joint_with_descriptors.out_spec if self.joint_with_descriptors else None
@dataclasses.dataclass
class ExportDynamoConfig:
"""