mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fused RMSNorm implementation (#153666)"
This reverts commit e1aee86646aa6d1b9cb9d34351e43936401c5efc.
Reverted https://github.com/pytorch/pytorch/pull/153666 on behalf of https://github.com/davidberard98 due to causing build failures on main branch [GH job link](https://github.com/pytorch/pytorch/actions/runs/16007148842/job/45156382001) [HUD commit link](e1aee86646
) ([comment](https://github.com/pytorch/pytorch/pull/153666#issuecomment-3025146176))
This commit is contained in:
@ -418,7 +418,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.native_dropout_backward,
|
||||
aten.native_group_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten._fused_rms_norm_backward,
|
||||
aten.new_empty,
|
||||
aten.new_full,
|
||||
aten.new_ones,
|
||||
|
@ -1743,81 +1743,6 @@ def native_layer_norm_backward_out(
|
||||
return grad_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm_backward.default)
|
||||
def _fused_rms_norm_backward(
|
||||
grad_out: Tensor,
|
||||
input: Tensor,
|
||||
normalized_shape: list[int],
|
||||
rstd: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
output_mask: list[bool],
|
||||
) -> tuple[Optional[Tensor], Optional[Tensor]]:
|
||||
input_shape = input.shape
|
||||
input_ndim = input.dim()
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
|
||||
grad_out_cast = grad_out.to(
|
||||
computation_dtype, memory_format=torch.contiguous_format
|
||||
)
|
||||
input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format)
|
||||
weight_cast = (
|
||||
weight.to(computation_dtype, memory_format=torch.contiguous_format)
|
||||
if weight is not None
|
||||
else None
|
||||
)
|
||||
assert grad_out_cast is not None
|
||||
|
||||
axis = input_ndim - len(normalized_shape)
|
||||
inner_dims = input_shape[axis:]
|
||||
outer_dims = input_shape[:axis]
|
||||
inner_dim_indices: list[int] = []
|
||||
outer_dim_indices: list[int] = []
|
||||
for i in range(input_ndim):
|
||||
if i >= axis:
|
||||
inner_dim_indices.append(i)
|
||||
else:
|
||||
outer_dim_indices.append(i)
|
||||
|
||||
N = prod(inner_dims) # type: ignore[arg-type]
|
||||
M = prod(outer_dims) # type: ignore[arg-type]
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
|
||||
return (
|
||||
input.new_zeros(input_shape) if output_mask[0] else None,
|
||||
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
|
||||
)
|
||||
|
||||
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
|
||||
if weight_cast is not None:
|
||||
grad_x_hat = grad_out_cast * weight_cast
|
||||
else:
|
||||
grad_x_hat = grad_out_cast
|
||||
|
||||
d_input: Optional[Tensor] = None
|
||||
d_weight: Optional[Tensor] = None
|
||||
|
||||
x_hat = input_cast * rstd
|
||||
|
||||
if output_mask[0]:
|
||||
sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True)
|
||||
d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd
|
||||
|
||||
if output_mask[1] and weight_cast is not None:
|
||||
d_weight_full_shape = grad_out_cast * x_hat
|
||||
if len(outer_dim_indices) > 0:
|
||||
d_weight = torch.sum(
|
||||
d_weight_full_shape, dim=outer_dim_indices, keepdim=False
|
||||
)
|
||||
else:
|
||||
d_weight = d_weight_full_shape
|
||||
|
||||
return (
|
||||
_maybe_cast(d_input, input.dtype),
|
||||
_maybe_cast(d_weight, input.dtype),
|
||||
)
|
||||
|
||||
|
||||
def native_batch_norm_helper(
|
||||
input: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
|
Reference in New Issue
Block a user