[AOTI] Add more fallback ops (#126720)

Summary: These ops are either in either unit tests or TorchBench. Fixes https://github.com/pytorch/pytorch/issues/122050

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126720
Approved by: https://github.com/chenyang78
This commit is contained in:
Bin Bao
2024-05-21 11:23:36 -07:00
committed by PyTorch MergeBot
parent 0d17aae242
commit 19cd4484ec
6 changed files with 17 additions and 5 deletions

View File

@ -333,6 +333,8 @@ def get_backend_index_for_aoti(
backend_index = backend_indices[
DispatchKey.CompositeExplicitAutogradNonFunctional
]
elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
return backend_index
@ -471,6 +473,7 @@ extern "C" {{
#include <ATen/{str(dispatch_key)}Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeImplicitAutogradFunctions.h>
#else
{includes}
#endif