Don't detach when making views; force caller to detach (#84893)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84893
Approved by: https://github.com/soulitzer, https://github.com/SherlockNoMad
This commit is contained in:
Edward Z. Yang
2022-09-13 11:42:12 -07:00
committed by PyTorch MergeBot
parent ec916bf6af
commit 3bb8d6a93c
6 changed files with 50 additions and 36 deletions

View File

@ -1 +1 @@
09be9870437684ba2da6741af3eb10126c04aede
8a78bec5dbb43c1047c30abffe89ac622ac7911b

View File

@ -2102,7 +2102,7 @@ Tensor slice(
auto quantizer = create_subtensor_quantizer(self, false, start_val, end_val, dim, step);
result = as_strided_qtensorimpl(self, sizes, strides, storage_offset, quantizer);
} else {
result = self.as_strided(sizes, strides, storage_offset);
result = as_strided_tensorimpl(self, sizes, strides, storage_offset);
}
namedinference::propagate_names(result, self);
return result;

View File

@ -1214,19 +1214,19 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
}
torch::Tensor view_op(const torch::Tensor& self) {
return self;
return self.alias();
}
torch::Tensor view_op_with_extra_arg(
const torch::Tensor& self,
const torch::Tensor& other) {
return self;
return self.alias();
}
std::vector<torch::Tensor> ret_tensor_vector_view(
const torch::Tensor& self,
const torch::Tensor& other) {
return {self, self};
return {self.alias(), self.alias()};
}
std::vector<at::Tensor> ret_tensor_vector(

View File

@ -732,6 +732,17 @@ class TestFakeProxyTensor(TestCase):
x, y = torch.randn(2), torch.randn(2)
self.assertEqual(g(x, y), f(x, y))
def test_alias(self):
def f(x):
return torch.ops.aten.alias(x)
r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
# NB: this should not have a detach call
self.assertExpectedInline(r, """\
def forward(self, x_1):
alias = torch.ops.aten.alias.default(x_1); x_1 = None
return alias""")
def _get_node(fx_g, cond):
for n in fx_g.graph.nodes:
if cond(n):

View File

@ -81,6 +81,18 @@ inline void throw_error_for_complex_autograd(
}
}
inline void throw_error_if_base_and_tensor_are_same(
const at::Tensor& base,
const at::Tensor& tensor) {
TORCH_CHECK(
base.unsafeGetTensorImpl() != tensor.unsafeGetTensorImpl(),
"View operation returned a tensor that is the same as the input base tensor. This "
"is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). "
"As a user, you could have made a mistake implementing __torch_dispatch__ or a Python "
"operator decomposition or meta registration; if that's not the case, please "
"report a bug to PyTorch or the backend you are using.");
}
inline void throw_error_for_complex_autograd(
const at::TensorList& tensorlist,
const char* name) {
@ -167,6 +179,7 @@ inline at::Tensor as_view(
// be used for both of them.
if ((!diff_view_meta || diff_view_meta->shared_view_info()) &&
is_bw_differentiable && is_fw_differentiable) {
throw_error_if_base_and_tensor_are_same(base, tensor);
if (diff_view_meta) {
creation_meta = propagate_creation_meta(
diff_view_meta->get_creation_meta(), creation_meta);
@ -220,6 +233,7 @@ inline at::Tensor as_view(
creation_meta = propagate_creation_meta(
diff_view_meta->get_creation_meta(), creation_meta);
}
throw_error_if_base_and_tensor_are_same(base, tensor);
return make_variable_differentiable_view(
tensor,
std::move(new_bw_info),

View File

@ -687,13 +687,17 @@ inline Variable make_variable_differentiable_view(
CreationMeta creation_meta,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
// If we already did a TensorImpl allocation for data, just reuse it.
// Otherwise(e.g tensor.swapdim(0, 0) when we return the same tensor as
// input), we have to use shallow_copy_and_detach to create a new TensorImpl
// to avoid moving leaf node into graph interior. This guarantees only 1
// TensorImpl allocation happens in view ops.
if (data.getIntrusivePtr().unique() &&
data.getIntrusivePtr()->unique_version()) {
TORCH_CHECK(
data.getIntrusivePtr()->autograd_meta() == nullptr,
"Attempted to make a tensor into a differentiable view, but the "
"tensor already had autograd metadata associated with it. If you are "
"using a __torch_dispatch__ mode, the most common cause for this "
"problem is that you used torch.overrides.enable_reentrant_dispatch() "
"improperly; tensors created within the extent of reentrant dispatch "
"MUST NOT be directly returned from __torch_dispatch__; instead, they "
"must be wrapped into fresh tensors that serve as the output. If you "
"are not using wrappers, you probably don't need reentrant dispatch. "
"If this doesn't seem applicable, please file a bug to PyTorch.");
at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
@ -703,21 +707,6 @@ inline Variable make_variable_differentiable_view(
shared_view_info,
creation_meta));
return data;
} else {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
c10::intrusive_ptr<at::TensorImpl> data_impl_copy =
data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
data_impl_copy->set_autograd_meta(
std::make_unique<DifferentiableViewMeta>(
data_impl_copy.get(),
std::move(backward_info),
std::move(forward_info),
shared_view_info,
creation_meta));
return Variable(data_impl_copy);
}
}
return Variable();
}