Rename convert_arrayref_tensor_to_tensor to copy_arrayref_tensor_to_tensor (#142182)

Be explicit about what we are doing, in preparation for adding borrow_arrayref_tensor_as_tensor.

Differential Revision: [D66847772](https://our.internmc.facebook.com/intern/diff/D66847772/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142182
Approved by: https://github.com/desertfire
ghstack dependencies: #142340
This commit is contained in:
Scott Wolchok
2024-12-09 10:29:55 -08:00
committed by PyTorch MergeBot
parent dc1ef9afb4
commit 18d25aa7aa
3 changed files with 18 additions and 26 deletions

View File

@ -659,16 +659,16 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
for x in args:
pieces = x.split(", ")
for piece in pieces:
# We only really *need* convert_arrayref_tensor_to_tensor for
# We only really *need* copy_arrayref_tensor_to_tensor for
# ArrayRefTensors. The code flowing into here uses `0` for nullptr,
# which convert_arrayref_tensor_to_tensor would blindly coerce to int,
# which copy_arrayref_tensor_to_tensor would blindly coerce to int,
# so just avoid wrapping integers.
# Name matching is to find tensor is hacky, but fixing all the
# ArrayRefTensor issues is not a priority for now.
if isinstance(piece, str) and piece.startswith(
("buf", "arg", "wrap_with_raii_handle_if_needed")
):
piece = f"convert_arrayref_tensor_to_tensor({piece})"
piece = f"copy_arrayref_tensor_to_tensor({piece})"
wrapped_args.append(piece)
debug_printer_manager.set_printer_args(args, kernel, None, None, "extern")
@ -696,14 +696,10 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
inputs_wrapped = [
(
f"convert_arrayref_tensor_to_tensor({x})"
if isinstance(x, str)
else str(x)
)
(f"copy_arrayref_tensor_to_tensor({x})" if isinstance(x, str) else str(x))
for x in inputs
]
line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}"
line = f"{cpp_kernel_name}(copy_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}"
if python_kernel_name.startswith("aten.scatter_reduce"):
line += f", {','.join(kwargs)}"
@ -728,22 +724,18 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
# tensor prematurely deallocated, thus this std::vector().data() trick here.
indices_str = (
"std::vector<AtenTensorHandle>{"
+ (
", ".join(
[f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices]
)
)
+ (", ".join([f"copy_arrayref_tensor_to_tensor({ind})" for ind in indices]))
+ "}.data()"
)
args = [
f"convert_arrayref_tensor_to_tensor({x})",
f"copy_arrayref_tensor_to_tensor({x})",
indices_str,
str(len(indices)),
f"convert_arrayref_tensor_to_tensor({values})",
f"copy_arrayref_tensor_to_tensor({values})",
accumulate,
]
args.insert(
0, f"convert_arrayref_tensor_to_tensor({x})"
0, f"copy_arrayref_tensor_to_tensor({x})"
) # set x as the output tensor, this fallback mutates x.
self.writeline(self.wrap_kernel_call(kernel, args))
@ -985,7 +977,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
# Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
base_handle = self.val_to_arg_str(val, element_type)
if config.aot_inductor.use_minimal_arrayref_interface:
base_handle = f"convert_arrayref_tensor_to_tensor({base_handle})"
base_handle = f"copy_arrayref_tensor_to_tensor({base_handle})"
(
tmp_raii_handle_var,
tmp_raii_handle_var_decl,
@ -1029,8 +1021,8 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
scalar_tmp = f"{scalar}_tmp"
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};")
# need convert_arrayref_tensor_to_tensor for ArrayRefTensors
tensor = f"convert_arrayref_tensor_to_tensor({tensor})"
# need copy_arrayref_tensor_to_tensor for ArrayRefTensors
tensor = f"copy_arrayref_tensor_to_tensor({tensor})"
writer.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));"
@ -1039,8 +1031,8 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
else:
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
# need convert_arrayref_tensor_to_tensor for ArrayRefTensors
tensor = f"convert_arrayref_tensor_to_tensor({tensor})"
# need copy_arrayref_tensor_to_tensor for ArrayRefTensors
tensor = f"copy_arrayref_tensor_to_tensor({tensor})"
writer.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
@ -1049,7 +1041,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
def create_tmp_raii_handle_var(self, base_handle):
if base_handle.startswith(
(
"convert_arrayref_tensor_to_tensor",
"copy_arrayref_tensor_to_tensor",
"wrap_with_raii_handle_if_needed",
)
):

View File

@ -156,7 +156,7 @@ class DebugPrinterManager:
self.args_to_print_or_save = args_to_print_or_save_extern
elif kernel_type == "cpp":
args_to_print_or_save_cpp = [
f"convert_arrayref_tensor_to_tensor({arg})"
f"copy_arrayref_tensor_to_tensor({arg})"
for arg in args_to_print_or_save
if arg.startswith(("buf", "arg"))
]

View File

@ -341,12 +341,12 @@ inline AtenTensorHandle expensive_copy_to_tensor_if_needed(
}
template <typename T>
const T& convert_arrayref_tensor_to_tensor(const T& t) {
const T& copy_arrayref_tensor_to_tensor(const T& t) {
return t;
}
template <typename T>
RAIIAtenTensorHandle convert_arrayref_tensor_to_tensor(
RAIIAtenTensorHandle copy_arrayref_tensor_to_tensor(
const ArrayRefTensor<T>& art) {
return art.expensiveCopyToTensor();
}