mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fullgraph graph capture with dynamo. (#159749)
Summary: Following up on Avik's doc https://docs.google.com/document/d/11RW0Bbkp1QwFbEu8rCNW5d7wUFaEkxbL0uLyqcc2jTk/edit?tab=t.0 We are experimenting with a new API which utilizes torch.compile(fullgraph=True) and intend to use it to replace the old dynamo.export() API. This PR adds a prototype for the API described in the doc. Test Plan: test_misc -- -k test_aot_capture Rollback Plan: Differential Revision: D79534608 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159749 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
101276f81b
commit
16d15445f8
@ -16,11 +16,13 @@ import logging
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import traceback
|
||||
import types
|
||||
import typing
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
@ -8520,6 +8522,50 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
self.assertEqual(seen_frames[0].name, "fn")
|
||||
self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
|
||||
|
||||
def test_fullgraph_capture(self):
|
||||
def foo(x):
|
||||
return x + x.shape[0]
|
||||
|
||||
compiled_foo = torch._dynamo.eval_frame.fullgraph_capture(foo)
|
||||
compiled_foo(torch.randn(3, 2))
|
||||
compiled_foo(torch.randn(4))
|
||||
artifacts = compiled_foo.get_artifacts()
|
||||
|
||||
guarded_codes = artifacts.dynamo_artifacts.guarded_codes
|
||||
backend_ids = list(artifacts.backend_inputs.keys())
|
||||
gms = [b.graph_module for b in artifacts.backend_inputs.values()]
|
||||
|
||||
def _convert_to_ep_demo(code, backend_id, gm, args):
|
||||
# Inject compiled function as the original gm
|
||||
new_globals = copy.copy(globals())
|
||||
new_globals[backend_id] = gm
|
||||
# Minimal boilerplate to setup a callable.
|
||||
SerializedCode = type(code.dynamo_code)
|
||||
dynamo_bytecode = SerializedCode.to_code_object(code.dynamo_code)
|
||||
guards_state = pickle.loads(code.guards_state)
|
||||
guard_manager = torch._dynamo.guards.CheckFunctionManager(
|
||||
foo.__code__,
|
||||
guards_state.output_graph,
|
||||
guards_serialization_mode="load",
|
||||
shape_code_parts=guards_state.shape_code_parts,
|
||||
runtime_global_scope=new_globals,
|
||||
).guard_manager
|
||||
|
||||
class ModuleForExport(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return types.FunctionType(dynamo_bytecode, new_globals)(x)
|
||||
|
||||
m = ModuleForExport()
|
||||
return guard_manager, torch.export.export(m, args)
|
||||
|
||||
guards0, ep0 = _convert_to_ep_demo(
|
||||
guarded_codes[0], backend_ids[0], gms[0], (torch.randn(3, 2),)
|
||||
)
|
||||
self.assertTrue(guards0.check({"x": torch.randn(3, 2)}))
|
||||
self.assertFalse(guards0.check({"x": torch.randn(4)}))
|
||||
input0 = torch.randn(3, 2)
|
||||
self.assertEqual(ep0.module()(input0), foo(input0))
|
||||
|
||||
def test_torch_guards_stack_frame_register_inlining_deep(self):
|
||||
x = torch.tensor([0.5, 0.5])
|
||||
y = torch.tensor([0.75, 0.75, 0.75, 0.75])
|
||||
|
@ -113,7 +113,7 @@ from .utils import (
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from torch._dynamo.package import CompilePackage
|
||||
from torch._dynamo.package import CompilePackage, DynamoCaptureOutput
|
||||
from torch._dynamo.repro.after_dynamo import WrapBackendDebug
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
@ -2288,3 +2288,83 @@ def skip_code(code: types.CodeType) -> None:
|
||||
set_code_exec_strategy(
|
||||
code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendInput:
|
||||
graph_module: torch.fx.GraphModule
|
||||
example_inputs: tuple[Any, ...]
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaptureOutput:
|
||||
"""
|
||||
Core data structure that contains the all the information dynamo generates
|
||||
from fullgraph=True. Ideally, this is should be the "return" type if dynamo
|
||||
has a standard API to return compilation artifacts.
|
||||
"""
|
||||
|
||||
dynamo_artifacts: DynamoCaptureOutput
|
||||
backend_inputs: dict[str, BackendInput]
|
||||
|
||||
|
||||
def fullgraph_capture(model: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
A helper function which wraps a model and returns a callable like optimize().
|
||||
The callable can be called with normal inputs like torch.compile()-ed functions
|
||||
and user can dump dynamo compilation artifacts through `get_artifacts()` call.
|
||||
|
||||
The CaptureOutput is separated into two parts:
|
||||
1. Dynamo specific information from DynamoCaptureOutput, which includes:
|
||||
- guards
|
||||
- generated bytecode
|
||||
- python source information
|
||||
2. Backend specific information (indexed by unique backend id) such as:
|
||||
- fx graph
|
||||
- example inputs
|
||||
|
||||
Example:
|
||||
def fn(*args):
|
||||
...
|
||||
|
||||
compiled_fn = fullgraph_capture(fn)
|
||||
compiled_fn(args)
|
||||
compiled_fn(another_args)
|
||||
artifacts = compiled_fn.get_artifacts()
|
||||
"""
|
||||
from torch._dynamo.package import CompilePackage
|
||||
|
||||
package = CompilePackage(model)
|
||||
|
||||
backend_inputs: dict[str, BackendInput] = {}
|
||||
|
||||
def _backend(
|
||||
gm: torch.fx.GraphModule, example_inputs: tuple[Any, ...]
|
||||
) -> torch.fx.GraphModule:
|
||||
from torch._guards import TracingContext
|
||||
|
||||
fake_mode = TracingContext.get().fake_mode
|
||||
assert fake_mode is not None
|
||||
backend_id = gm._backend_id
|
||||
assert isinstance(backend_id, str)
|
||||
backend_inputs[backend_id] = BackendInput(gm, example_inputs, fake_mode)
|
||||
return gm
|
||||
|
||||
# TODO For now we use eval_frame to give us the frame. This is can be simplified to
|
||||
# a manual frame creation helper.
|
||||
optimized_model = optimize(nopython=True, backend=_backend, package=package)(model)
|
||||
|
||||
@functools.wraps(model)
|
||||
def capture_context(*args: Any, **kwargs: Any) -> Any:
|
||||
return optimized_model(*args, **kwargs)
|
||||
|
||||
def get_artifacts() -> CaptureOutput:
|
||||
cache_entry = package.cache_entry()
|
||||
assert len(cache_entry.codes) == 1
|
||||
return CaptureOutput(
|
||||
dynamo_artifacts=cache_entry.codes[0], backend_inputs=backend_inputs
|
||||
)
|
||||
|
||||
capture_context.get_artifacts = get_artifacts # type: ignore[attr-defined]
|
||||
return capture_context
|
||||
|
@ -112,7 +112,17 @@ class InlinedSource:
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynamoCodeCacheEntry:
|
||||
class DynamoCaptureOutput:
|
||||
"""
|
||||
Core information generated from Dynamo for fullgraph=True.
|
||||
"""
|
||||
|
||||
guarded_codes: list[_GuardedCodeCacheEntry]
|
||||
backend_ids: list[_BackendId]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynamoCodeCacheEntry(DynamoCaptureOutput):
|
||||
"""
|
||||
Contains the serializable information associated with a single code object
|
||||
in dynamo. To restore an execution of compiled code, we will need the following
|
||||
@ -135,9 +145,7 @@ class _DynamoCodeCacheEntry:
|
||||
python_code: SerializedCode
|
||||
python_module: str
|
||||
function_names: list[_FunctionId]
|
||||
guarded_codes: list[_GuardedCodeCacheEntry]
|
||||
import_sources: dict[str, str]
|
||||
backend_ids: list[_BackendId]
|
||||
code_source: Optional[str]
|
||||
install_to_global: bool
|
||||
has_compile_id: bool = False
|
||||
|
Reference in New Issue
Block a user