make view.dtype always return an alias (#136074)

Fixes https://github.com/pytorch/pytorch/issues/136064

In the linked repro, this issue was that there was some code like this:
```
# x has dtype torch.float32
def f(x):
    y = x.view(torch.float32)
    y.copy_(...)
```

Where because `view.dtype` is implemented today to potentially directly return its input, we would end up directly clobbering the proxy for our graph input (replacing its FX proxy value from `arg0_1` to `view_1`). This is not desirable, because we have careful assertions in AOTDispatcher that mutations only ever happen on graph inputs - but this clobbering caused the mutation to appear, from the perspective of the FX graph, like it was happening on a view of the input.

Why is this normally not a problem? Ordinarily, the `ADInplaceOrView` kernel for `view.dtype` will take the output of the view kernel, [and detach() it](https://github.com/pytorch/pytorch/blob/main/tools/autograd/gen_inplace_or_view_type.py#L466) (properly creating a fresh `TensorImpl`).

This does **not** happen, though, if you are executing the kernel from with a `__torch_dispatch__` region: the `ADInplaceOrView` logic has already run above you, so that key will be in the TLS exclude set.

This PR changes eager behavior - at first I considered trying to only change behavior under compile. But this problem isn't technically specific to PT2: if you ever rely on tensor identity from inside of a __torch_dispatch__ call, then we need to make sure the raw `view.dtype` kernel doesn't directly return the input.

I am also making the assumption that "`view.dtype` no-op'ing when the dtype is the same" is not a case worth optimizing in eager mode, and that the overhead of the `TensorImpl` creation is relatively negligible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136074
Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/albanD
ghstack dependencies: #136041
This commit is contained in:
Brian Hirsh
2024-09-13 21:45:46 -07:00
committed by PyTorch MergeBot
parent d463a81c27
commit dc82d274e6
2 changed files with 17 additions and 3 deletions

View File

@ -1793,6 +1793,23 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
with self.assertRaises(AssertionError):
self.assertEqual(x, None)
# See https://github.com/pytorch/pytorch/issues/136064
def test_view_returns_alias_under_torch_dispatch(self):
class MyMode(TorchDispatchMode):
def __init__(self, testcase):
self.testcase = testcase
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
out = func(*args, **kwargs)
if func == torch.ops.aten.view.dtype:
# view should return a fresh TensorImpl
self.testcase.assertTrue(out is not args[0])
return out
with MyMode(self):
x = torch.ones(4, dtype=torch.float32)
out = x.view(torch.float32)
def test_record_stream(self) -> None:
class TestMode(TorchDispatchMode):
def __init__(self, testcase):