mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fullgraph graph capture with dynamo. (#159749)
Summary: Following up on Avik's doc https://docs.google.com/document/d/11RW0Bbkp1QwFbEu8rCNW5d7wUFaEkxbL0uLyqcc2jTk/edit?tab=t.0 We are experimenting with a new API which utilizes torch.compile(fullgraph=True) and intend to use it to replace the old dynamo.export() API. This PR adds a prototype for the API described in the doc. Test Plan: test_misc -- -k test_aot_capture Rollback Plan: Differential Revision: D79534608 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159749 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
101276f81b
commit
16d15445f8
@ -16,11 +16,13 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
|
import types
|
||||||
import typing
|
import typing
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
@ -8520,6 +8522,50 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
self.assertEqual(seen_frames[0].name, "fn")
|
self.assertEqual(seen_frames[0].name, "fn")
|
||||||
self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
|
self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
|
||||||
|
|
||||||
|
def test_fullgraph_capture(self):
|
||||||
|
def foo(x):
|
||||||
|
return x + x.shape[0]
|
||||||
|
|
||||||
|
compiled_foo = torch._dynamo.eval_frame.fullgraph_capture(foo)
|
||||||
|
compiled_foo(torch.randn(3, 2))
|
||||||
|
compiled_foo(torch.randn(4))
|
||||||
|
artifacts = compiled_foo.get_artifacts()
|
||||||
|
|
||||||
|
guarded_codes = artifacts.dynamo_artifacts.guarded_codes
|
||||||
|
backend_ids = list(artifacts.backend_inputs.keys())
|
||||||
|
gms = [b.graph_module for b in artifacts.backend_inputs.values()]
|
||||||
|
|
||||||
|
def _convert_to_ep_demo(code, backend_id, gm, args):
|
||||||
|
# Inject compiled function as the original gm
|
||||||
|
new_globals = copy.copy(globals())
|
||||||
|
new_globals[backend_id] = gm
|
||||||
|
# Minimal boilerplate to setup a callable.
|
||||||
|
SerializedCode = type(code.dynamo_code)
|
||||||
|
dynamo_bytecode = SerializedCode.to_code_object(code.dynamo_code)
|
||||||
|
guards_state = pickle.loads(code.guards_state)
|
||||||
|
guard_manager = torch._dynamo.guards.CheckFunctionManager(
|
||||||
|
foo.__code__,
|
||||||
|
guards_state.output_graph,
|
||||||
|
guards_serialization_mode="load",
|
||||||
|
shape_code_parts=guards_state.shape_code_parts,
|
||||||
|
runtime_global_scope=new_globals,
|
||||||
|
).guard_manager
|
||||||
|
|
||||||
|
class ModuleForExport(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return types.FunctionType(dynamo_bytecode, new_globals)(x)
|
||||||
|
|
||||||
|
m = ModuleForExport()
|
||||||
|
return guard_manager, torch.export.export(m, args)
|
||||||
|
|
||||||
|
guards0, ep0 = _convert_to_ep_demo(
|
||||||
|
guarded_codes[0], backend_ids[0], gms[0], (torch.randn(3, 2),)
|
||||||
|
)
|
||||||
|
self.assertTrue(guards0.check({"x": torch.randn(3, 2)}))
|
||||||
|
self.assertFalse(guards0.check({"x": torch.randn(4)}))
|
||||||
|
input0 = torch.randn(3, 2)
|
||||||
|
self.assertEqual(ep0.module()(input0), foo(input0))
|
||||||
|
|
||||||
def test_torch_guards_stack_frame_register_inlining_deep(self):
|
def test_torch_guards_stack_frame_register_inlining_deep(self):
|
||||||
x = torch.tensor([0.5, 0.5])
|
x = torch.tensor([0.5, 0.5])
|
||||||
y = torch.tensor([0.75, 0.75, 0.75, 0.75])
|
y = torch.tensor([0.75, 0.75, 0.75, 0.75])
|
||||||
|
@ -113,7 +113,7 @@ from .utils import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
|
|
||||||
from torch._dynamo.package import CompilePackage
|
from torch._dynamo.package import CompilePackage, DynamoCaptureOutput
|
||||||
from torch._dynamo.repro.after_dynamo import WrapBackendDebug
|
from torch._dynamo.repro.after_dynamo import WrapBackendDebug
|
||||||
from torch._subclasses import fake_tensor
|
from torch._subclasses import fake_tensor
|
||||||
from torch.fx.node import Argument, Node, Target
|
from torch.fx.node import Argument, Node, Target
|
||||||
@ -2288,3 +2288,83 @@ def skip_code(code: types.CodeType) -> None:
|
|||||||
set_code_exec_strategy(
|
set_code_exec_strategy(
|
||||||
code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
|
code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendInput:
|
||||||
|
graph_module: torch.fx.GraphModule
|
||||||
|
example_inputs: tuple[Any, ...]
|
||||||
|
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CaptureOutput:
|
||||||
|
"""
|
||||||
|
Core data structure that contains the all the information dynamo generates
|
||||||
|
from fullgraph=True. Ideally, this is should be the "return" type if dynamo
|
||||||
|
has a standard API to return compilation artifacts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dynamo_artifacts: DynamoCaptureOutput
|
||||||
|
backend_inputs: dict[str, BackendInput]
|
||||||
|
|
||||||
|
|
||||||
|
def fullgraph_capture(model: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
"""
|
||||||
|
A helper function which wraps a model and returns a callable like optimize().
|
||||||
|
The callable can be called with normal inputs like torch.compile()-ed functions
|
||||||
|
and user can dump dynamo compilation artifacts through `get_artifacts()` call.
|
||||||
|
|
||||||
|
The CaptureOutput is separated into two parts:
|
||||||
|
1. Dynamo specific information from DynamoCaptureOutput, which includes:
|
||||||
|
- guards
|
||||||
|
- generated bytecode
|
||||||
|
- python source information
|
||||||
|
2. Backend specific information (indexed by unique backend id) such as:
|
||||||
|
- fx graph
|
||||||
|
- example inputs
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def fn(*args):
|
||||||
|
...
|
||||||
|
|
||||||
|
compiled_fn = fullgraph_capture(fn)
|
||||||
|
compiled_fn(args)
|
||||||
|
compiled_fn(another_args)
|
||||||
|
artifacts = compiled_fn.get_artifacts()
|
||||||
|
"""
|
||||||
|
from torch._dynamo.package import CompilePackage
|
||||||
|
|
||||||
|
package = CompilePackage(model)
|
||||||
|
|
||||||
|
backend_inputs: dict[str, BackendInput] = {}
|
||||||
|
|
||||||
|
def _backend(
|
||||||
|
gm: torch.fx.GraphModule, example_inputs: tuple[Any, ...]
|
||||||
|
) -> torch.fx.GraphModule:
|
||||||
|
from torch._guards import TracingContext
|
||||||
|
|
||||||
|
fake_mode = TracingContext.get().fake_mode
|
||||||
|
assert fake_mode is not None
|
||||||
|
backend_id = gm._backend_id
|
||||||
|
assert isinstance(backend_id, str)
|
||||||
|
backend_inputs[backend_id] = BackendInput(gm, example_inputs, fake_mode)
|
||||||
|
return gm
|
||||||
|
|
||||||
|
# TODO For now we use eval_frame to give us the frame. This is can be simplified to
|
||||||
|
# a manual frame creation helper.
|
||||||
|
optimized_model = optimize(nopython=True, backend=_backend, package=package)(model)
|
||||||
|
|
||||||
|
@functools.wraps(model)
|
||||||
|
def capture_context(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return optimized_model(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_artifacts() -> CaptureOutput:
|
||||||
|
cache_entry = package.cache_entry()
|
||||||
|
assert len(cache_entry.codes) == 1
|
||||||
|
return CaptureOutput(
|
||||||
|
dynamo_artifacts=cache_entry.codes[0], backend_inputs=backend_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
capture_context.get_artifacts = get_artifacts # type: ignore[attr-defined]
|
||||||
|
return capture_context
|
||||||
|
@ -112,7 +112,17 @@ class InlinedSource:
|
|||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class _DynamoCodeCacheEntry:
|
class DynamoCaptureOutput:
|
||||||
|
"""
|
||||||
|
Core information generated from Dynamo for fullgraph=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
guarded_codes: list[_GuardedCodeCacheEntry]
|
||||||
|
backend_ids: list[_BackendId]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class _DynamoCodeCacheEntry(DynamoCaptureOutput):
|
||||||
"""
|
"""
|
||||||
Contains the serializable information associated with a single code object
|
Contains the serializable information associated with a single code object
|
||||||
in dynamo. To restore an execution of compiled code, we will need the following
|
in dynamo. To restore an execution of compiled code, we will need the following
|
||||||
@ -135,9 +145,7 @@ class _DynamoCodeCacheEntry:
|
|||||||
python_code: SerializedCode
|
python_code: SerializedCode
|
||||||
python_module: str
|
python_module: str
|
||||||
function_names: list[_FunctionId]
|
function_names: list[_FunctionId]
|
||||||
guarded_codes: list[_GuardedCodeCacheEntry]
|
|
||||||
import_sources: dict[str, str]
|
import_sources: dict[str, str]
|
||||||
backend_ids: list[_BackendId]
|
|
||||||
code_source: Optional[str]
|
code_source: Optional[str]
|
||||||
install_to_global: bool
|
install_to_global: bool
|
||||||
has_compile_id: bool = False
|
has_compile_id: bool = False
|
||||||
|
Reference in New Issue
Block a user