[AOTI] Support InplaceBernoulliFallback in the ABI-compatible codegen (#126183)

Summary: Update the torchgen rule for inplace ops like bernoulli_, and update InplaceBernoulliFallback to codegen in the ABI-compatible mode. Fixes https://github.com/pytorch/pytorch/issues/121809

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126183
Approved by: https://github.com/angelayi
ghstack dependencies: #126181, #126182
This commit is contained in:
Bin Bao
2024-05-15 07:05:09 -07:00
committed by PyTorch MergeBot
parent 5792bc3c3e
commit 0332b5812e
8 changed files with 38 additions and 22 deletions

View File

@ -249,18 +249,18 @@ def gen_declaration_and_definition(
return declaration_definition_cache[(func_name, device, backend_call)]
if schema.is_out_fn():
# out_variant has out arguments in the front, and it's ok to ignore return value
# out_variant has out arguments in the front, and it's ok to ignore return values
# because C shim functions only return AOTITorchError
# Somehow at::native out-variant functions have out arguments in the back
args, callsite_exprs = gen_arguments(
[*schema.arguments.flat_non_out, *schema.arguments.out]
if "at::native" in backend_call
else [*schema.arguments.out, *schema.arguments.flat_non_out],
[*schema.arguments.out, *schema.arguments.flat_non_out]
)
ret_assignments: List[str] = []
else:
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
ret_declarations, ret_assignments = gen_returns(schema)
# ignore return values for inplace ops
ret_declarations, ret_assignments = (
([], []) if schema.name.name.inplace else gen_returns(schema)
)
args.extend(ret_declarations)
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"