Files
pytorch/test/dynamo
Yiming Zhou 33f3413bd3 [WIP][precompile] Set fake_mode of base tensor in fx graph pickler (#163738)
Summary:
When unpickling a fake tensor in fx graph pickler. It only sets the fake mode of the current tensor's metadata to the one that is consistent with pickler's `unpickle_state`. However, it doesn't set the fake mode of a tensor's base tensor when that tensor is a view.

This will cause an issue when dumping and loading the following graph
```
class GraphModule(torch.nn.Module):
    def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, 8]"):
        l_x_ = L_x_
        chunk = l_x_.chunk(2, dim = -1);  l_x_ = None
        y: "f32[s77, 4]" = chunk[0];  chunk = None
        y_repeat: "f32[s77, 8]" = y.repeat_interleave(2, dim = -1);  y = None
        return (y_repeat,)
```
because `repeat_interleave` will create an intermediate fake tensor of size `[s77, 2, 4]` and it will become the base of the node `y_repeat`'s `meta['val']`.

This causes issues during the deserialization phase when applying AOT precompile to DeepSeek in vLLM.

Test Plan:
This has been tested in vLLM with DeepSeek.

As for unittest, ideally it should be `test_aot_compile_repeat_interleave` with mark_dynamic turned on. However, that's leading to some other pickle issues.

```
python test/dynamo/test_aot_compile.py -k test_aot_compile_repeat_interleave
```

I have yet to figure out a more appropriate unittest. But a proof-of-concept demo would be the following:
```
import inspect
import sympy
import torch
from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
from unittest.mock import patch

class M(torch.nn.Module):
    def forward(self, x):
        chunk = x.chunk(2, dim=-1)
        y = chunk[0]
        y_repeat = y.repeat_interleave(2, dim=-1)
        return y_repeat

def my_custom_backend(gm, example_inputs):
    global gm_global
    gm_global = gm
    return gm.forward

m = M()
m_opt = torch.compile(m, backend=my_custom_backend, fullgraph=True)

sample_inputs = (torch.randn(2, 8),)
torch._dynamo.mark_dynamic(sample_inputs[0], [0])
opt_out = m_opt(*sample_inputs)

graph_reducer_override = GraphPickler.reducer_override

def _graph_reducer_override(self, obj):
    if (inspect.isclass(obj) and issubclass(obj, sympy.Function)
            and hasattr(obj, "_torch_unpickler")):
        return obj._torch_unpickler, (obj._torch_handler_name, )
    if isinstance(obj, FakeTensorMode):
        return type(None), ()
    return graph_reducer_override(self, obj)

with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
    pickled_gm = GraphPickler.dumps(gm_global, Options(ops_filter=None))

fake_mode = FakeTensorMode(shape_env=ShapeEnv())
loaded_gm = GraphPickler.loads(pickled_gm, fake_mode)
```

Differential Revision: D83112599

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163738
Approved by: https://github.com/zhxchen17
2025-09-26 14:36:37 +00:00
..