cpp_wrapper/aot_inductor: handle conjugation and negation dispatch keys (#145095)

Handles conjugation and negation in the same way that runtime dispatch does: by on-the-fly cloning a tensor with either key applied.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145095
Approved by: https://github.com/desertfire
This commit is contained in:
Benjamin Glass
2025-02-04 18:04:09 +00:00
committed by PyTorch MergeBot
parent 09b0dfdc90
commit 7c0fe7a045
4 changed files with 100 additions and 43 deletions

View File

@ -59,7 +59,7 @@ base_type_to_aten_type = {
}
base_type_to_callsite_expr = {
BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
BaseTy.Tensor: "resolve_tensor_dispatch_flags",
BaseTy.bool: "",
BaseTy.int: "",
BaseTy.SymInt: "",
@ -75,21 +75,30 @@ base_type_to_callsite_expr = {
# convert args to C types, names in declarations, and expressions in function bodies
def convert_arg_type_and_name( # type: ignore[return]
def convert_arg_type_and_name(
typ: Type,
name: str,
is_write: bool = False,
) -> tuple[list[str], list[str], list[str], list[str]]:
if isinstance(typ, BaseType):
if typ.name in base_type_to_c_type:
if typ.name == BaseTy.Tensor and is_write:
# For output tensors, our normal call to resolve_tensor_dispatch_flags
# results in an rvalue tensor, which can't be passed to at::Tensor&.
# Override this case specifically.
callsite_expr = [f"*tensor_handle_to_tensor_pointer({name})"]
else:
callsite_expr = [
f"{base_type_to_callsite_expr[typ.name]}({name})"
if base_type_to_callsite_expr[typ.name]
else name
]
return (
[base_type_to_c_type[typ.name]],
[name],
[base_type_to_aten_type[typ.name]],
[
f"{base_type_to_callsite_expr[typ.name]}({name})"
if base_type_to_callsite_expr[typ.name]
else name
],
callsite_expr,
)
elif typ.name == BaseTy.Device:
return (
@ -128,6 +137,10 @@ def convert_arg_type_and_name( # type: ignore[return]
f"pointer_to_optional_device({names[j]}, {names[j + 1]})"
)
j += 2
elif aten_type == "at::Tensor":
new_aten_types.append(f"::std::optional<{aten_type}>")
new_callsite_exprs.append(f"resolve_tensor_dispatch_flags({names[j]})")
j += 1
else:
new_aten_types.append(f"::std::optional<{aten_type}>")
new_callsite_exprs.append(
@ -159,10 +172,14 @@ def convert_arg_type_and_name( # type: ignore[return]
# construct std::array<bool, N> instead
assert typ.size is not None
callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
elif atype == "at::Tensor" and not is_write:
callsite_exprs.append(
f"resolve_tensor_list_dispatch_flags({name}, {name}_len_)"
)
elif atype == "::std::optional<at::Tensor>":
# convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
callsite_exprs.append(
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(resolve_tensor_list_dispatch_flags({name}, {name}_len_)))"
)
else:
callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
@ -174,6 +191,7 @@ def convert_arg_type_and_name( # type: ignore[return]
aten_types,
callsite_exprs,
)
raise NotImplementedError(f"Argument type {repr(typ)} not supported!")
def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
@ -187,7 +205,7 @@ def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[s
callsite_exprs = []
for arg in flat_arguments:
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
arg.type, arg.name
arg.type, arg.name, arg.is_write
)
types.extend(new_types)
new_names.extend(names)