mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove ExclusivelyOwned from register_dispatch_key (#106791)
This fixes a bug that could occur with python decompositions. When an operation is intercepted in the c++ code in pytorch the outputs a created as `ExclusivelyOwned<at::Tensor>`s. Later on when it dispatches back to python for the decomposition these tensors have their ownership shared with python. In a normal use case the exclusively owned tensor is released and it's value returned as a non-exclusively owned tensor from the operation. However if the python decomposition throws an error the `ExclusivelyOwned` wrapper destroys the `at::Tensor` leading to a python reference to a tensor which isn't alive (and meaning pytorch falls over in debug mode). Note this will be a performance hit when handling errors. Fixes #106790 Pull Request resolved: https://github.com/pytorch/pytorch/pull/106791 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
d97b18d769
commit
c9cdcb299a
@ -577,7 +577,6 @@ class StructuredRegisterDispatchKey(RegisterDispatchKey):
|
||||
set_output_super = ""
|
||||
|
||||
def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
|
||||
maybe_star = "*" if k is SchemaKind.functional else ""
|
||||
return f"""
|
||||
void set_output_{name}(
|
||||
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
|
||||
@ -585,7 +584,7 @@ void set_output_{name}(
|
||||
) override {{
|
||||
{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
|
||||
if (!names.empty()) {{
|
||||
namedinference::propagate_names({maybe_star}outputs_[output_idx], names);
|
||||
namedinference::propagate_names(outputs_[output_idx], names);
|
||||
}}
|
||||
// super must happen after, so that downstream can use maybe_get_output
|
||||
// to retrieve the output
|
||||
@ -621,7 +620,7 @@ if (C10_UNLIKELY(current_device.has_value())) {
|
||||
create_proxy = """
|
||||
auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
|
||||
if (C10_UNLIKELY(maybe_proxy.has_value())) {
|
||||
proxy_outputs_[output_idx] = c10::ExclusivelyOwned<Tensor>(std::move(maybe_proxy).value());
|
||||
proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
|
||||
}
|
||||
"""
|
||||
else:
|
||||
@ -683,17 +682,17 @@ resize_out(out, sizes, strides, options);
|
||||
generate_super: bool,
|
||||
) -> str:
|
||||
if k is SchemaKind.functional:
|
||||
output_type = "c10::ExclusivelyOwned<Tensor>"
|
||||
output_value = "*outputs_[output_idx]"
|
||||
output_type = "Tensor"
|
||||
output_value = "outputs_[output_idx]"
|
||||
proxy_field = ""
|
||||
elif k is SchemaKind.inplace:
|
||||
output_type = "std::reference_wrapper<Tensor>"
|
||||
output_value = "proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get()"
|
||||
proxy_field = f"std::array<c10::optional<c10::ExclusivelyOwned<Tensor>>, {len(f.func.returns)}> proxy_outputs_;"
|
||||
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
|
||||
proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
|
||||
elif k is SchemaKind.out:
|
||||
output_type = "std::reference_wrapper<Tensor>"
|
||||
output_value = "proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get()"
|
||||
proxy_field = f"std::array<c10::optional<c10::ExclusivelyOwned<Tensor>>, {len(f.func.returns)}> proxy_outputs_;"
|
||||
output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
|
||||
proxy_field = f"std::array<c10::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
|
||||
|
||||
if self.backend_index.dispatch_key == DispatchKey.CUDA:
|
||||
if self.rocm:
|
||||
@ -886,8 +885,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
||||
if k is SchemaKind.out:
|
||||
expr = f"op.maybe_get_output({i})"
|
||||
else:
|
||||
maybe_star = "*" if k is SchemaKind.functional else ""
|
||||
expr = f"{maybe_star}op.outputs_[{i}]"
|
||||
expr = f"op.outputs_[{i}]"
|
||||
|
||||
context.append(
|
||||
Expr(
|
||||
@ -942,17 +940,17 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
||||
if k is SchemaKind.out or k is SchemaKind.inplace:
|
||||
for i in range(len(f.func.returns)):
|
||||
sig_body.append(
|
||||
f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(**op.proxy_outputs_[{i}]);"
|
||||
f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
|
||||
)
|
||||
|
||||
# Destructively return the final tensors
|
||||
# TODO: Do this in translate instead
|
||||
if k is SchemaKind.functional:
|
||||
if len(f.func.returns) == 1:
|
||||
ret_expr = "std::move(op.outputs_[0]).take()" # small optimization
|
||||
ret_expr = "std::move(op.outputs_[0])" # small optimization
|
||||
else:
|
||||
moved = ", ".join(
|
||||
f"std::move(op.outputs_[{i}]).take()"
|
||||
f"std::move(op.outputs_[{i}])"
|
||||
for i in range(len(f.func.returns))
|
||||
)
|
||||
ret_expr = f"std::make_tuple({moved})"
|
||||
|
||||
Reference in New Issue
Block a user