mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: When unpickling a fake tensor in fx graph pickler. It only sets the fake mode of the current tensor's metadata to the one that is consistent with pickler's `unpickle_state`. However, it doesn't set the fake mode of a tensor's base tensor when that tensor is a view. This will cause an issue when dumping and loading the following graph ``` class GraphModule(torch.nn.Module): def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, 8]"): l_x_ = L_x_ chunk = l_x_.chunk(2, dim = -1); l_x_ = None y: "f32[s77, 4]" = chunk[0]; chunk = None y_repeat: "f32[s77, 8]" = y.repeat_interleave(2, dim = -1); y = None return (y_repeat,) ``` because `repeat_interleave` will create an intermediate fake tensor of size `[s77, 2, 4]` and it will become the base of the node `y_repeat`'s `meta['val']`. This causes issues during the deserialization phase when applying AOT precompile to DeepSeek in vLLM. Test Plan: This has been tested in vLLM with DeepSeek. As for unittest, ideally it should be `test_aot_compile_repeat_interleave` with mark_dynamic turned on. However, that's leading to some other pickle issues. ``` python test/dynamo/test_aot_compile.py -k test_aot_compile_repeat_interleave ``` I have yet to figure out a more appropriate unittest. But a proof-of-concept demo would be the following: ``` import inspect import sympy import torch from torch.fx._graph_pickler import GraphPickler from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch._subclasses import FakeTensorMode from torch.fx._graph_pickler import GraphPickler, Options from unittest.mock import patch class M(torch.nn.Module): def forward(self, x): chunk = x.chunk(2, dim=-1) y = chunk[0] y_repeat = y.repeat_interleave(2, dim=-1) return y_repeat def my_custom_backend(gm, example_inputs): global gm_global gm_global = gm return gm.forward m = M() m_opt = torch.compile(m, backend=my_custom_backend, fullgraph=True) sample_inputs = (torch.randn(2, 8),) torch._dynamo.mark_dynamic(sample_inputs[0], [0]) opt_out = m_opt(*sample_inputs) graph_reducer_override = GraphPickler.reducer_override def _graph_reducer_override(self, obj): if (inspect.isclass(obj) and issubclass(obj, sympy.Function) and hasattr(obj, "_torch_unpickler")): return obj._torch_unpickler, (obj._torch_handler_name, ) if isinstance(obj, FakeTensorMode): return type(None), () return graph_reducer_override(self, obj) with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): pickled_gm = GraphPickler.dumps(gm_global, Options(ops_filter=None)) fake_mode = FakeTensorMode(shape_env=ShapeEnv()) loaded_gm = GraphPickler.loads(pickled_gm, fake_mode) ``` Differential Revision: D83112599 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163738 Approved by: https://github.com/zhxchen17