mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Graph break when faking named tensors (#120779)
Fixes #120644 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120779 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
1104e0798c
commit
a911eb74ae
@ -4766,7 +4766,7 @@ class CommonTemplate:
|
||||
dtype=torch.float32,
|
||||
device=a.device,
|
||||
),
|
||||
torch.zeros(2, 3, names=None),
|
||||
torch.zeros(2, 3),
|
||||
a + torch.ones(8, device=a.device),
|
||||
torch.full((2, 3), 3.1416, device=a.device),
|
||||
)
|
||||
|
||||
@ -686,10 +686,12 @@ class TestNamedTensor(TestCase):
|
||||
|
||||
self.assertEqual(op(a, a).names, ('N', 'C'))
|
||||
self.assertEqual(op(a, c).names, ('N', 'C'))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "do not match"):
|
||||
# TODO: dynamo will throw a slightly different
|
||||
# error message because it's adding fake tensors
|
||||
# `must match the size of` portion is the dynamo error
|
||||
with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
|
||||
op(a, d)
|
||||
with self.assertRaisesRegex(RuntimeError, "do not match"):
|
||||
with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
|
||||
op(a, b)
|
||||
|
||||
def test_wildcard(op):
|
||||
|
||||
@ -157,6 +157,11 @@ def constructors(fake_mode, func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
if "names" in kwargs:
|
||||
raise UnsupportedOperatorException(
|
||||
"torch.compile doesn't support named tensors"
|
||||
)
|
||||
|
||||
if func in _like_tensor_constructors:
|
||||
default_device = new_kwargs["input"].device
|
||||
# TODO: file issue
|
||||
|
||||
@ -448,6 +448,13 @@ class FakeTensor(torch.Tensor):
|
||||
# that have dispatch keys which are higher than the "meta" key:
|
||||
# https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L189
|
||||
|
||||
# We don't support named tensors; graph break
|
||||
@property
|
||||
def names(self):
|
||||
raise UnsupportedFakeTensorException(
|
||||
"torch.compile doesn't support named tensors"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, fake_mode, elem, device, constant=None):
|
||||
self = torch.Tensor._make_subclass(
|
||||
|
||||
Reference in New Issue
Block a user