mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[WIP] Make constructor calls in experimental MetaTracer serializable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76789 Approved by: https://github.com/pbelevich
This commit is contained in:
committed by
PyTorch MergeBot
parent
3d561ee926
commit
7311390d35
@ -6,7 +6,7 @@ torch.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.modul
|
||||
torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph.Graph
|
||||
torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph_module.GraphModule
|
||||
torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable])
|
||||
torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None)
|
||||
torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None, tracer_extras: Optional[Dict[str, Any]] = None)
|
||||
torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
|
||||
|
@ -3,7 +3,9 @@
|
||||
import math
|
||||
import numbers
|
||||
import operator
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Callable, Dict, Union, List, Optional
|
||||
from types import BuiltinFunctionType
|
||||
@ -26,7 +28,7 @@ from torch.fx.experimental.partitioner_utils import (
|
||||
)
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
|
||||
from torch.fx.experimental.meta_tracer import MetaTracer
|
||||
import torch.fx.experimental.meta_tracer
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
@ -670,8 +672,6 @@ class TestFXExperimental(JitTestCase):
|
||||
self.assertEqual(traced(3, 3), m(3, 3))
|
||||
|
||||
def test_meta_tracer(self):
|
||||
mt = MetaTracer()
|
||||
|
||||
class MetaTracerTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -680,16 +680,27 @@ class TestFXExperimental(JitTestCase):
|
||||
|
||||
def forward(self, x):
|
||||
emb = self.emb(x)
|
||||
emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device)
|
||||
lol = self.layernorm(emb)
|
||||
return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
|
||||
|
||||
mttm = MetaTracerTestModule()
|
||||
for BS in [15, 35]:
|
||||
x = torch.zeros(BS, dtype=torch.long).random_(42)
|
||||
graph = mt.trace(mttm, meta_args={'x' : x.to(device='meta')})
|
||||
gm = torch.fx.GraphModule(mttm, graph)
|
||||
meta_args = {'x' : x.to(device='meta')}
|
||||
gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args)
|
||||
torch.testing.assert_close(gm(x), mttm(x))
|
||||
|
||||
# Test serialization/deserialization
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f:
|
||||
pickle.dump(gm, f)
|
||||
|
||||
with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f:
|
||||
loaded = pickle.load(f)
|
||||
|
||||
torch.testing.assert_close(loaded(x), mttm(x))
|
||||
|
||||
def test_proxy_tensor(self):
|
||||
def f(x):
|
||||
val = x.cos().cos().sum()
|
||||
|
@ -4,7 +4,7 @@ import warnings
|
||||
import functools
|
||||
import builtins
|
||||
|
||||
from typing import Callable, Dict
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
def embedding_override(self, input):
|
||||
return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
|
||||
@ -250,7 +250,19 @@ class MetaTracer(torch.fx.Tracer):
|
||||
self.orig_fns.add(orig)
|
||||
|
||||
try:
|
||||
return super().trace(root, concrete_args)
|
||||
graph = super().trace(root, concrete_args)
|
||||
graph._tracer_extras = {'meta_args': meta_args}
|
||||
return graph
|
||||
finally:
|
||||
for name, (_, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, orig)
|
||||
|
||||
|
||||
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
||||
meta_args : Dict[str, torch.Tensor] = None,
|
||||
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
|
||||
tracer = MetaTracer()
|
||||
graph = tracer.trace(root, meta_args, concrete_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
gm = torch.fx.GraphModule(tracer.root, graph, name)
|
||||
return gm
|
||||
|
@ -598,7 +598,8 @@ class Graph:
|
||||
"""
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None):
|
||||
def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None,
|
||||
tracer_extras: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Construct an empty Graph.
|
||||
"""
|
||||
@ -610,6 +611,7 @@ class Graph:
|
||||
self._owners = 0
|
||||
self._owning_module = owning_module
|
||||
self._tracer_cls = tracer_cls
|
||||
self._tracer_extras = tracer_extras
|
||||
self._codegen = CodeGen()
|
||||
|
||||
@property
|
||||
|
@ -160,7 +160,8 @@ def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
|
||||
|
||||
com = CodeOnlyModule(body)
|
||||
|
||||
graph = KeepModules().trace(com)
|
||||
tracer_extras = body.get('_tracer_extras', {})
|
||||
graph = KeepModules().trace(com, **tracer_extras)
|
||||
|
||||
# Manually set Tracer class on the reconstructed Graph, to avoid
|
||||
# referencing the private local subclass KeepModules.
|
||||
@ -373,6 +374,10 @@ class GraphModule(torch.nn.Module):
|
||||
if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
|
||||
self._tracer_cls = self.graph._tracer_cls
|
||||
|
||||
self._tracer_extras = {}
|
||||
if self.graph._tracer_extras:
|
||||
self._tracer_extras = self.graph._tracer_extras
|
||||
|
||||
# TorchScript breaks trying to compile the graph setter because of the
|
||||
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
|
||||
#
|
||||
|
Reference in New Issue
Block a user