Revert "Fused RMSNorm implementation (#153666)"

This reverts commit e1aee86646aa6d1b9cb9d34351e43936401c5efc.

Reverted https://github.com/pytorch/pytorch/pull/153666 on behalf of https://github.com/davidberard98 due to causing build failures on main branch [GH job link](https://github.com/pytorch/pytorch/actions/runs/16007148842/job/45156382001) [HUD commit link](e1aee86646) ([comment](https://github.com/pytorch/pytorch/pull/153666#issuecomment-3025146176))
This commit is contained in:
PyTorch MergeBot
2025-07-01 18:46:45 +00:00
parent 3a5677a380
commit 6401d1d53d
14 changed files with 184 additions and 839 deletions

View File

@ -418,7 +418,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten._fused_rms_norm_backward,
aten.new_empty,
aten.new_full,
aten.new_ones,

View File

@ -1743,81 +1743,6 @@ def native_layer_norm_backward_out(
return grad_input
@register_decomposition(aten._fused_rms_norm_backward.default)
def _fused_rms_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: list[int],
rstd: Tensor,
weight: Optional[Tensor],
output_mask: list[bool],
) -> tuple[Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
computation_dtype = utils.get_computation_dtype(input.dtype)
grad_out_cast = grad_out.to(
computation_dtype, memory_format=torch.contiguous_format
)
input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format)
weight_cast = (
weight.to(computation_dtype, memory_format=torch.contiguous_format)
if weight is not None
else None
)
assert grad_out_cast is not None
axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices: list[int] = []
outer_dim_indices: list[int] = []
for i in range(input_ndim):
if i >= axis:
inner_dim_indices.append(i)
else:
outer_dim_indices.append(i)
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
)
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
if weight_cast is not None:
grad_x_hat = grad_out_cast * weight_cast
else:
grad_x_hat = grad_out_cast
d_input: Optional[Tensor] = None
d_weight: Optional[Tensor] = None
x_hat = input_cast * rstd
if output_mask[0]:
sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True)
d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd
if output_mask[1] and weight_cast is not None:
d_weight_full_shape = grad_out_cast * x_hat
if len(outer_dim_indices) > 0:
d_weight = torch.sum(
d_weight_full_shape, dim=outer_dim_indices, keepdim=False
)
else:
d_weight = d_weight_full_shape
return (
_maybe_cast(d_input, input.dtype),
_maybe_cast(d_weight, input.dtype),
)
def native_batch_norm_helper(
input: Tensor,
weight: Optional[Tensor],