mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Support InplaceBernoulliFallback in the ABI-compatible codegen (#126183)
Summary: Update the torchgen rule for inplace ops like bernoulli_, and update InplaceBernoulliFallback to codegen in the ABI-compatible mode. Fixes https://github.com/pytorch/pytorch/issues/121809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126183 Approved by: https://github.com/angelayi ghstack dependencies: #126181, #126182
This commit is contained in:
committed by
PyTorch MergeBot
parent
5792bc3c3e
commit
0332b5812e
@ -249,18 +249,18 @@ def gen_declaration_and_definition(
|
||||
return declaration_definition_cache[(func_name, device, backend_call)]
|
||||
|
||||
if schema.is_out_fn():
|
||||
# out_variant has out arguments in the front, and it's ok to ignore return value
|
||||
# out_variant has out arguments in the front, and it's ok to ignore return values
|
||||
# because C shim functions only return AOTITorchError
|
||||
# Somehow at::native out-variant functions have out arguments in the back
|
||||
args, callsite_exprs = gen_arguments(
|
||||
[*schema.arguments.flat_non_out, *schema.arguments.out]
|
||||
if "at::native" in backend_call
|
||||
else [*schema.arguments.out, *schema.arguments.flat_non_out],
|
||||
[*schema.arguments.out, *schema.arguments.flat_non_out]
|
||||
)
|
||||
ret_assignments: List[str] = []
|
||||
else:
|
||||
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
||||
ret_declarations, ret_assignments = gen_returns(schema)
|
||||
# ignore return values for inplace ops
|
||||
ret_declarations, ret_assignments = (
|
||||
([], []) if schema.name.name.inplace else gen_returns(schema)
|
||||
)
|
||||
args.extend(ret_declarations)
|
||||
|
||||
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
|
||||
|
||||
Reference in New Issue
Block a user