Create native function for determining which implementation of SDP to call (#89029)

# Summary
Creates a callable native function that can determine which implementation of scaled dot product will get called. This allows to bump re-order the runtime dispatch of SDP to enable autograd.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89029
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2022-11-16 03:07:54 +00:00
committed by PyTorch MergeBot
parent 397f100672
commit b291c1213a
9 changed files with 137 additions and 33 deletions

View File

@ -73,6 +73,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"record_stream", # no return
"sparse_dim", # returns an int
"_nested_tensor_offsets", # returns a vector of ints
"_fused_sdp_choice", # returns an int
]
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [