mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
cpp_wrapper/aot_inductor: handle conjugation and negation dispatch keys (#145095)
Handles conjugation and negation in the same way that runtime dispatch does: by on-the-fly cloning a tensor with either key applied. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145095 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
09b0dfdc90
commit
7c0fe7a045
@ -59,7 +59,7 @@ base_type_to_aten_type = {
|
||||
}
|
||||
|
||||
base_type_to_callsite_expr = {
|
||||
BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
|
||||
BaseTy.Tensor: "resolve_tensor_dispatch_flags",
|
||||
BaseTy.bool: "",
|
||||
BaseTy.int: "",
|
||||
BaseTy.SymInt: "",
|
||||
@ -75,21 +75,30 @@ base_type_to_callsite_expr = {
|
||||
|
||||
|
||||
# convert args to C types, names in declarations, and expressions in function bodies
|
||||
def convert_arg_type_and_name( # type: ignore[return]
|
||||
def convert_arg_type_and_name(
|
||||
typ: Type,
|
||||
name: str,
|
||||
is_write: bool = False,
|
||||
) -> tuple[list[str], list[str], list[str], list[str]]:
|
||||
if isinstance(typ, BaseType):
|
||||
if typ.name in base_type_to_c_type:
|
||||
if typ.name == BaseTy.Tensor and is_write:
|
||||
# For output tensors, our normal call to resolve_tensor_dispatch_flags
|
||||
# results in an rvalue tensor, which can't be passed to at::Tensor&.
|
||||
# Override this case specifically.
|
||||
callsite_expr = [f"*tensor_handle_to_tensor_pointer({name})"]
|
||||
else:
|
||||
callsite_expr = [
|
||||
f"{base_type_to_callsite_expr[typ.name]}({name})"
|
||||
if base_type_to_callsite_expr[typ.name]
|
||||
else name
|
||||
]
|
||||
|
||||
return (
|
||||
[base_type_to_c_type[typ.name]],
|
||||
[name],
|
||||
[base_type_to_aten_type[typ.name]],
|
||||
[
|
||||
f"{base_type_to_callsite_expr[typ.name]}({name})"
|
||||
if base_type_to_callsite_expr[typ.name]
|
||||
else name
|
||||
],
|
||||
callsite_expr,
|
||||
)
|
||||
elif typ.name == BaseTy.Device:
|
||||
return (
|
||||
@ -128,6 +137,10 @@ def convert_arg_type_and_name( # type: ignore[return]
|
||||
f"pointer_to_optional_device({names[j]}, {names[j + 1]})"
|
||||
)
|
||||
j += 2
|
||||
elif aten_type == "at::Tensor":
|
||||
new_aten_types.append(f"::std::optional<{aten_type}>")
|
||||
new_callsite_exprs.append(f"resolve_tensor_dispatch_flags({names[j]})")
|
||||
j += 1
|
||||
else:
|
||||
new_aten_types.append(f"::std::optional<{aten_type}>")
|
||||
new_callsite_exprs.append(
|
||||
@ -159,10 +172,14 @@ def convert_arg_type_and_name( # type: ignore[return]
|
||||
# construct std::array<bool, N> instead
|
||||
assert typ.size is not None
|
||||
callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
|
||||
elif atype == "at::Tensor" and not is_write:
|
||||
callsite_exprs.append(
|
||||
f"resolve_tensor_list_dispatch_flags({name}, {name}_len_)"
|
||||
)
|
||||
elif atype == "::std::optional<at::Tensor>":
|
||||
# convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
|
||||
callsite_exprs.append(
|
||||
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
|
||||
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(resolve_tensor_list_dispatch_flags({name}, {name}_len_)))"
|
||||
)
|
||||
else:
|
||||
callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
|
||||
@ -174,6 +191,7 @@ def convert_arg_type_and_name( # type: ignore[return]
|
||||
aten_types,
|
||||
callsite_exprs,
|
||||
)
|
||||
raise NotImplementedError(f"Argument type {repr(typ)} not supported!")
|
||||
|
||||
|
||||
def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
|
||||
@ -187,7 +205,7 @@ def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[s
|
||||
callsite_exprs = []
|
||||
for arg in flat_arguments:
|
||||
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
|
||||
arg.type, arg.name
|
||||
arg.type, arg.name, arg.is_write
|
||||
)
|
||||
types.extend(new_types)
|
||||
new_names.extend(names)
|
||||
|
Reference in New Issue
Block a user