Fullgraph graph capture with dynamo. (#159749)

Summary:
Following up on Avik's doc https://docs.google.com/document/d/11RW0Bbkp1QwFbEu8rCNW5d7wUFaEkxbL0uLyqcc2jTk/edit?tab=t.0

We are experimenting with a new API which utilizes torch.compile(fullgraph=True) and intend to use it to replace the old dynamo.export() API.

This PR adds a prototype for the API described in the doc.

Test Plan:
test_misc -- -k test_aot_capture

Rollback Plan:

Differential Revision: D79534608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159749
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2025-08-12 22:06:18 +00:00
committed by PyTorch MergeBot
parent 101276f81b
commit 16d15445f8
3 changed files with 138 additions and 4 deletions

View File

@ -16,11 +16,13 @@ import logging
import math import math
import operator import operator
import os import os
import pickle
import random import random
import sys import sys
import tempfile import tempfile
import threading import threading
import traceback import traceback
import types
import typing import typing
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
@ -8520,6 +8522,50 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertEqual(seen_frames[0].name, "fn") self.assertEqual(seen_frames[0].name, "fn")
self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)") self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
def test_fullgraph_capture(self):
def foo(x):
return x + x.shape[0]
compiled_foo = torch._dynamo.eval_frame.fullgraph_capture(foo)
compiled_foo(torch.randn(3, 2))
compiled_foo(torch.randn(4))
artifacts = compiled_foo.get_artifacts()
guarded_codes = artifacts.dynamo_artifacts.guarded_codes
backend_ids = list(artifacts.backend_inputs.keys())
gms = [b.graph_module for b in artifacts.backend_inputs.values()]
def _convert_to_ep_demo(code, backend_id, gm, args):
# Inject compiled function as the original gm
new_globals = copy.copy(globals())
new_globals[backend_id] = gm
# Minimal boilerplate to setup a callable.
SerializedCode = type(code.dynamo_code)
dynamo_bytecode = SerializedCode.to_code_object(code.dynamo_code)
guards_state = pickle.loads(code.guards_state)
guard_manager = torch._dynamo.guards.CheckFunctionManager(
foo.__code__,
guards_state.output_graph,
guards_serialization_mode="load",
shape_code_parts=guards_state.shape_code_parts,
runtime_global_scope=new_globals,
).guard_manager
class ModuleForExport(torch.nn.Module):
def forward(self, x):
return types.FunctionType(dynamo_bytecode, new_globals)(x)
m = ModuleForExport()
return guard_manager, torch.export.export(m, args)
guards0, ep0 = _convert_to_ep_demo(
guarded_codes[0], backend_ids[0], gms[0], (torch.randn(3, 2),)
)
self.assertTrue(guards0.check({"x": torch.randn(3, 2)}))
self.assertFalse(guards0.check({"x": torch.randn(4)}))
input0 = torch.randn(3, 2)
self.assertEqual(ep0.module()(input0), foo(input0))
def test_torch_guards_stack_frame_register_inlining_deep(self): def test_torch_guards_stack_frame_register_inlining_deep(self):
x = torch.tensor([0.5, 0.5]) x = torch.tensor([0.5, 0.5])
y = torch.tensor([0.75, 0.75, 0.75, 0.75]) y = torch.tensor([0.75, 0.75, 0.75, 0.75])

View File

@ -113,7 +113,7 @@ from .utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from torch._dynamo.package import CompilePackage from torch._dynamo.package import CompilePackage, DynamoCaptureOutput
from torch._dynamo.repro.after_dynamo import WrapBackendDebug from torch._dynamo.repro.after_dynamo import WrapBackendDebug
from torch._subclasses import fake_tensor from torch._subclasses import fake_tensor
from torch.fx.node import Argument, Node, Target from torch.fx.node import Argument, Node, Target
@ -2288,3 +2288,83 @@ def skip_code(code: types.CodeType) -> None:
set_code_exec_strategy( set_code_exec_strategy(
code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
) )
@dataclass
class BackendInput:
graph_module: torch.fx.GraphModule
example_inputs: tuple[Any, ...]
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
@dataclass
class CaptureOutput:
"""
Core data structure that contains the all the information dynamo generates
from fullgraph=True. Ideally, this is should be the "return" type if dynamo
has a standard API to return compilation artifacts.
"""
dynamo_artifacts: DynamoCaptureOutput
backend_inputs: dict[str, BackendInput]
def fullgraph_capture(model: Callable[..., Any]) -> Callable[..., Any]:
"""
A helper function which wraps a model and returns a callable like optimize().
The callable can be called with normal inputs like torch.compile()-ed functions
and user can dump dynamo compilation artifacts through `get_artifacts()` call.
The CaptureOutput is separated into two parts:
1. Dynamo specific information from DynamoCaptureOutput, which includes:
- guards
- generated bytecode
- python source information
2. Backend specific information (indexed by unique backend id) such as:
- fx graph
- example inputs
Example:
def fn(*args):
...
compiled_fn = fullgraph_capture(fn)
compiled_fn(args)
compiled_fn(another_args)
artifacts = compiled_fn.get_artifacts()
"""
from torch._dynamo.package import CompilePackage
package = CompilePackage(model)
backend_inputs: dict[str, BackendInput] = {}
def _backend(
gm: torch.fx.GraphModule, example_inputs: tuple[Any, ...]
) -> torch.fx.GraphModule:
from torch._guards import TracingContext
fake_mode = TracingContext.get().fake_mode
assert fake_mode is not None
backend_id = gm._backend_id
assert isinstance(backend_id, str)
backend_inputs[backend_id] = BackendInput(gm, example_inputs, fake_mode)
return gm
# TODO For now we use eval_frame to give us the frame. This is can be simplified to
# a manual frame creation helper.
optimized_model = optimize(nopython=True, backend=_backend, package=package)(model)
@functools.wraps(model)
def capture_context(*args: Any, **kwargs: Any) -> Any:
return optimized_model(*args, **kwargs)
def get_artifacts() -> CaptureOutput:
cache_entry = package.cache_entry()
assert len(cache_entry.codes) == 1
return CaptureOutput(
dynamo_artifacts=cache_entry.codes[0], backend_inputs=backend_inputs
)
capture_context.get_artifacts = get_artifacts # type: ignore[attr-defined]
return capture_context

View File

@ -112,7 +112,17 @@ class InlinedSource:
@dataclasses.dataclass @dataclasses.dataclass
class _DynamoCodeCacheEntry: class DynamoCaptureOutput:
"""
Core information generated from Dynamo for fullgraph=True.
"""
guarded_codes: list[_GuardedCodeCacheEntry]
backend_ids: list[_BackendId]
@dataclasses.dataclass
class _DynamoCodeCacheEntry(DynamoCaptureOutput):
""" """
Contains the serializable information associated with a single code object Contains the serializable information associated with a single code object
in dynamo. To restore an execution of compiled code, we will need the following in dynamo. To restore an execution of compiled code, we will need the following
@ -135,9 +145,7 @@ class _DynamoCodeCacheEntry:
python_code: SerializedCode python_code: SerializedCode
python_module: str python_module: str
function_names: list[_FunctionId] function_names: list[_FunctionId]
guarded_codes: list[_GuardedCodeCacheEntry]
import_sources: dict[str, str] import_sources: dict[str, str]
backend_ids: list[_BackendId]
code_source: Optional[str] code_source: Optional[str]
install_to_global: bool install_to_global: bool
has_compile_id: bool = False has_compile_id: bool = False