mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Don't detach when making views; force caller to detach (#84893)
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84893 Approved by: https://github.com/soulitzer, https://github.com/SherlockNoMad
This commit is contained in:
committed by
PyTorch MergeBot
parent
ec916bf6af
commit
3bb8d6a93c
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
09be9870437684ba2da6741af3eb10126c04aede
|
||||
8a78bec5dbb43c1047c30abffe89ac622ac7911b
|
||||
|
@ -2102,7 +2102,7 @@ Tensor slice(
|
||||
auto quantizer = create_subtensor_quantizer(self, false, start_val, end_val, dim, step);
|
||||
result = as_strided_qtensorimpl(self, sizes, strides, storage_offset, quantizer);
|
||||
} else {
|
||||
result = self.as_strided(sizes, strides, storage_offset);
|
||||
result = as_strided_tensorimpl(self, sizes, strides, storage_offset);
|
||||
}
|
||||
namedinference::propagate_names(result, self);
|
||||
return result;
|
||||
|
@ -1214,19 +1214,19 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
|
||||
}
|
||||
|
||||
torch::Tensor view_op(const torch::Tensor& self) {
|
||||
return self;
|
||||
return self.alias();
|
||||
}
|
||||
|
||||
torch::Tensor view_op_with_extra_arg(
|
||||
const torch::Tensor& self,
|
||||
const torch::Tensor& other) {
|
||||
return self;
|
||||
return self.alias();
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> ret_tensor_vector_view(
|
||||
const torch::Tensor& self,
|
||||
const torch::Tensor& other) {
|
||||
return {self, self};
|
||||
return {self.alias(), self.alias()};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> ret_tensor_vector(
|
||||
|
@ -732,6 +732,17 @@ class TestFakeProxyTensor(TestCase):
|
||||
x, y = torch.randn(2), torch.randn(2)
|
||||
self.assertEqual(g(x, y), f(x, y))
|
||||
|
||||
def test_alias(self):
|
||||
def f(x):
|
||||
return torch.ops.aten.alias(x)
|
||||
|
||||
r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
|
||||
# NB: this should not have a detach call
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, x_1):
|
||||
alias = torch.ops.aten.alias.default(x_1); x_1 = None
|
||||
return alias""")
|
||||
|
||||
def _get_node(fx_g, cond):
|
||||
for n in fx_g.graph.nodes:
|
||||
if cond(n):
|
||||
|
@ -81,6 +81,18 @@ inline void throw_error_for_complex_autograd(
|
||||
}
|
||||
}
|
||||
|
||||
inline void throw_error_if_base_and_tensor_are_same(
|
||||
const at::Tensor& base,
|
||||
const at::Tensor& tensor) {
|
||||
TORCH_CHECK(
|
||||
base.unsafeGetTensorImpl() != tensor.unsafeGetTensorImpl(),
|
||||
"View operation returned a tensor that is the same as the input base tensor. This "
|
||||
"is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). "
|
||||
"As a user, you could have made a mistake implementing __torch_dispatch__ or a Python "
|
||||
"operator decomposition or meta registration; if that's not the case, please "
|
||||
"report a bug to PyTorch or the backend you are using.");
|
||||
}
|
||||
|
||||
inline void throw_error_for_complex_autograd(
|
||||
const at::TensorList& tensorlist,
|
||||
const char* name) {
|
||||
@ -167,6 +179,7 @@ inline at::Tensor as_view(
|
||||
// be used for both of them.
|
||||
if ((!diff_view_meta || diff_view_meta->shared_view_info()) &&
|
||||
is_bw_differentiable && is_fw_differentiable) {
|
||||
throw_error_if_base_and_tensor_are_same(base, tensor);
|
||||
if (diff_view_meta) {
|
||||
creation_meta = propagate_creation_meta(
|
||||
diff_view_meta->get_creation_meta(), creation_meta);
|
||||
@ -220,6 +233,7 @@ inline at::Tensor as_view(
|
||||
creation_meta = propagate_creation_meta(
|
||||
diff_view_meta->get_creation_meta(), creation_meta);
|
||||
}
|
||||
throw_error_if_base_and_tensor_are_same(base, tensor);
|
||||
return make_variable_differentiable_view(
|
||||
tensor,
|
||||
std::move(new_bw_info),
|
||||
|
@ -687,13 +687,17 @@ inline Variable make_variable_differentiable_view(
|
||||
CreationMeta creation_meta,
|
||||
bool allow_tensor_metadata_change = true) {
|
||||
if (data.defined()) {
|
||||
// If we already did a TensorImpl allocation for data, just reuse it.
|
||||
// Otherwise(e.g tensor.swapdim(0, 0) when we return the same tensor as
|
||||
// input), we have to use shallow_copy_and_detach to create a new TensorImpl
|
||||
// to avoid moving leaf node into graph interior. This guarantees only 1
|
||||
// TensorImpl allocation happens in view ops.
|
||||
if (data.getIntrusivePtr().unique() &&
|
||||
data.getIntrusivePtr()->unique_version()) {
|
||||
TORCH_CHECK(
|
||||
data.getIntrusivePtr()->autograd_meta() == nullptr,
|
||||
"Attempted to make a tensor into a differentiable view, but the "
|
||||
"tensor already had autograd metadata associated with it. If you are "
|
||||
"using a __torch_dispatch__ mode, the most common cause for this "
|
||||
"problem is that you used torch.overrides.enable_reentrant_dispatch() "
|
||||
"improperly; tensors created within the extent of reentrant dispatch "
|
||||
"MUST NOT be directly returned from __torch_dispatch__; instead, they "
|
||||
"must be wrapped into fresh tensors that serve as the output. If you "
|
||||
"are not using wrappers, you probably don't need reentrant dispatch. "
|
||||
"If this doesn't seem applicable, please file a bug to PyTorch.");
|
||||
at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
|
||||
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
|
||||
@ -703,21 +707,6 @@ inline Variable make_variable_differentiable_view(
|
||||
shared_view_info,
|
||||
creation_meta));
|
||||
return data;
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
c10::intrusive_ptr<at::TensorImpl> data_impl_copy =
|
||||
data.getIntrusivePtr()->shallow_copy_and_detach(
|
||||
/*version_counter=*/0,
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
data_impl_copy->set_autograd_meta(
|
||||
std::make_unique<DifferentiableViewMeta>(
|
||||
data_impl_copy.get(),
|
||||
std::move(backward_info),
|
||||
std::move(forward_info),
|
||||
shared_view_info,
|
||||
creation_meta));
|
||||
return Variable(data_impl_copy);
|
||||
}
|
||||
}
|
||||
return Variable();
|
||||
}
|
||||
|
Reference in New Issue
Block a user