[dynamo] Graph break when faking named tensors (#120779)

Fixes #120644
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120779
Approved by: https://github.com/zou3519
This commit is contained in:
James Wu
2024-02-29 06:21:24 -08:00
committed by PyTorch MergeBot
parent 1104e0798c
commit a911eb74ae
54 changed files with 18 additions and 4 deletions

View File

@ -4766,7 +4766,7 @@ class CommonTemplate:
dtype=torch.float32,
device=a.device,
),
torch.zeros(2, 3, names=None),
torch.zeros(2, 3),
a + torch.ones(8, device=a.device),
torch.full((2, 3), 3.1416, device=a.device),
)

View File

@ -686,10 +686,12 @@ class TestNamedTensor(TestCase):
self.assertEqual(op(a, a).names, ('N', 'C'))
self.assertEqual(op(a, c).names, ('N', 'C'))
with self.assertRaisesRegex(RuntimeError, "do not match"):
# TODO: dynamo will throw a slightly different
# error message because it's adding fake tensors
# `must match the size of` portion is the dynamo error
with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
op(a, d)
with self.assertRaisesRegex(RuntimeError, "do not match"):
with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
op(a, b)
def test_wildcard(op):

View File

@ -157,6 +157,11 @@ def constructors(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
if "names" in kwargs:
raise UnsupportedOperatorException(
"torch.compile doesn't support named tensors"
)
if func in _like_tensor_constructors:
default_device = new_kwargs["input"].device
# TODO: file issue

View File

@ -448,6 +448,13 @@ class FakeTensor(torch.Tensor):
# that have dispatch keys which are higher than the "meta" key:
# https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L189
# We don't support named tensors; graph break
@property
def names(self):
raise UnsupportedFakeTensorException(
"torch.compile doesn't support named tensors"
)
@staticmethod
def __new__(cls, fake_mode, elem, device, constant=None):
self = torch.Tensor._make_subclass(