return_and_correct_aliasing: skip dispatcher when swapping storage (#132524)

`return_and_correct_aliasing` is used by FunctionalTensor today to ensure that when we call view/inplace ops, the input and output `FunctionalTensors` share the same storage.

This was previously done with a dispatcher call to `aten.set_`. In this PR I swap it out with a util that just manually does the storage swap. Benefits:

(1) we know this is safe in the specific way it is used by FunctionalTensor: avoiding the extra assertions in `aten.set_` is necessary to avoid some unbacked symint errors

(2) this should improve compile times a bit

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132524
Approved by: https://github.com/ezyang
ghstack dependencies: #132243, #132337, #132322
This commit is contained in:
Brian Hirsh
2024-08-05 14:35:25 -07:00
committed by PyTorch MergeBot
parent eca0cb0fbe
commit 26c6786109
3 changed files with 40 additions and 35 deletions

View File

@ -792,6 +792,9 @@ def gen_pyi(
"_functionalize_commit_update": [
"def _functionalize_commit_update(t: Tensor) -> None: ..."
],
"_functionalize_unsafe_set": [
"def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..."
],
"_functionalize_mark_mutation_hidden_from_autograd": [
"def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ..."
],

View File

@ -21,6 +21,7 @@
#include <ATen/ATen.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/native/Resize.h>
#include <Python.h>
#include <fmt/format.h>
@ -711,6 +712,28 @@ void initTorchFunctions(PyObject* module) {
auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
return wrapper->was_storage_changed();
});
py_module.def(
"_functionalize_unsafe_set", [](at::Tensor& dst, const at::Tensor& src) {
// Forcefully/unsafely dumps src.storage into dst.
// This API is technically and not specific to functionalization
// (it just runs set_() without the safety checks).
// But its main intended purpose today is during functionalization.
// In particular: when we generate a new FunctionalTensor from a view
// op, we need to ensure it shares a storage with the view input.
//
// Other subclasses shouldn't really need to care about this,
// because we define aliasing on wrapper subclasses such that:
// - differentiable aliasing: subclass_x and subclass_y share a ._base.
// - non-differentiable aliasing: aliasing of subclass_x and subclass_y
// is defined recursively based on the aliasing of their inner
// tensors.
at::native::checkSetStorage(
dst,
src.storage(),
dst.sym_storage_offset(),
dst.sym_sizes(),
dst.sym_strides());
});
py_module.def(
"_functionalize_mark_mutation_hidden_from_autograd",
[](const at::Tensor& t) {

View File

@ -463,44 +463,23 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
r
), f"""Called {str(func)} with input of type {type(arg)}
and output of type {type(ret)}. But expected types to match."""
# Need to run under no_dispatch, because we explicitly do **not**
# Need to call a non-dispatcher helper, because we explicitly do **not**
# want our subclass to intercept the set_() call.
# instead, our subclass should directly have its storage swapped out.
with torch.utils._mode_utils.no_dispatch():
# See Note: [Fake Tensor Dispatch Keys]
# we're borrowing the way it modifies dispatch key TLS.
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
torch._C._set_meta_in_tls_dispatch_include(True)
try:
# directly calling this overload, and passing ret.shape, because we **explicitly**
# don't want to reset the sizes on ret, if the storage implies a size change.
# Why?
# The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct.
# We just want to "fix up" the storage aliasing, without modifying or output's metadata.
# Example: out = inp.expand(inp.shape[0], inp.shape[0])
# This requires swapping the storage of out to be the same as inp,
# but we do *not* want it to change the sizes/strides that were compute for out.
# we **explicitly** don't want to reset the sizes on ret, if the storage implies a size change.
# Why?
# The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct.
# We just want to "fix up" the storage aliasing, without modifying or output's metadata.
# Example: out = inp.expand(inp.shape[0], inp.shape[0])
# This requires swapping the storage of out to be the same as inp,
# but we do *not* want it to change the sizes/strides that were compute for out.
if isinstance(ret, list):
for r in ret:
torch.ops.aten.set_.source_Storage_storage_offset(
r,
arg.untyped_storage(),
r.storage_offset(),
r.shape,
r.stride(),
)
else:
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
torch.ops.aten.set_.source_Storage_storage_offset(
ret,
arg.untyped_storage(),
ret.storage_offset(),
ret.shape,
ret.stride(),
)
finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
if isinstance(ret, list):
for r in ret:
torch._functionalize_unsafe_set(r, arg)
else:
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
torch._functionalize_unsafe_set(ret, arg)
def is_read_only_alias_match(arg, ret):
shared_aliases = arg.alias_set & ret.alias_set