mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[WIP] Make constructor calls in experimental MetaTracer serializable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76789 Approved by: https://github.com/pbelevich
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							3d561ee926
						
					
				
				
					commit
					7311390d35
				
			| @ -6,7 +6,7 @@ torch.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.modul | ||||
| torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph.Graph | ||||
| torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph_module.GraphModule | ||||
| torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable]) | ||||
| torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None) | ||||
| torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None, tracer_extras: Optional[Dict[str, Any]] = None) | ||||
| torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node | ||||
| torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node | ||||
| torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node | ||||
|  | ||||
| @ -3,7 +3,9 @@ | ||||
| import math | ||||
| import numbers | ||||
| import operator | ||||
| import pickle | ||||
| import sys | ||||
| import tempfile | ||||
| import unittest | ||||
| from typing import Callable, Dict, Union, List, Optional | ||||
| from types import BuiltinFunctionType | ||||
| @ -26,7 +28,7 @@ from torch.fx.experimental.partitioner_utils import ( | ||||
| ) | ||||
| from torch.fx.experimental.rewriter import RewritingTracer | ||||
| from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema | ||||
| from torch.fx.experimental.meta_tracer import MetaTracer | ||||
| import torch.fx.experimental.meta_tracer | ||||
| from torch.fx.experimental.proxy_tensor import make_fx | ||||
| from torch.fx.graph_module import GraphModule | ||||
| from torch.fx.node import Node | ||||
| @ -670,8 +672,6 @@ class TestFXExperimental(JitTestCase): | ||||
|         self.assertEqual(traced(3, 3), m(3, 3)) | ||||
|  | ||||
|     def test_meta_tracer(self): | ||||
|         mt = MetaTracer() | ||||
|  | ||||
|         class MetaTracerTestModule(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
| @ -680,16 +680,27 @@ class TestFXExperimental(JitTestCase): | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 emb = self.emb(x) | ||||
|                 emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device) | ||||
|                 lol = self.layernorm(emb) | ||||
|                 return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) | ||||
|  | ||||
|         mttm = MetaTracerTestModule() | ||||
|         for BS in [15, 35]: | ||||
|             x = torch.zeros(BS, dtype=torch.long).random_(42) | ||||
|             graph = mt.trace(mttm, meta_args={'x' : x.to(device='meta')}) | ||||
|             gm = torch.fx.GraphModule(mttm, graph) | ||||
|             meta_args = {'x' : x.to(device='meta')} | ||||
|             gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args) | ||||
|             torch.testing.assert_close(gm(x), mttm(x)) | ||||
|  | ||||
|             # Test serialization/deserialization | ||||
|             with tempfile.TemporaryDirectory() as tmp_dir: | ||||
|                 with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f: | ||||
|                     pickle.dump(gm, f) | ||||
|  | ||||
|                 with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f: | ||||
|                     loaded = pickle.load(f) | ||||
|  | ||||
|                 torch.testing.assert_close(loaded(x), mttm(x)) | ||||
|  | ||||
|     def test_proxy_tensor(self): | ||||
|         def f(x): | ||||
|             val = x.cos().cos().sum() | ||||
|  | ||||
| @ -4,7 +4,7 @@ import warnings | ||||
| import functools | ||||
| import builtins | ||||
|  | ||||
| from typing import Callable, Dict | ||||
| from typing import Any, Callable, Dict, Optional, Union | ||||
|  | ||||
| def embedding_override(self, input): | ||||
|     return torch.empty(*input.shape, self.weight.shape[-1], device='meta') | ||||
| @ -250,7 +250,19 @@ class MetaTracer(torch.fx.Tracer): | ||||
|             self.orig_fns.add(orig) | ||||
|  | ||||
|         try: | ||||
|             return super().trace(root, concrete_args) | ||||
|             graph = super().trace(root, concrete_args) | ||||
|             graph._tracer_extras = {'meta_args': meta_args} | ||||
|             return graph | ||||
|         finally: | ||||
|             for name, (_, orig) in self.patched_torch_methods.items(): | ||||
|                 setattr(torch, name, orig) | ||||
|  | ||||
|  | ||||
| def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], | ||||
|                    meta_args : Dict[str, torch.Tensor] = None, | ||||
|                    concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule: | ||||
|     tracer = MetaTracer() | ||||
|     graph = tracer.trace(root, meta_args, concrete_args) | ||||
|     name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ | ||||
|     gm = torch.fx.GraphModule(tracer.root, graph, name) | ||||
|     return gm | ||||
|  | ||||
| @ -598,7 +598,8 @@ class Graph: | ||||
|     """ | ||||
|  | ||||
|     @compatibility(is_backward_compatible=True) | ||||
|     def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None): | ||||
|     def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, | ||||
|                  tracer_extras: Optional[Dict[str, Any]] = None): | ||||
|         """ | ||||
|         Construct an empty Graph. | ||||
|         """ | ||||
| @ -610,6 +611,7 @@ class Graph: | ||||
|         self._owners = 0 | ||||
|         self._owning_module = owning_module | ||||
|         self._tracer_cls = tracer_cls | ||||
|         self._tracer_extras = tracer_extras | ||||
|         self._codegen = CodeGen() | ||||
|  | ||||
|     @property | ||||
|  | ||||
| @ -160,7 +160,8 @@ def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module: | ||||
|  | ||||
|     com = CodeOnlyModule(body) | ||||
|  | ||||
|     graph = KeepModules().trace(com) | ||||
|     tracer_extras = body.get('_tracer_extras', {}) | ||||
|     graph = KeepModules().trace(com, **tracer_extras) | ||||
|  | ||||
|     # Manually set Tracer class on the reconstructed Graph, to avoid | ||||
|     # referencing the private local subclass KeepModules. | ||||
| @ -373,6 +374,10 @@ class GraphModule(torch.nn.Module): | ||||
|         if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__: | ||||
|             self._tracer_cls = self.graph._tracer_cls | ||||
|  | ||||
|         self._tracer_extras = {} | ||||
|         if self.graph._tracer_extras: | ||||
|             self._tracer_extras = self.graph._tracer_extras | ||||
|  | ||||
|     # TorchScript breaks trying to compile the graph setter because of the | ||||
|     # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 | ||||
|     # | ||||
|  | ||||
		Reference in New Issue
	
	Block a user