mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
detect fake mode in proxy_tensor creation in make_fx (#144168)
Summary: Fixes https://github.com/pytorch/pytorch/issues/143742 A FakeTensorMode may already exist when we are setting the "val" meta of a proxy tensor. We should detect existing FakeTensorMode before creating a new one. Otherwise, we could cause an error when using `detect_fake_mode` later, because there are now multiple FakeTensorModes existing. Test Plan: The error in https://github.com/pytorch/pytorch/issues/143742 Differential Revision: D67813111 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144168 Approved by: https://github.com/BoyuanFeng, https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
e56768f030
commit
e3aac7f8a0
@ -924,6 +924,20 @@ def forward(self, x_1):
|
||||
continue
|
||||
self.assertTrue('val' in n.meta)
|
||||
|
||||
def test_fake_tensor_mode(self):
|
||||
def f(a):
|
||||
d = a.cos()
|
||||
return d
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
existing_fake_mode = FakeTensorMode()
|
||||
with existing_fake_mode:
|
||||
out = make_fx(f, tracing_mode="real")(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]))
|
||||
|
||||
fake_mode = detect_fake_mode([node.meta.get('val', None) for node in out.graph.nodes])
|
||||
self.assertEqual(fake_mode, existing_fake_mode)
|
||||
|
||||
def _get_node(fx_g, cond):
|
||||
for n in fx_g.graph.nodes:
|
||||
if cond(n):
|
||||
|
||||
Reference in New Issue
Block a user