[WIP] Make constructor calls in experimental MetaTracer serializable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76789

Approved by: https://github.com/pbelevich
This commit is contained in:
James Reed
2022-05-10 23:54:37 +00:00
committed by PyTorch MergeBot
parent 3d561ee926
commit 7311390d35
5 changed files with 40 additions and 10 deletions

View File

@ -6,7 +6,7 @@ torch.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.modul
torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph.Graph
torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph_module.GraphModule
torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable])
torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None)
torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None, tracer_extras: Optional[Dict[str, Any]] = None)
torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node

View File

@ -3,7 +3,9 @@
import math
import numbers
import operator
import pickle
import sys
import tempfile
import unittest
from typing import Callable, Dict, Union, List, Optional
from types import BuiltinFunctionType
@ -26,7 +28,7 @@ from torch.fx.experimental.partitioner_utils import (
)
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
from torch.fx.experimental.meta_tracer import MetaTracer
import torch.fx.experimental.meta_tracer
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
@ -670,8 +672,6 @@ class TestFXExperimental(JitTestCase):
self.assertEqual(traced(3, 3), m(3, 3))
def test_meta_tracer(self):
mt = MetaTracer()
class MetaTracerTestModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -680,16 +680,27 @@ class TestFXExperimental(JitTestCase):
def forward(self, x):
emb = self.emb(x)
emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device)
lol = self.layernorm(emb)
return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
mttm = MetaTracerTestModule()
for BS in [15, 35]:
x = torch.zeros(BS, dtype=torch.long).random_(42)
graph = mt.trace(mttm, meta_args={'x' : x.to(device='meta')})
gm = torch.fx.GraphModule(mttm, graph)
meta_args = {'x' : x.to(device='meta')}
gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args)
torch.testing.assert_close(gm(x), mttm(x))
# Test serialization/deserialization
with tempfile.TemporaryDirectory() as tmp_dir:
with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f:
pickle.dump(gm, f)
with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f:
loaded = pickle.load(f)
torch.testing.assert_close(loaded(x), mttm(x))
def test_proxy_tensor(self):
def f(x):
val = x.cos().cos().sum()

View File

@ -4,7 +4,7 @@ import warnings
import functools
import builtins
from typing import Callable, Dict
from typing import Any, Callable, Dict, Optional, Union
def embedding_override(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
@ -250,7 +250,19 @@ class MetaTracer(torch.fx.Tracer):
self.orig_fns.add(orig)
try:
return super().trace(root, concrete_args)
graph = super().trace(root, concrete_args)
graph._tracer_extras = {'meta_args': meta_args}
return graph
finally:
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
meta_args : Dict[str, torch.Tensor] = None,
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
gm = torch.fx.GraphModule(tracer.root, graph, name)
return gm

View File

@ -598,7 +598,8 @@ class Graph:
"""
@compatibility(is_backward_compatible=True)
def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None):
def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None,
tracer_extras: Optional[Dict[str, Any]] = None):
"""
Construct an empty Graph.
"""
@ -610,6 +611,7 @@ class Graph:
self._owners = 0
self._owning_module = owning_module
self._tracer_cls = tracer_cls
self._tracer_extras = tracer_extras
self._codegen = CodeGen()
@property

View File

@ -160,7 +160,8 @@ def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
com = CodeOnlyModule(body)
graph = KeepModules().trace(com)
tracer_extras = body.get('_tracer_extras', {})
graph = KeepModules().trace(com, **tracer_extras)
# Manually set Tracer class on the reconstructed Graph, to avoid
# referencing the private local subclass KeepModules.
@ -373,6 +374,10 @@ class GraphModule(torch.nn.Module):
if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
self._tracer_cls = self.graph._tracer_cls
self._tracer_extras = {}
if self.graph._tracer_extras:
self._tracer_extras = self.graph._tracer_extras
# TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
#