Fullgraph graph capture with dynamo. (#159749)

Summary:
Following up on Avik's doc https://docs.google.com/document/d/11RW0Bbkp1QwFbEu8rCNW5d7wUFaEkxbL0uLyqcc2jTk/edit?tab=t.0

We are experimenting with a new API which utilizes torch.compile(fullgraph=True) and intend to use it to replace the old dynamo.export() API.

This PR adds a prototype for the API described in the doc.

Test Plan:
test_misc -- -k test_aot_capture

Rollback Plan:

Differential Revision: D79534608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159749
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2025-08-12 22:06:18 +00:00
committed by PyTorch MergeBot
parent 101276f81b
commit 16d15445f8
3 changed files with 138 additions and 4 deletions

View File

@ -16,11 +16,13 @@ import logging
import math
import operator
import os
import pickle
import random
import sys
import tempfile
import threading
import traceback
import types
import typing
import unittest
import unittest.mock as mock
@ -8520,6 +8522,50 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertEqual(seen_frames[0].name, "fn")
self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
def test_fullgraph_capture(self):
def foo(x):
return x + x.shape[0]
compiled_foo = torch._dynamo.eval_frame.fullgraph_capture(foo)
compiled_foo(torch.randn(3, 2))
compiled_foo(torch.randn(4))
artifacts = compiled_foo.get_artifacts()
guarded_codes = artifacts.dynamo_artifacts.guarded_codes
backend_ids = list(artifacts.backend_inputs.keys())
gms = [b.graph_module for b in artifacts.backend_inputs.values()]
def _convert_to_ep_demo(code, backend_id, gm, args):
# Inject compiled function as the original gm
new_globals = copy.copy(globals())
new_globals[backend_id] = gm
# Minimal boilerplate to setup a callable.
SerializedCode = type(code.dynamo_code)
dynamo_bytecode = SerializedCode.to_code_object(code.dynamo_code)
guards_state = pickle.loads(code.guards_state)
guard_manager = torch._dynamo.guards.CheckFunctionManager(
foo.__code__,
guards_state.output_graph,
guards_serialization_mode="load",
shape_code_parts=guards_state.shape_code_parts,
runtime_global_scope=new_globals,
).guard_manager
class ModuleForExport(torch.nn.Module):
def forward(self, x):
return types.FunctionType(dynamo_bytecode, new_globals)(x)
m = ModuleForExport()
return guard_manager, torch.export.export(m, args)
guards0, ep0 = _convert_to_ep_demo(
guarded_codes[0], backend_ids[0], gms[0], (torch.randn(3, 2),)
)
self.assertTrue(guards0.check({"x": torch.randn(3, 2)}))
self.assertFalse(guards0.check({"x": torch.randn(4)}))
input0 = torch.randn(3, 2)
self.assertEqual(ep0.module()(input0), foo(input0))
def test_torch_guards_stack_frame_register_inlining_deep(self):
x = torch.tensor([0.5, 0.5])
y = torch.tensor([0.75, 0.75, 0.75, 0.75])

View File

@ -113,7 +113,7 @@ from .utils import (
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from torch._dynamo.package import CompilePackage
from torch._dynamo.package import CompilePackage, DynamoCaptureOutput
from torch._dynamo.repro.after_dynamo import WrapBackendDebug
from torch._subclasses import fake_tensor
from torch.fx.node import Argument, Node, Target
@ -2288,3 +2288,83 @@ def skip_code(code: types.CodeType) -> None:
set_code_exec_strategy(
code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
)
@dataclass
class BackendInput:
graph_module: torch.fx.GraphModule
example_inputs: tuple[Any, ...]
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
@dataclass
class CaptureOutput:
"""
Core data structure that contains the all the information dynamo generates
from fullgraph=True. Ideally, this is should be the "return" type if dynamo
has a standard API to return compilation artifacts.
"""
dynamo_artifacts: DynamoCaptureOutput
backend_inputs: dict[str, BackendInput]
def fullgraph_capture(model: Callable[..., Any]) -> Callable[..., Any]:
"""
A helper function which wraps a model and returns a callable like optimize().
The callable can be called with normal inputs like torch.compile()-ed functions
and user can dump dynamo compilation artifacts through `get_artifacts()` call.
The CaptureOutput is separated into two parts:
1. Dynamo specific information from DynamoCaptureOutput, which includes:
- guards
- generated bytecode
- python source information
2. Backend specific information (indexed by unique backend id) such as:
- fx graph
- example inputs
Example:
def fn(*args):
...
compiled_fn = fullgraph_capture(fn)
compiled_fn(args)
compiled_fn(another_args)
artifacts = compiled_fn.get_artifacts()
"""
from torch._dynamo.package import CompilePackage
package = CompilePackage(model)
backend_inputs: dict[str, BackendInput] = {}
def _backend(
gm: torch.fx.GraphModule, example_inputs: tuple[Any, ...]
) -> torch.fx.GraphModule:
from torch._guards import TracingContext
fake_mode = TracingContext.get().fake_mode
assert fake_mode is not None
backend_id = gm._backend_id
assert isinstance(backend_id, str)
backend_inputs[backend_id] = BackendInput(gm, example_inputs, fake_mode)
return gm
# TODO For now we use eval_frame to give us the frame. This is can be simplified to
# a manual frame creation helper.
optimized_model = optimize(nopython=True, backend=_backend, package=package)(model)
@functools.wraps(model)
def capture_context(*args: Any, **kwargs: Any) -> Any:
return optimized_model(*args, **kwargs)
def get_artifacts() -> CaptureOutput:
cache_entry = package.cache_entry()
assert len(cache_entry.codes) == 1
return CaptureOutput(
dynamo_artifacts=cache_entry.codes[0], backend_inputs=backend_inputs
)
capture_context.get_artifacts = get_artifacts # type: ignore[attr-defined]
return capture_context

View File

@ -112,7 +112,17 @@ class InlinedSource:
@dataclasses.dataclass
class _DynamoCodeCacheEntry:
class DynamoCaptureOutput:
"""
Core information generated from Dynamo for fullgraph=True.
"""
guarded_codes: list[_GuardedCodeCacheEntry]
backend_ids: list[_BackendId]
@dataclasses.dataclass
class _DynamoCodeCacheEntry(DynamoCaptureOutput):
"""
Contains the serializable information associated with a single code object
in dynamo. To restore an execution of compiled code, we will need the following
@ -135,9 +145,7 @@ class _DynamoCodeCacheEntry:
python_code: SerializedCode
python_module: str
function_names: list[_FunctionId]
guarded_codes: list[_GuardedCodeCacheEntry]
import_sources: dict[str, str]
backend_ids: list[_BackendId]
code_source: Optional[str]
install_to_global: bool
has_compile_id: bool = False