cpp_wrapper: Properly handle scalars when input to tensor arguments (#144910)

Additionally, reduce code duplication in `cpp_wrapper_cpu_array_ref.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144910
Approved by: https://github.com/desertfire
This commit is contained in:
Benjamin Glass
2025-01-21 16:14:27 +00:00
committed by PyTorch MergeBot
parent 47e65077b1
commit d5629889f1
6 changed files with 68 additions and 124 deletions

View File

@ -1970,11 +1970,37 @@ if (custom_op_wrapper.get() == NULL) {
# Py_NewRef is only available since Python 3.10
self.include_extra_header("torch/csrc/utils/pythoncapi_compat.h")
def handle_scalar(scalar):
if isinstance(scalar, int):
return f"PyLong_FromLongLong({scalar})"
if isinstance(scalar, float):
return f"PyFloat_FromDouble({self.generate_float_value(scalar)})"
if isinstance(scalar, bool):
return f"PyBool_FromLong({1 if scalar else 0})"
if isinstance(scalar, complex):
real = self.generate_float_value(scalar.real)
imag = self.generate_float_value(scalar.imag)
return f"PyComplex_FromDoubles({real}, {imag})"
if isinstance(scalar, SymTypes):
scalar_var = cexpr(scalar.node.expr)
if isinstance(scalar, torch.SymBool):
return f"PyBool_FromLong({scalar_var})"
if isinstance(scalar, torch.SymFloat):
return f"PyFloat_FromDouble({scalar_var})"
return f"PyLong_FromLongLong({scalar_var})"
raise NotImplementedError(
f"scalar {scalar}, {type(scalar)} cannot be handled by handle_scalar"
)
if raw_arg is None:
# Py_None is a singleton, so we have to explicitly incref it here
lines.append("Py_INCREF(Py_None);\n")
return "Py_None"
elif isinstance(arg_type, torch.TensorType):
# In some cases, scalar arguments may be passed in place of tensors.
if not hasattr(raw_arg, "codegen_reference"):
return handle_scalar(raw_arg)
# Store AtenTensorHandle as void*
base_handle = raw_arg.codegen_reference()
(
@ -2005,23 +2031,7 @@ if (custom_op_wrapper.get() == NULL) {
elif isinstance(arg_type, torch.NumberType):
# Union[bool, int, float, complex]
# torch/_prims_common/__init__.py
if isinstance(raw_arg, int):
return f"PyLong_FromLongLong({raw_arg})"
elif isinstance(raw_arg, float):
return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})"
elif isinstance(raw_arg, bool):
return f"PyBool_FromLong({1 if raw_arg else 0})"
elif isinstance(raw_arg, complex):
real = self.generate_float_value(raw_arg.real)
imag = self.generate_float_value(raw_arg.imag)
return f"PyComplex_FromDoubles({real}, {imag})"
elif isinstance(raw_arg, torch.SymInt):
expr = raw_arg.node.expr
return f"PyLong_FromLongLong({cexpr(expr)})"
else:
raise NotImplementedError(
f"arg type {arg_type} with raw_arg {raw_arg}, {type(raw_arg)} is not yet supported by custom_op_wrapper"
)
return handle_scalar(raw_arg)
elif isinstance(raw_arg, torch.device):
# device
self.include_extra_header("torch/csrc/Device.h")
@ -2221,6 +2231,8 @@ reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_
elif isinstance(val, int):
# uint64_t is long on Linux, but long long on MacOS and Windows
return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L"
elif isinstance(val, complex):
return f"c10::complex<double>{{ {self.generate_float_value(val.real)}, {self.generate_float_value(val.imag)} }}"
elif isinstance(val, str):
return f'"{val}"'
elif isinstance(
@ -2309,7 +2321,7 @@ reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_
self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();")
return f"&{var_name}"
elif isinstance(type_, torch.ListType):
if isinstance(type_, torch.ListType):
assert isinstance(
val, (list, tuple)
), f"{val} does not match with arg type {type_}"
@ -2329,6 +2341,11 @@ reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_
# Need to pass the array length because we can't use std::vector
return f"{var_name}, {len(val)}"
val_is_scalar = isinstance(val, (bool, complex, float, int, *SymTypes))
if isinstance(type_, torch.TensorType) and val_is_scalar:
val_str = self.val_to_arg_str_for_prim_type(val, None)
return self.codegen_scalar_to_tensor(val_str)
return self.val_to_arg_str_for_prim_type(val, type_)
def create_tmp_raii_handle_var(self, base_handle):

View File

@ -953,56 +953,14 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
return final_tmp_name
def val_to_arg_str(self, val, type_=None) -> str:
if val is None:
# None needs special care. It either represent nullopt or an empty tensor
if type_ is None or isinstance(type_, torch.OptionalType):
if type_ is not None and isinstance(
type_.getElementType(),
(
torch.ListType,
torch.TupleType,
torch.DeviceObjType,
),
if (
val is not None
and isinstance(type_, torch.OptionalType)
and isinstance(type_.getElementType(), torch.TensorType)
):
return "0, 0"
else:
return "0" # nullptr is not available in C
elif isinstance(type_, torch.TensorType):
# create an empty tensor, the equivalent of at::Tensor()
var_name = f"var_{next(self.arg_var_id)}"
self.writeline(f"AtenTensorHandle {var_name}_handle;")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));"
)
self.writeline(f"RAIIAtenTensorHandle {var_name}({var_name}_handle);")
return var_name
else:
raise AssertionError("Can not map None to a known data type")
if isinstance(type_, torch.OptionalType):
element_type = type_.getElementType()
if not isinstance(element_type, torch.TensorType):
var_name = f"var_{next(self.arg_var_id)}"
if isinstance(
element_type,
(torch.ListType, torch.TupleType, torch.DeviceObjType),
):
# type_ is something like Optional[List] or Optional[Device]
arg_str = self.val_to_arg_str(val, element_type)
# For datatypes with auxiliary info, we need to hoist out the extra arguments.
# NOTE: This only works if there is one additional argument, though it can easily be generalized.
main_value, aux = arg_str.rsplit(", ")
self.writeline(f"auto {var_name} = {main_value};")
return f"&{var_name}, {aux}"
else:
self.writeline(
f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};"
)
return f"&{var_name}"
else:
# type_ is Optional[Tensor]
# Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
base_handle = self.val_to_arg_str(val, element_type)
base_handle = self.val_to_arg_str(val, torch.TensorType)
if config.aot_inductor.use_minimal_arrayref_interface:
if self.is_safe_to_use_borrow_arrayref_tensor_as_tensor():
base_handle = f"borrow_arrayref_tensor_as_tensor({base_handle})"
@ -1019,27 +977,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();")
return f"&{var_name}"
elif isinstance(type_, torch.ListType):
assert isinstance(
val, (list, tuple)
), f"{val} does not match with arg type {type_}"
element_type = type_.getElementType()
var_name = f"var_array_{next(self.var_array_id)}"
if len(val) == 0:
# Zero-size array is not supported in the C or C++ standard, so
# we declare a null pointer for it.
self.writeline(
f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;"
)
else:
result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}"
self.writeline(
f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};"
)
# Need to pass the array length because we can't use std::vector
return f"{var_name}, {len(val)}"
return self.val_to_arg_str_for_prim_type(val, type_)
return super().val_to_arg_str(val, type_)
def codegen_tensor_item(
self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None

View File

@ -6292,27 +6292,6 @@ class FallbackKernel(ExternKernelAlloc):
*,
unbacked_bindings=None,
) -> None:
# When aten binary ops have constant second args, cpp wrapper expects the scalar
# version. This should long-term be handled as in
# https://github.com/pytorch/pytorch/issues/90923.
BINARY_OP_MAPPING = {
aten.add.Tensor: aten.add.Scalar,
aten.div.Tensor: aten.div.Scalar,
aten.divide.Tensor: aten.divide.Scalar,
aten.floor_divide: aten.floor_divide.Scalar,
aten.mul.Tensor: aten.mul.Scalar,
aten.multiply.Tensor: aten.multiply.Scalar,
aten.sub.Tensor: aten.sub.Scalar,
aten.subtract.Tensor: aten.subtract.Scalar,
aten.true_divide.Tensor: aten.true_divide.Scalar,
}
if (
kernel in BINARY_OP_MAPPING
and len(tensor_args) == 1
and len(nontensor_args) == 1
):
kernel = BINARY_OP_MAPPING[kernel]
super().__init__(
layout,
tuple(tensor_args),

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/util/complex.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
namespace torch::aot_inductor {
@ -30,6 +31,8 @@ AOTI_RUNTIME_SCALAR_TO_TENSOR(int16, int16_t)
AOTI_RUNTIME_SCALAR_TO_TENSOR(int32, int32_t)
AOTI_RUNTIME_SCALAR_TO_TENSOR(int64, int64_t)
AOTI_RUNTIME_SCALAR_TO_TENSOR(bool, bool)
AOTI_RUNTIME_SCALAR_TO_TENSOR(complex64, c10::complex<float>)
AOTI_RUNTIME_SCALAR_TO_TENSOR(complex128, c10::complex<double>)
#undef AOTI_RUNTIME_SCALAR_TO_TENSOR
} // namespace torch::aot_inductor

View File

@ -209,6 +209,9 @@ aoti_torch_scalar_to_tensor_bool(bool value, AtenTensorHandle* ret_new_tensor);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex64(
c10::complex<float> value,
AtenTensorHandle* ret_new_tensor);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex128(
c10::complex<double> value,
AtenTensorHandle* ret_new_tensor);
AOTI_TORCH_EXPORT bool aoti_torch_grad_mode_is_enabled();
AOTI_TORCH_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled);

View File

@ -223,6 +223,10 @@ AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int32, int32_t, Int)
AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int64, int64_t, Long)
AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(bool, bool, Bool)
AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(complex64, c10::complex<float>, ComplexFloat)
AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(
complex128,
c10::complex<double>,
ComplexDouble)
#undef AOTI_TORCH_SCALAR_TO_TENSOR_IMPL
bool aoti_torch_grad_mode_is_enabled() {