Fused RMSNorm implementation (#153666)

Relevant #72643

Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090.

```py
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype))
        x_normed = x / (rms_x + self.eps)
        return self.scale * x_normed

def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16):
    rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype)
    input_data = torch.randn(input_shape, device='cuda', dtype=dtype)

    for _ in range(warmup_iterations):
        _ = rms_norm_layer(input_data)
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(num_iterations):
        _ = rms_norm_layer(input_data)

    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iterations

    print(f"--- RMSNorm CUDA Benchmark ---")
    print(f"Input Shape: {input_shape}")
    print(f"Normalized Dimension: {normalized_dim}")
    print(f"Benchmark Iterations: {num_iterations}")
    print(f"--- Fused Implementation ---")
    print(f"Average Time per Iteration: {avg_time_ms:.4f} ms")
    print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms")

    compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda()
    for _ in range(warmup_iterations):
        _ = compiled_rms_norm(input_data)
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(num_iterations):
        _ = compiled_rms_norm(input_data)
    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iterations

    print(f"--- TorchCompile Implementation ---")
    print(f"Average Time per Iteration: {avg_time_ms:.4f} ms")
    print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms")

    print("-" * 50)

if __name__ == '__main__':
    parameter_sets = [
        {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16},
        {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16},
        {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16},
        {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32},
        {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16},
    ]

    num_benchmark_iterations = 200
    num_warmup_iterations = 20

    for params in parameter_sets:
        batch_size = params['batch_size']
        sequence_length = params['sequence_length']
        hidden_features = params['hidden_features']
        data_type = params.get('dtype', torch.float16)

        shape = (batch_size, sequence_length, hidden_features)
        norm_dim_to_normalize = hidden_features

        print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}")
        benchmark_rmsnorm_cuda(input_shape=shape,
                               normalized_dim=norm_dim_to_normalize,
                               num_iterations=num_benchmark_iterations,
                               warmup_iterations=num_warmup_iterations,
                               dtype=data_type)
```

Here are the triton compile tests ran on a 5090 (comparing this branch vs main)
```py
import torch
import torch.nn as nn
from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code

torch.manual_seed(0)

device = torch.device("cuda")

for batch in range(0, 9):
    for i in range(9, 16):
        normalized_shape_arg = (2**batch, 2**i)
        input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True)
        weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True)

        model = torch.nn.functional.rms_norm
        compiled_model = torch.compile(model)
        loss = torch.randn_like(input_tensor)

        num_iter = 5
        for j in range(num_iter):
            output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor)
            output.backward(loss)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        num_iter = 10
        for j in range(num_iter):
            output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor)
            output.backward(loss)

        end_event.record()
        torch.cuda.synchronize()

        elapsed_time_ms = start_event.elapsed_time(end_event)
        avg_time_ms = round(elapsed_time_ms / num_iter, 5)
        print(2**batch, 2**i, avg_time_ms)
```
main
```
32 512 0.1812
32 1024 0.19021
32 2048 0.18871
32 4096 0.17019
32 8192 0.21944
32 16384 0.38871
32 32768 0.83282
64 512 0.14705
64 1024 0.13987
64 2048 0.14111
64 4096 0.21699
64 8192 0.43141
64 16384 0.90652
64 32768 2.18573
128 512 0.19361
128 1024 0.1963
128 2048 0.20122
128 4096 0.38888
128 8192 0.93795
128 16384 2.23437
128 32768 5.50079
256 512 0.16722
256 1024 0.22856
256 2048 0.39421
256 4096 0.96621
256 8192 2.48746
256 16384 5.53571
256 32768 11.97932
```
current branch
```
32 512 0.16328
32 1024 0.18104
32 2048 0.15508
32 4096 0.14356
32 8192 0.20111
32 16384 0.45974
32 32768 0.94799
64 512 0.16874
64 1024 0.18701
64 2048 0.16107
64 4096 0.20152
64 8192 0.46568
64 16384 0.96599
64 32768 2.21661
128 512 0.14982
128 1024 0.15565
128 2048 0.22241
128 4096 0.46128
128 8192 0.88883
128 16384 2.3097
128 32768 5.84448
256 512 0.14346
256 1024 0.2007
256 2048 0.45927
256 4096 0.87876
256 8192 2.10571
256 16384 5.73948
256 32768 12.98581
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/albanD
This commit is contained in:
AaronWang04
2025-07-18 23:24:21 +00:00
committed by PyTorch MergeBot
parent 60b9b06a53
commit 15ef4f28df
19 changed files with 845 additions and 185 deletions

View File

@ -158,6 +158,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(kron);
OP_DECOMPOSE(l1_loss);
m.impl("layer_norm", native::layer_norm_symint);
m.impl("_fused_rms_norm", native::rms_norm_composite);
OP_DECOMPOSE2(ldexp, Tensor);
OP_DECOMPOSE2(less_equal, Tensor );
OP_DECOMPOSE2(less, Tensor );

File diff suppressed because it is too large Load Diff

View File

@ -261,30 +261,11 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
return outputs;
}
Tensor rms_norm_symint(
std::tuple<Tensor, Tensor> rms_norm_composite(
const Tensor& input,
c10::SymIntArrayRef normalized_shape,
IntArrayRef normalized_shape,
const std::optional<Tensor>& weight_opt /* optional */,
std::optional<double> eps) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
_check_rms_norm_inputs_symint(input, normalized_shape, weight);
#ifdef USE_MPS
if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) {
const Tensor weight = weight_opt.value();
const bool any_nested = input.is_nested() || weight.is_nested();
const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad();
const bool is_input_fp = isFloatingType(input.scalar_type());
const bool is_weight_fp = isFloatingType(weight.scalar_type());
if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) {
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val);
}
}
#endif
std::vector<int64_t> dims_to_reduce;
for (const auto i : c10::irange(normalized_shape.size())) {
@ -321,10 +302,60 @@ Tensor rms_norm_symint(
upcasted_result = upcasted_result.mul(weight_opt.value());
}
return upcasted_result;
// if nested do not make contiguous
if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){
return std::make_tuple(upcasted_result, rqrst_input);
}
if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){
return std::make_tuple(upcasted_result, rqrst_input);
}
return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous());
});
return result.type_as(input);
return std::make_tuple(
std::get<0>(result).type_as(input), // Cast normalized result to original input type
std::get<1>(result) // rsqrt_val
);
}
Tensor rms_norm_symint(
const Tensor& input,
c10::SymIntArrayRef normalized_shape,
const std::optional<Tensor>& weight_opt /* optional */,
const std::optional<double> eps) {
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
_check_rms_norm_inputs_symint(input, normalized_shape, weight);
// composite fallback for channels last
if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
// composite fallback for complex datatypes
if(input.is_complex()){
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
#ifdef USE_MPS
if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) {
const Tensor weight = weight_opt.value();
const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad();
if (!(GradMode::is_enabled() && any_inputs_require_grad)) {
return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
}
if (input.device().type() == DeviceType::MPS){
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
#endif
return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
} // namespace at::native

View File

@ -106,6 +106,12 @@ void layer_norm_cpu_out(
int64_t M,
int64_t N);
std::tuple<Tensor, Tensor> rms_norm_composite(
const Tensor& input,
IntArrayRef normalized_shape,
const std::optional<Tensor>& weight_opt /* optional */,
std::optional<double> eps);
Tensor rms_norm_symint(
const Tensor& input,
c10::SymIntArrayRef normalized_shape,

View File

@ -19,7 +19,14 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/RMSNorm_metallib.h>
#endif
Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) {
std::tuple<Tensor, Tensor> _fused_rms_norm_mps(const Tensor& input,
IntArrayRef normalized_shape,
const std::optional<Tensor>& weight_opt,
const std::optional<double> eps) {
const Tensor weight = weight_opt.value().contiguous();
const int64_t normalized_ndim = normalized_shape.size();
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors");
auto output = at::empty_like(input);
const auto input_shape = input.sizes();
@ -41,7 +48,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c
const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output));
id<MTLComputePipelineState> rms_norm_pso = lib.getPipelineStateForFunc(kernel);
[computeEncoder setComputePipelineState:rms_norm_pso];
mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1);
mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1);
const auto maxThreadsPerGroup = static_cast<size_t>([rms_norm_pso maxTotalThreadsPerThreadgroup]);
size_t threadgroup_size = maxThreadsPerGroup;
@ -58,7 +65,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c
}
});
return output;
return std::make_tuple(output, Tensor());
}
} // namespace at::native

View File

@ -3314,9 +3314,15 @@
dispatch:
CompositeImplicitAutograd: rms_norm_symint
- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor
- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)
dispatch:
CUDA: _fused_rms_norm_cuda
MPS: _fused_rms_norm_mps
CompositeImplicitAutograd: rms_norm_composite
- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor)
dispatch:
CUDA: _fused_rms_norm_backward_cuda
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
variants: function, method

View File

@ -374,7 +374,6 @@ aten::_fused_adamw_.tensor_lr
aten::_fused_moving_avg_obs_fq_helper
aten::_fused_moving_avg_obs_fq_helper.out
aten::_fused_moving_avg_obs_fq_helper_functional
aten::_fused_rms_norm
aten::_fused_sdp_choice
aten::_fused_sgd
aten::_fused_sgd.out

View File

@ -139,6 +139,8 @@ ALLOW_LIST = [
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
# TODO: add back restriction when c10d ops can be exported
("c10d::.*", datetime.date(9999, 1, 1)),
# Previously MPS_only did not support backward
("aten::_fused_rms_norm", datetime.date(2025, 12, 30)),
]
ALLOW_LIST_COMPILED = [

View File

@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._export.utils import _is_cia_op
from torch._ops import DispatchKey
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import tf32_off
from torch.testing._internal.common_cuda import SM70OrLater, tf32_off
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
@ -1226,6 +1226,33 @@ class DecompOneOffTests(TestCase):
for o_ref, o in zip(out_ref, out):
self.assertEqual(o_ref.dtype, o.dtype)
@onlyCUDA
@unittest.skipIf(not SM70OrLater, "triton")
def test_rms_norm_decomp_cuda(self, device):
@torch.compile
def rms_norm_sinh(a, b, c):
output = torch.nn.functional.rms_norm(a, b, c)
return torch.sinh(output)
normalized_shape_arg = (3, 3, 3)
input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
def forward_pass_fn():
return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor)
model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code(
forward_pass_fn
)
# check RMSNorm was fused with sinh
self.assertTrue(
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
)
self.assertTrue(
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
)
instantiate_device_type_tests(DecompOneOffTests, globals())

View File

@ -1267,6 +1267,11 @@
mean: not_implemented("native_layer_norm_backward mean")
rstd: not_implemented("native_layer_norm_backward rstd")
- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)
input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple<Tensor, Tensor>())"
result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape)
result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape)
- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)

View File

@ -418,6 +418,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten._fused_rms_norm_backward,
aten.new_empty,
aten.new_full,
aten.new_ones,

View File

@ -1743,6 +1743,81 @@ def native_layer_norm_backward_out(
return grad_input
@register_decomposition(aten._fused_rms_norm_backward.default)
def _fused_rms_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: list[int],
rstd: Tensor,
weight: Optional[Tensor],
output_mask: list[bool],
) -> tuple[Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
computation_dtype = utils.get_computation_dtype(input.dtype)
grad_out_cast = grad_out.to(
computation_dtype, memory_format=torch.contiguous_format
)
input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format)
weight_cast = (
weight.to(computation_dtype, memory_format=torch.contiguous_format)
if weight is not None
else None
)
assert grad_out_cast is not None
axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices: list[int] = []
outer_dim_indices: list[int] = []
for i in range(input_ndim):
if i >= axis:
inner_dim_indices.append(i)
else:
outer_dim_indices.append(i)
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
)
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
if weight_cast is not None:
grad_x_hat = grad_out_cast * weight_cast
else:
grad_x_hat = grad_out_cast
d_input: Optional[Tensor] = None
d_weight: Optional[Tensor] = None
x_hat = input_cast * rstd
if output_mask[0]:
sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True)
d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd
if output_mask[1] and weight_cast is not None:
d_weight_full_shape = grad_out_cast * x_hat
if len(outer_dim_indices) > 0:
d_weight = torch.sum(
d_weight_full_shape, dim=outer_dim_indices, keepdim=False
)
else:
d_weight = d_weight_full_shape
return (
_maybe_cast(d_input, input.dtype),
_maybe_cast(d_weight, input.dtype),
)
def native_batch_norm_helper(
input: Tensor,
weight: Optional[Tensor],

View File

@ -5023,6 +5023,103 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}
std::tuple<Tensor, Tensor> infinitely_differentiable_native_rms_norm_backward(
const Tensor& dY,
const Tensor& drstd,
const Tensor& input,
IntArrayRef normalized_shape,
const Tensor& rstd,
const std::optional<Tensor>& weight_opt,
std::array<bool, 2> grad_input_mask) {
c10::MaybeOwned<at::Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const auto input_shape = input.sizes();
const auto input_ndim = input.dim();
const int normalized_ndim = normalized_shape.size();
const int axis = input_ndim - normalized_ndim;
int64_t N_rms = 1;
for (int i = 0; i < normalized_ndim; ++i) {
N_rms *= input_shape[axis + i];
}
Tensor dX;
Tensor dgamma;
std::vector<int64_t> rstd_view_shape = rstd.sizes().vec();
for (int i = 0;
i < std::max(static_cast<int>(normalized_ndim - rstd.dim()), 0);
++i) {
rstd_view_shape.push_back(1);
}
Tensor rstd_broadcast = rstd.view(rstd_view_shape);
Tensor rstd_pow3 = rstd_broadcast.pow(3);
Tensor grad_x_hat;
if (dY.defined()) {
if (weight.defined()) {
grad_x_hat = dY * weight;
} else {
grad_x_hat = dY;
}
}
if (grad_input_mask[0]) {
Tensor dX_from_dY_path;
Tensor dX_from_drstd_path;
std::vector<int64_t> inner_sum_dims;
inner_sum_dims.reserve(normalized_ndim);
for (int i = 0; i < normalized_ndim; ++i) {
inner_sum_dims.push_back(axis + i);
}
if (dY.defined() && grad_x_hat.defined()) {
Tensor sum_input_times_grad_x_hat =
sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true);
dX_from_dY_path = rstd_broadcast * grad_x_hat -
(input * rstd_pow3 / static_cast<double>(N_rms)) *
sum_input_times_grad_x_hat;
}
if (drstd.defined()) {
Tensor drstd_broadcast = drstd.view(rstd_view_shape);
dX_from_drstd_path =
-(input * rstd_pow3 / static_cast<double>(N_rms)) * drstd_broadcast;
}
if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) {
dX = dX_from_dY_path + dX_from_drstd_path;
} else if (dX_from_dY_path.defined()) {
dX = dX_from_dY_path;
} else if (dX_from_drstd_path.defined()) {
dX = dX_from_drstd_path;
}
}
if (grad_input_mask[1] && weight.defined()) {
if (dY.defined()) {
Tensor x_hat = input * rstd_broadcast;
Tensor dgamma_full_shape = dY * x_hat;
if (axis > 0) {
std::vector<int64_t> outer_sum_dims;
outer_sum_dims.reserve(axis);
for (int i = 0; i < axis; ++i) {
outer_sum_dims.push_back(i);
}
dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false);
} else {
dgamma = dgamma_full_shape;
}
}
}
return std::make_tuple(dX, dgamma);
}
std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(
const Tensor& dY,
@ -6377,6 +6474,98 @@ Tensor layer_norm_jvp(
bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
}
Tensor rms_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& saved_rstd,
IntArrayRef normalized_shape) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
auto view_size_affine = input_t.sizes().vec();
int64_t numel = 1;
for (const auto i : c10::irange(view_size.size())) {
if (i < view_size.size() - normalized_shape.size()) {
view_size_affine[i] = 1;
} else {
numel *= input_t.size(static_cast<int64_t>(i));
view_size[i] = 1;
dims.push_back(static_cast<int64_t>(i));
}
}
auto rstd_p = saved_rstd.view(view_size);
Tensor rstd_t;
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
input_t._is_zerotensor()) {
rstd_t = -rstd_p.pow(3) * (input_t) * (input_p);
} else {
rstd_t = input_t * input_p;
rstd_t *= -rstd_p.pow(3);
}
rstd_t = rstd_t.sum(dims, true);
rstd_t /= numel;
Tensor result_t;
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
input_t._is_zerotensor()) {
result_t = (input_t)*rstd_p + (input_p)*rstd_t;
} else {
result_t = input_t * rstd_p;
auto temp = input_p * rstd_t;
result_t += temp;
}
std::optional<Tensor> result_p = std::nullopt;
if (weight_p.defined()) {
result_p = std::optional<Tensor>(input_p * rstd_p);
}
return _affine_jvp(
result_p,
result_t,
weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
Tensor());
}
Tensor rms_norm_rstd_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& saved_rstd,
IntArrayRef normalized_shape) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
auto view_size_affine = input_t.sizes().vec();
int64_t numel = 1;
for (const auto i : c10::irange(view_size.size())) {
if (i < view_size.size() - normalized_shape.size()) {
view_size_affine[i] = 1;
} else {
numel *= input_t.size(static_cast<int64_t>(i));
view_size[i] = 1;
dims.push_back(static_cast<int64_t>(i));
}
}
auto rstd_p = saved_rstd.view(view_size);
Tensor rstd_t;
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
input_t._is_zerotensor()) {
rstd_t = -rstd_p.pow(3) * (input_t) * (input_p);
} else {
rstd_t = input_t * input_p;
rstd_t *= -rstd_p.pow(3);
}
rstd_t = rstd_t.sum(dims, true);
rstd_t /= numel;
return rstd_t;
}
Tensor group_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,

View File

@ -826,6 +826,15 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
c10::SymIntArrayRef normalized_shape,
std::array<bool, 3> output_mask);
std::tuple<Tensor, Tensor> infinitely_differentiable_native_rms_norm_backward(
const Tensor& dY,
const Tensor& drstd,
const Tensor& input,
IntArrayRef normalized_shape,
const Tensor& rstd,
const std::optional<Tensor>& weight_opt,
std::array<bool, 2> grad_input_mask);
std::tuple<Tensor, Tensor> householder_product_backward(
const Tensor& grad,
const Tensor& result,
@ -965,6 +974,20 @@ Tensor layer_norm_jvp(
const Tensor& saved_invstd,
c10::SymIntArrayRef normalized_shape);
Tensor rms_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& saved_rstd,
IntArrayRef normalized_shape);
Tensor rms_norm_rstd_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& saved_rstd,
IntArrayRef normalized_shape);
Tensor group_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,

View File

@ -29,6 +29,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self,
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);

View File

@ -32,6 +32,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenT
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);

View File

@ -18,7 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);

View File

@ -13,6 +13,7 @@ extern "C" {
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);

View File

@ -820,6 +820,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
torch.native_dropout: lambda input, p, train: -1,
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1,
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
torch.native_channel_shuffle: lambda input, groups: -1,