detect fake mode in proxy_tensor creation in make_fx (#144168)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/143742

A FakeTensorMode may already exist when we are setting the "val" meta of a proxy tensor. We should detect existing FakeTensorMode before creating a new one.

Otherwise, we could cause an error when using `detect_fake_mode` later, because there are now multiple FakeTensorModes existing.

Test Plan: The error in https://github.com/pytorch/pytorch/issues/143742

Differential Revision: D67813111

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144168
Approved by: https://github.com/BoyuanFeng, https://github.com/tugsbayasgalan
This commit is contained in:
Shangdi Yu
2025-01-06 19:02:08 +00:00
committed by PyTorch MergeBot
parent e56768f030
commit e3aac7f8a0
2 changed files with 19 additions and 1 deletions

View File

@ -924,6 +924,20 @@ def forward(self, x_1):
continue
self.assertTrue('val' in n.meta)
def test_fake_tensor_mode(self):
def f(a):
d = a.cos()
return d
from torch._guards import detect_fake_mode
existing_fake_mode = FakeTensorMode()
with existing_fake_mode:
out = make_fx(f, tracing_mode="real")(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]))
fake_mode = detect_fake_mode([node.meta.get('val', None) for node in out.graph.nodes])
self.assertEqual(fake_mode, existing_fake_mode)
def _get_node(fx_g, cond):
for n in fx_g.graph.nodes:
if cond(n):