mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to eager, because it avoids a decomposition. You can force the decomposition by having the decomposition in the dispatch table, but if eager mode wouldn't have decomposed (because it went to the fused one), we now default to preserving the fused call by default. This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel. Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939 Approved by: https://github.com/bdhirsh ghstack dependencies: #164573
This commit is contained in:
committed by
PyTorch MergeBot
parent
e532f62e0d
commit
d40a9bfb8d
@ -52,9 +52,7 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
|
||||
// where we would like to support composite implicit kernels but not
|
||||
// explicit kernels therefore we manually add the key to the
|
||||
// math_dispatch_keyset
|
||||
DispatchKeySet{DispatchKey::NestedTensor} |
|
||||
// Functionalize should always reuse CompositeImplicit decomps.
|
||||
DispatchKeySet{DispatchKey::Functionalize};
|
||||
DispatchKeySet{DispatchKey::NestedTensor};
|
||||
|
||||
constexpr DispatchKeySet nested_dispatch_keyset =
|
||||
DispatchKeySet(
|
||||
|
Reference in New Issue
Block a user