mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
return_and_correct_aliasing: skip dispatcher when swapping storage (#132524)
`return_and_correct_aliasing` is used by FunctionalTensor today to ensure that when we call view/inplace ops, the input and output `FunctionalTensors` share the same storage. This was previously done with a dispatcher call to `aten.set_`. In this PR I swap it out with a util that just manually does the storage swap. Benefits: (1) we know this is safe in the specific way it is used by FunctionalTensor: avoiding the extra assertions in `aten.set_` is necessary to avoid some unbacked symint errors (2) this should improve compile times a bit Pull Request resolved: https://github.com/pytorch/pytorch/pull/132524 Approved by: https://github.com/ezyang ghstack dependencies: #132243, #132337, #132322
This commit is contained in:
committed by
PyTorch MergeBot
parent
eca0cb0fbe
commit
26c6786109
@ -792,6 +792,9 @@ def gen_pyi(
|
||||
"_functionalize_commit_update": [
|
||||
"def _functionalize_commit_update(t: Tensor) -> None: ..."
|
||||
],
|
||||
"_functionalize_unsafe_set": [
|
||||
"def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..."
|
||||
],
|
||||
"_functionalize_mark_mutation_hidden_from_autograd": [
|
||||
"def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ..."
|
||||
],
|
||||
|
@ -21,6 +21,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
|
||||
#include <Python.h>
|
||||
#include <fmt/format.h>
|
||||
@ -711,6 +712,28 @@ void initTorchFunctions(PyObject* module) {
|
||||
auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
||||
return wrapper->was_storage_changed();
|
||||
});
|
||||
py_module.def(
|
||||
"_functionalize_unsafe_set", [](at::Tensor& dst, const at::Tensor& src) {
|
||||
// Forcefully/unsafely dumps src.storage into dst.
|
||||
// This API is technically and not specific to functionalization
|
||||
// (it just runs set_() without the safety checks).
|
||||
// But its main intended purpose today is during functionalization.
|
||||
// In particular: when we generate a new FunctionalTensor from a view
|
||||
// op, we need to ensure it shares a storage with the view input.
|
||||
//
|
||||
// Other subclasses shouldn't really need to care about this,
|
||||
// because we define aliasing on wrapper subclasses such that:
|
||||
// - differentiable aliasing: subclass_x and subclass_y share a ._base.
|
||||
// - non-differentiable aliasing: aliasing of subclass_x and subclass_y
|
||||
// is defined recursively based on the aliasing of their inner
|
||||
// tensors.
|
||||
at::native::checkSetStorage(
|
||||
dst,
|
||||
src.storage(),
|
||||
dst.sym_storage_offset(),
|
||||
dst.sym_sizes(),
|
||||
dst.sym_strides());
|
||||
});
|
||||
py_module.def(
|
||||
"_functionalize_mark_mutation_hidden_from_autograd",
|
||||
[](const at::Tensor& t) {
|
||||
|
@ -463,44 +463,23 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||
r
|
||||
), f"""Called {str(func)} with input of type {type(arg)}
|
||||
and output of type {type(ret)}. But expected types to match."""
|
||||
# Need to run under no_dispatch, because we explicitly do **not**
|
||||
# Need to call a non-dispatcher helper, because we explicitly do **not**
|
||||
# want our subclass to intercept the set_() call.
|
||||
# instead, our subclass should directly have its storage swapped out.
|
||||
with torch.utils._mode_utils.no_dispatch():
|
||||
# See Note: [Fake Tensor Dispatch Keys]
|
||||
# we're borrowing the way it modifies dispatch key TLS.
|
||||
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
|
||||
torch._C._set_meta_in_tls_dispatch_include(True)
|
||||
try:
|
||||
# directly calling this overload, and passing ret.shape, because we **explicitly**
|
||||
# don't want to reset the sizes on ret, if the storage implies a size change.
|
||||
# Why?
|
||||
# The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct.
|
||||
# We just want to "fix up" the storage aliasing, without modifying or output's metadata.
|
||||
# Example: out = inp.expand(inp.shape[0], inp.shape[0])
|
||||
# This requires swapping the storage of out to be the same as inp,
|
||||
# but we do *not* want it to change the sizes/strides that were compute for out.
|
||||
# we **explicitly** don't want to reset the sizes on ret, if the storage implies a size change.
|
||||
# Why?
|
||||
# The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct.
|
||||
# We just want to "fix up" the storage aliasing, without modifying or output's metadata.
|
||||
# Example: out = inp.expand(inp.shape[0], inp.shape[0])
|
||||
# This requires swapping the storage of out to be the same as inp,
|
||||
# but we do *not* want it to change the sizes/strides that were compute for out.
|
||||
|
||||
if isinstance(ret, list):
|
||||
for r in ret:
|
||||
torch.ops.aten.set_.source_Storage_storage_offset(
|
||||
r,
|
||||
arg.untyped_storage(),
|
||||
r.storage_offset(),
|
||||
r.shape,
|
||||
r.stride(),
|
||||
)
|
||||
else:
|
||||
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
|
||||
torch.ops.aten.set_.source_Storage_storage_offset(
|
||||
ret,
|
||||
arg.untyped_storage(),
|
||||
ret.storage_offset(),
|
||||
ret.shape,
|
||||
ret.stride(),
|
||||
)
|
||||
finally:
|
||||
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
|
||||
if isinstance(ret, list):
|
||||
for r in ret:
|
||||
torch._functionalize_unsafe_set(r, arg)
|
||||
else:
|
||||
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
|
||||
torch._functionalize_unsafe_set(ret, arg)
|
||||
|
||||
def is_read_only_alias_match(arg, ret):
|
||||
shared_aliases = arg.alias_set & ret.alias_set
|
||||
|
Reference in New Issue
Block a user