mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make view.dtype always return an alias (#136074)
Fixes https://github.com/pytorch/pytorch/issues/136064 In the linked repro, this issue was that there was some code like this: ``` # x has dtype torch.float32 def f(x): y = x.view(torch.float32) y.copy_(...) ``` Where because `view.dtype` is implemented today to potentially directly return its input, we would end up directly clobbering the proxy for our graph input (replacing its FX proxy value from `arg0_1` to `view_1`). This is not desirable, because we have careful assertions in AOTDispatcher that mutations only ever happen on graph inputs - but this clobbering caused the mutation to appear, from the perspective of the FX graph, like it was happening on a view of the input. Why is this normally not a problem? Ordinarily, the `ADInplaceOrView` kernel for `view.dtype` will take the output of the view kernel, [and detach() it](https://github.com/pytorch/pytorch/blob/main/tools/autograd/gen_inplace_or_view_type.py#L466) (properly creating a fresh `TensorImpl`). This does **not** happen, though, if you are executing the kernel from with a `__torch_dispatch__` region: the `ADInplaceOrView` logic has already run above you, so that key will be in the TLS exclude set. This PR changes eager behavior - at first I considered trying to only change behavior under compile. But this problem isn't technically specific to PT2: if you ever rely on tensor identity from inside of a __torch_dispatch__ call, then we need to make sure the raw `view.dtype` kernel doesn't directly return the input. I am also making the assumption that "`view.dtype` no-op'ing when the dtype is the same" is not a case worth optimizing in eager mode, and that the overhead of the `TensorImpl` creation is relatively negligible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136074 Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/albanD ghstack dependencies: #136041
This commit is contained in:
committed by
PyTorch MergeBot
parent
d463a81c27
commit
dc82d274e6
@ -1793,6 +1793,23 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual(x, None)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/136064
|
||||
def test_view_returns_alias_under_torch_dispatch(self):
|
||||
class MyMode(TorchDispatchMode):
|
||||
def __init__(self, testcase):
|
||||
self.testcase = testcase
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
out = func(*args, **kwargs)
|
||||
if func == torch.ops.aten.view.dtype:
|
||||
# view should return a fresh TensorImpl
|
||||
self.testcase.assertTrue(out is not args[0])
|
||||
return out
|
||||
|
||||
with MyMode(self):
|
||||
x = torch.ones(4, dtype=torch.float32)
|
||||
out = x.view(torch.float32)
|
||||
|
||||
def test_record_stream(self) -> None:
|
||||
class TestMode(TorchDispatchMode):
|
||||
def __init__(self, testcase):
|
||||
|
Reference in New Issue
Block a user