mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to eager, because it avoids a decomposition. You can force the decomposition by having the decomposition in the dispatch table, but if eager mode wouldn't have decomposed (because it went to the fused one), we now default to preserving the fused call by default. This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel. Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939 Approved by: https://github.com/bdhirsh ghstack dependencies: #164573
This commit is contained in:
committed by
PyTorch MergeBot
parent
e532f62e0d
commit
d40a9bfb8d
@ -404,6 +404,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.max_unpool3d,
|
||||
aten.mish,
|
||||
aten.mish_,
|
||||
aten.mish_backward,
|
||||
aten.mse_loss,
|
||||
aten.mse_loss_backward,
|
||||
aten.multi_margin_loss,
|
||||
@ -419,6 +420,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.native_dropout_backward,
|
||||
aten.native_group_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten._fused_rms_norm,
|
||||
aten._fused_rms_norm_backward,
|
||||
aten.new_empty,
|
||||
aten.new_full,
|
||||
@ -475,6 +477,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.silu,
|
||||
aten.silu_,
|
||||
aten.silu_backward.grad_input,
|
||||
aten.silu_backward,
|
||||
aten.sinc,
|
||||
aten.sinc_,
|
||||
aten.slice_backward,
|
||||
|
@ -1757,6 +1757,58 @@ def native_layer_norm_backward_out(
|
||||
return grad_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm.default)
|
||||
def _fused_rms_norm(
|
||||
input: Tensor,
|
||||
normalized_shape: list[int],
|
||||
weight: Optional[Tensor],
|
||||
eps: Optional[float],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
dims_to_reduce: list[int] = []
|
||||
for i in range(len(normalized_shape)):
|
||||
dims_to_reduce.append(input.dim() - i - 1)
|
||||
|
||||
# upcast is needed for fp16 and bf16
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
upcasted_input = input.to(computation_dtype)
|
||||
|
||||
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
|
||||
if eps is None:
|
||||
if computation_dtype in (torch.float32, torch.complex64):
|
||||
eps_val = sys.float_info.epsilon
|
||||
else:
|
||||
eps_val = sys.float_info.epsilon
|
||||
else:
|
||||
eps_val = eps
|
||||
|
||||
rqrst_input = torch.rsqrt(
|
||||
# NB: don't inplace here, will violate functional IR invariant
|
||||
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
|
||||
)
|
||||
|
||||
upcasted_result = upcasted_input.mul(rqrst_input)
|
||||
|
||||
if weight is not None:
|
||||
upcasted_result = upcasted_result.mul(weight)
|
||||
|
||||
# NB: nested should be dead here, just here for fidelity
|
||||
is_nested = input.is_nested or (weight is not None and weight.is_nested)
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
is_channels_last = memory_format in (
|
||||
torch.channels_last,
|
||||
torch.channels_last_3d,
|
||||
)
|
||||
|
||||
if not is_nested and not is_channels_last:
|
||||
upcasted_result = upcasted_result.contiguous()
|
||||
rqrst_input = rqrst_input.contiguous()
|
||||
|
||||
# Cast normalized result back to original input type
|
||||
result = upcasted_result.type_as(input)
|
||||
|
||||
return result, rqrst_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm_backward.default)
|
||||
def _fused_rms_norm_backward(
|
||||
grad_out: Tensor,
|
||||
|
Reference in New Issue
Block a user