mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Graph break on torch.Tensor.data
assignment with mismatched dtype (#156623)
Fixes #152162. Discussed with @bdhirsh and decided this is the easiest workaround for now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156623 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8cf5ff564
commit
d06a406656
@ -4460,6 +4460,20 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
# frame_count should stay at 1.
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_tensor_set_data_mismatched_dtype(self):
|
||||
def func(x, y):
|
||||
x.data = y.to(dtype=torch.bfloat16)
|
||||
|
||||
x1 = torch.tensor([], dtype=torch.float32)
|
||||
x2 = torch.tensor([], dtype=torch.float32)
|
||||
y1 = torch.tensor([1, 2, 3], dtype=torch.float32)
|
||||
y2 = torch.tensor([1, 2, 3], dtype=torch.float32)
|
||||
func(x1, y1)
|
||||
torch.compile(func, backend="eager")(x2, y2)
|
||||
self.assertEqual(x1, x2)
|
||||
self.assertEqual(x1.data, x2.data)
|
||||
self.assertEqual(y1, y2)
|
||||
|
||||
def test_user_ctor_ctx_manager(self):
|
||||
class UserCtxManager:
|
||||
def __enter__(self):
|
||||
|
@ -2162,5 +2162,19 @@
|
||||
"Report an issue to the backend compiler repo."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0220": [
|
||||
{
|
||||
"Gb_type": "Failed to mutate tensor data attribute to different dtype",
|
||||
"Context": "setattr({obj}, {name}, {val})",
|
||||
"Explanation": "Dyanmo only supports mutating `.data` of tensor to a new one with the same dtype",
|
||||
"Hints": [
|
||||
"Don't mutate `.data` on this tensor, or move ",
|
||||
"the mutation out of `torch.compile` region"
|
||||
],
|
||||
"Additional_Info": [
|
||||
"INFO"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -2175,6 +2175,17 @@ class BuiltinVariable(VariableTracker):
|
||||
"the mutation out of `torch.compile` region",
|
||||
],
|
||||
)
|
||||
elif obj.dtype != val.dtype: # type: ignore[attr-defined]
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to mutate tensor data attribute to different dtype",
|
||||
context=f"setattr({obj}, {name}, {val})",
|
||||
explanation="Dyanmo only supports mutating `.data`"
|
||||
" of tensor to a new one with the same dtype",
|
||||
hints=[
|
||||
"Don't mutate `.data` on this tensor, or move "
|
||||
"the mutation out of `torch.compile` region",
|
||||
],
|
||||
)
|
||||
|
||||
# Remove the old reference in tracked fakes - if we don't do this
|
||||
# new .data value size and shape differences will cause
|
||||
|
Reference in New Issue
Block a user