mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fused RMSNorm implementation (#153666)
Relevant #72643 Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090. ```py import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype)) x_normed = x / (rms_x + self.eps) return self.scale * x_normed def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16): rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype) input_data = torch.randn(input_shape, device='cuda', dtype=dtype) for _ in range(warmup_iterations): _ = rms_norm_layer(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = rms_norm_layer(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- RMSNorm CUDA Benchmark ---") print(f"Input Shape: {input_shape}") print(f"Normalized Dimension: {normalized_dim}") print(f"Benchmark Iterations: {num_iterations}") print(f"--- Fused Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda() for _ in range(warmup_iterations): _ = compiled_rms_norm(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = compiled_rms_norm(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- TorchCompile Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") print("-" * 50) if __name__ == '__main__': parameter_sets = [ {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16}, {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32}, {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16}, ] num_benchmark_iterations = 200 num_warmup_iterations = 20 for params in parameter_sets: batch_size = params['batch_size'] sequence_length = params['sequence_length'] hidden_features = params['hidden_features'] data_type = params.get('dtype', torch.float16) shape = (batch_size, sequence_length, hidden_features) norm_dim_to_normalize = hidden_features print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}") benchmark_rmsnorm_cuda(input_shape=shape, normalized_dim=norm_dim_to_normalize, num_iterations=num_benchmark_iterations, warmup_iterations=num_warmup_iterations, dtype=data_type) ``` Here are the triton compile tests ran on a 5090 (comparing this branch vs main) ```py import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code torch.manual_seed(0) device = torch.device("cuda") for batch in range(0, 9): for i in range(9, 16): normalized_shape_arg = (2**batch, 2**i) input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True) weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True) model = torch.nn.functional.rms_norm compiled_model = torch.compile(model) loss = torch.randn_like(input_tensor) num_iter = 5 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() num_iter = 10 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = round(elapsed_time_ms / num_iter, 5) print(2**batch, 2**i, avg_time_ms) ``` main ``` 32 512 0.1812 32 1024 0.19021 32 2048 0.18871 32 4096 0.17019 32 8192 0.21944 32 16384 0.38871 32 32768 0.83282 64 512 0.14705 64 1024 0.13987 64 2048 0.14111 64 4096 0.21699 64 8192 0.43141 64 16384 0.90652 64 32768 2.18573 128 512 0.19361 128 1024 0.1963 128 2048 0.20122 128 4096 0.38888 128 8192 0.93795 128 16384 2.23437 128 32768 5.50079 256 512 0.16722 256 1024 0.22856 256 2048 0.39421 256 4096 0.96621 256 8192 2.48746 256 16384 5.53571 256 32768 11.97932 ``` current branch ``` 32 512 0.16328 32 1024 0.18104 32 2048 0.15508 32 4096 0.14356 32 8192 0.20111 32 16384 0.45974 32 32768 0.94799 64 512 0.16874 64 1024 0.18701 64 2048 0.16107 64 4096 0.20152 64 8192 0.46568 64 16384 0.96599 64 32768 2.21661 128 512 0.14982 128 1024 0.15565 128 2048 0.22241 128 4096 0.46128 128 8192 0.88883 128 16384 2.3097 128 32768 5.84448 256 512 0.14346 256 1024 0.2007 256 2048 0.45927 256 4096 0.87876 256 8192 2.10571 256 16384 5.73948 256 32768 12.98581 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
60b9b06a53
commit
15ef4f28df
@ -158,6 +158,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
|
|||||||
OP_DECOMPOSE(kron);
|
OP_DECOMPOSE(kron);
|
||||||
OP_DECOMPOSE(l1_loss);
|
OP_DECOMPOSE(l1_loss);
|
||||||
m.impl("layer_norm", native::layer_norm_symint);
|
m.impl("layer_norm", native::layer_norm_symint);
|
||||||
|
m.impl("_fused_rms_norm", native::rms_norm_composite);
|
||||||
OP_DECOMPOSE2(ldexp, Tensor);
|
OP_DECOMPOSE2(ldexp, Tensor);
|
||||||
OP_DECOMPOSE2(less_equal, Tensor );
|
OP_DECOMPOSE2(less_equal, Tensor );
|
||||||
OP_DECOMPOSE2(less, Tensor );
|
OP_DECOMPOSE2(less, Tensor );
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -261,30 +261,11 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor rms_norm_symint(
|
std::tuple<Tensor, Tensor> rms_norm_composite(
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
c10::SymIntArrayRef normalized_shape,
|
IntArrayRef normalized_shape,
|
||||||
const std::optional<Tensor>& weight_opt /* optional */,
|
const std::optional<Tensor>& weight_opt /* optional */,
|
||||||
std::optional<double> eps) {
|
std::optional<double> eps) {
|
||||||
// See [Note: hacky wrapper removal for optional tensor]
|
|
||||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
|
||||||
const Tensor& weight = *weight_maybe_owned;
|
|
||||||
_check_rms_norm_inputs_symint(input, normalized_shape, weight);
|
|
||||||
|
|
||||||
#ifdef USE_MPS
|
|
||||||
if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) {
|
|
||||||
const Tensor weight = weight_opt.value();
|
|
||||||
const bool any_nested = input.is_nested() || weight.is_nested();
|
|
||||||
const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad();
|
|
||||||
const bool is_input_fp = isFloatingType(input.scalar_type());
|
|
||||||
const bool is_weight_fp = isFloatingType(weight.scalar_type());
|
|
||||||
|
|
||||||
if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) {
|
|
||||||
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
|
|
||||||
return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::vector<int64_t> dims_to_reduce;
|
std::vector<int64_t> dims_to_reduce;
|
||||||
for (const auto i : c10::irange(normalized_shape.size())) {
|
for (const auto i : c10::irange(normalized_shape.size())) {
|
||||||
@ -321,10 +302,60 @@ Tensor rms_norm_symint(
|
|||||||
upcasted_result = upcasted_result.mul(weight_opt.value());
|
upcasted_result = upcasted_result.mul(weight_opt.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
return upcasted_result;
|
// if nested do not make contiguous
|
||||||
|
if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){
|
||||||
|
return std::make_tuple(upcasted_result, rqrst_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){
|
||||||
|
return std::make_tuple(upcasted_result, rqrst_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous());
|
||||||
});
|
});
|
||||||
|
return std::make_tuple(
|
||||||
return result.type_as(input);
|
std::get<0>(result).type_as(input), // Cast normalized result to original input type
|
||||||
|
std::get<1>(result) // rsqrt_val
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Tensor rms_norm_symint(
|
||||||
|
const Tensor& input,
|
||||||
|
c10::SymIntArrayRef normalized_shape,
|
||||||
|
const std::optional<Tensor>& weight_opt /* optional */,
|
||||||
|
const std::optional<double> eps) {
|
||||||
|
|
||||||
|
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||||
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
|
_check_rms_norm_inputs_symint(input, normalized_shape, weight);
|
||||||
|
|
||||||
|
// composite fallback for channels last
|
||||||
|
if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){
|
||||||
|
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||||
|
}
|
||||||
|
|
||||||
|
// composite fallback for complex datatypes
|
||||||
|
if(input.is_complex()){
|
||||||
|
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef USE_MPS
|
||||||
|
if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) {
|
||||||
|
const Tensor weight = weight_opt.value();
|
||||||
|
const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad();
|
||||||
|
|
||||||
|
if (!(GradMode::is_enabled() && any_inputs_require_grad)) {
|
||||||
|
return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input.device().type() == DeviceType::MPS){
|
||||||
|
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
@ -106,6 +106,12 @@ void layer_norm_cpu_out(
|
|||||||
int64_t M,
|
int64_t M,
|
||||||
int64_t N);
|
int64_t N);
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor> rms_norm_composite(
|
||||||
|
const Tensor& input,
|
||||||
|
IntArrayRef normalized_shape,
|
||||||
|
const std::optional<Tensor>& weight_opt /* optional */,
|
||||||
|
std::optional<double> eps);
|
||||||
|
|
||||||
Tensor rms_norm_symint(
|
Tensor rms_norm_symint(
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
c10::SymIntArrayRef normalized_shape,
|
c10::SymIntArrayRef normalized_shape,
|
||||||
|
@ -19,7 +19,14 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
|||||||
#include <ATen/native/mps/RMSNorm_metallib.h>
|
#include <ATen/native/mps/RMSNorm_metallib.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) {
|
std::tuple<Tensor, Tensor> _fused_rms_norm_mps(const Tensor& input,
|
||||||
|
IntArrayRef normalized_shape,
|
||||||
|
const std::optional<Tensor>& weight_opt,
|
||||||
|
const std::optional<double> eps) {
|
||||||
|
const Tensor weight = weight_opt.value().contiguous();
|
||||||
|
const int64_t normalized_ndim = normalized_shape.size();
|
||||||
|
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
|
||||||
|
|
||||||
TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors");
|
TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors");
|
||||||
auto output = at::empty_like(input);
|
auto output = at::empty_like(input);
|
||||||
const auto input_shape = input.sizes();
|
const auto input_shape = input.sizes();
|
||||||
@ -41,7 +48,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c
|
|||||||
const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output));
|
const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output));
|
||||||
id<MTLComputePipelineState> rms_norm_pso = lib.getPipelineStateForFunc(kernel);
|
id<MTLComputePipelineState> rms_norm_pso = lib.getPipelineStateForFunc(kernel);
|
||||||
[computeEncoder setComputePipelineState:rms_norm_pso];
|
[computeEncoder setComputePipelineState:rms_norm_pso];
|
||||||
mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1);
|
mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1);
|
||||||
|
|
||||||
const auto maxThreadsPerGroup = static_cast<size_t>([rms_norm_pso maxTotalThreadsPerThreadgroup]);
|
const auto maxThreadsPerGroup = static_cast<size_t>([rms_norm_pso maxTotalThreadsPerThreadgroup]);
|
||||||
size_t threadgroup_size = maxThreadsPerGroup;
|
size_t threadgroup_size = maxThreadsPerGroup;
|
||||||
@ -58,7 +65,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return output;
|
return std::make_tuple(output, Tensor());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
@ -3314,9 +3314,15 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
CompositeImplicitAutograd: rms_norm_symint
|
CompositeImplicitAutograd: rms_norm_symint
|
||||||
|
|
||||||
- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor
|
- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)
|
||||||
dispatch:
|
dispatch:
|
||||||
|
CUDA: _fused_rms_norm_cuda
|
||||||
MPS: _fused_rms_norm_mps
|
MPS: _fused_rms_norm_mps
|
||||||
|
CompositeImplicitAutograd: rms_norm_composite
|
||||||
|
|
||||||
|
- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor)
|
||||||
|
dispatch:
|
||||||
|
CUDA: _fused_rms_norm_backward_cuda
|
||||||
|
|
||||||
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
|
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
@ -374,7 +374,6 @@ aten::_fused_adamw_.tensor_lr
|
|||||||
aten::_fused_moving_avg_obs_fq_helper
|
aten::_fused_moving_avg_obs_fq_helper
|
||||||
aten::_fused_moving_avg_obs_fq_helper.out
|
aten::_fused_moving_avg_obs_fq_helper.out
|
||||||
aten::_fused_moving_avg_obs_fq_helper_functional
|
aten::_fused_moving_avg_obs_fq_helper_functional
|
||||||
aten::_fused_rms_norm
|
|
||||||
aten::_fused_sdp_choice
|
aten::_fused_sdp_choice
|
||||||
aten::_fused_sgd
|
aten::_fused_sgd
|
||||||
aten::_fused_sgd.out
|
aten::_fused_sgd.out
|
||||||
|
@ -139,6 +139,8 @@ ALLOW_LIST = [
|
|||||||
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
|
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
|
||||||
# TODO: add back restriction when c10d ops can be exported
|
# TODO: add back restriction when c10d ops can be exported
|
||||||
("c10d::.*", datetime.date(9999, 1, 1)),
|
("c10d::.*", datetime.date(9999, 1, 1)),
|
||||||
|
# Previously MPS_only did not support backward
|
||||||
|
("aten::_fused_rms_norm", datetime.date(2025, 12, 30)),
|
||||||
]
|
]
|
||||||
|
|
||||||
ALLOW_LIST_COMPILED = [
|
ALLOW_LIST_COMPILED = [
|
||||||
|
@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher
|
|||||||
from torch._export.utils import _is_cia_op
|
from torch._export.utils import _is_cia_op
|
||||||
from torch._ops import DispatchKey
|
from torch._ops import DispatchKey
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_cuda import tf32_off
|
from torch.testing._internal.common_cuda import SM70OrLater, tf32_off
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
onlyCPU,
|
onlyCPU,
|
||||||
@ -1226,6 +1226,33 @@ class DecompOneOffTests(TestCase):
|
|||||||
for o_ref, o in zip(out_ref, out):
|
for o_ref, o in zip(out_ref, out):
|
||||||
self.assertEqual(o_ref.dtype, o.dtype)
|
self.assertEqual(o_ref.dtype, o.dtype)
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
|
def test_rms_norm_decomp_cuda(self, device):
|
||||||
|
@torch.compile
|
||||||
|
def rms_norm_sinh(a, b, c):
|
||||||
|
output = torch.nn.functional.rms_norm(a, b, c)
|
||||||
|
return torch.sinh(output)
|
||||||
|
|
||||||
|
normalized_shape_arg = (3, 3, 3)
|
||||||
|
input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
|
||||||
|
weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
def forward_pass_fn():
|
||||||
|
return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor)
|
||||||
|
|
||||||
|
model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code(
|
||||||
|
forward_pass_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# check RMSNorm was fused with sinh
|
||||||
|
self.assertTrue(
|
||||||
|
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(DecompOneOffTests, globals())
|
instantiate_device_type_tests(DecompOneOffTests, globals())
|
||||||
|
|
||||||
|
@ -1267,6 +1267,11 @@
|
|||||||
mean: not_implemented("native_layer_norm_backward mean")
|
mean: not_implemented("native_layer_norm_backward mean")
|
||||||
rstd: not_implemented("native_layer_norm_backward rstd")
|
rstd: not_implemented("native_layer_norm_backward rstd")
|
||||||
|
|
||||||
|
- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)
|
||||||
|
input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple<Tensor, Tensor>())"
|
||||||
|
result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape)
|
||||||
|
result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape)
|
||||||
|
|
||||||
- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
|
- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
|
||||||
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
|
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
|
||||||
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
|
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
|
||||||
|
@ -418,6 +418,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
|||||||
aten.native_dropout_backward,
|
aten.native_dropout_backward,
|
||||||
aten.native_group_norm_backward,
|
aten.native_group_norm_backward,
|
||||||
aten.native_layer_norm_backward,
|
aten.native_layer_norm_backward,
|
||||||
|
aten._fused_rms_norm_backward,
|
||||||
aten.new_empty,
|
aten.new_empty,
|
||||||
aten.new_full,
|
aten.new_full,
|
||||||
aten.new_ones,
|
aten.new_ones,
|
||||||
|
@ -1743,6 +1743,81 @@ def native_layer_norm_backward_out(
|
|||||||
return grad_input
|
return grad_input
|
||||||
|
|
||||||
|
|
||||||
|
@register_decomposition(aten._fused_rms_norm_backward.default)
|
||||||
|
def _fused_rms_norm_backward(
|
||||||
|
grad_out: Tensor,
|
||||||
|
input: Tensor,
|
||||||
|
normalized_shape: list[int],
|
||||||
|
rstd: Tensor,
|
||||||
|
weight: Optional[Tensor],
|
||||||
|
output_mask: list[bool],
|
||||||
|
) -> tuple[Optional[Tensor], Optional[Tensor]]:
|
||||||
|
input_shape = input.shape
|
||||||
|
input_ndim = input.dim()
|
||||||
|
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||||
|
|
||||||
|
grad_out_cast = grad_out.to(
|
||||||
|
computation_dtype, memory_format=torch.contiguous_format
|
||||||
|
)
|
||||||
|
input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format)
|
||||||
|
weight_cast = (
|
||||||
|
weight.to(computation_dtype, memory_format=torch.contiguous_format)
|
||||||
|
if weight is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
assert grad_out_cast is not None
|
||||||
|
|
||||||
|
axis = input_ndim - len(normalized_shape)
|
||||||
|
inner_dims = input_shape[axis:]
|
||||||
|
outer_dims = input_shape[:axis]
|
||||||
|
inner_dim_indices: list[int] = []
|
||||||
|
outer_dim_indices: list[int] = []
|
||||||
|
for i in range(input_ndim):
|
||||||
|
if i >= axis:
|
||||||
|
inner_dim_indices.append(i)
|
||||||
|
else:
|
||||||
|
outer_dim_indices.append(i)
|
||||||
|
|
||||||
|
N = prod(inner_dims) # type: ignore[arg-type]
|
||||||
|
M = prod(outer_dims) # type: ignore[arg-type]
|
||||||
|
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||||
|
|
||||||
|
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
|
||||||
|
return (
|
||||||
|
input.new_zeros(input_shape) if output_mask[0] else None,
|
||||||
|
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
|
||||||
|
if weight_cast is not None:
|
||||||
|
grad_x_hat = grad_out_cast * weight_cast
|
||||||
|
else:
|
||||||
|
grad_x_hat = grad_out_cast
|
||||||
|
|
||||||
|
d_input: Optional[Tensor] = None
|
||||||
|
d_weight: Optional[Tensor] = None
|
||||||
|
|
||||||
|
x_hat = input_cast * rstd
|
||||||
|
|
||||||
|
if output_mask[0]:
|
||||||
|
sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True)
|
||||||
|
d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd
|
||||||
|
|
||||||
|
if output_mask[1] and weight_cast is not None:
|
||||||
|
d_weight_full_shape = grad_out_cast * x_hat
|
||||||
|
if len(outer_dim_indices) > 0:
|
||||||
|
d_weight = torch.sum(
|
||||||
|
d_weight_full_shape, dim=outer_dim_indices, keepdim=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
d_weight = d_weight_full_shape
|
||||||
|
|
||||||
|
return (
|
||||||
|
_maybe_cast(d_input, input.dtype),
|
||||||
|
_maybe_cast(d_weight, input.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def native_batch_norm_helper(
|
def native_batch_norm_helper(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
weight: Optional[Tensor],
|
weight: Optional[Tensor],
|
||||||
|
@ -5023,6 +5023,103 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
|||||||
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
|
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor> infinitely_differentiable_native_rms_norm_backward(
|
||||||
|
const Tensor& dY,
|
||||||
|
const Tensor& drstd,
|
||||||
|
const Tensor& input,
|
||||||
|
IntArrayRef normalized_shape,
|
||||||
|
const Tensor& rstd,
|
||||||
|
const std::optional<Tensor>& weight_opt,
|
||||||
|
std::array<bool, 2> grad_input_mask) {
|
||||||
|
c10::MaybeOwned<at::Tensor> weight_maybe_owned =
|
||||||
|
at::borrow_from_optional_tensor(weight_opt);
|
||||||
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
|
|
||||||
|
const auto input_shape = input.sizes();
|
||||||
|
const auto input_ndim = input.dim();
|
||||||
|
const int normalized_ndim = normalized_shape.size();
|
||||||
|
const int axis = input_ndim - normalized_ndim;
|
||||||
|
|
||||||
|
int64_t N_rms = 1;
|
||||||
|
for (int i = 0; i < normalized_ndim; ++i) {
|
||||||
|
N_rms *= input_shape[axis + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor dX;
|
||||||
|
Tensor dgamma;
|
||||||
|
|
||||||
|
std::vector<int64_t> rstd_view_shape = rstd.sizes().vec();
|
||||||
|
for (int i = 0;
|
||||||
|
i < std::max(static_cast<int>(normalized_ndim - rstd.dim()), 0);
|
||||||
|
++i) {
|
||||||
|
rstd_view_shape.push_back(1);
|
||||||
|
}
|
||||||
|
Tensor rstd_broadcast = rstd.view(rstd_view_shape);
|
||||||
|
Tensor rstd_pow3 = rstd_broadcast.pow(3);
|
||||||
|
Tensor grad_x_hat;
|
||||||
|
|
||||||
|
if (dY.defined()) {
|
||||||
|
if (weight.defined()) {
|
||||||
|
grad_x_hat = dY * weight;
|
||||||
|
} else {
|
||||||
|
grad_x_hat = dY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grad_input_mask[0]) {
|
||||||
|
Tensor dX_from_dY_path;
|
||||||
|
Tensor dX_from_drstd_path;
|
||||||
|
|
||||||
|
std::vector<int64_t> inner_sum_dims;
|
||||||
|
inner_sum_dims.reserve(normalized_ndim);
|
||||||
|
for (int i = 0; i < normalized_ndim; ++i) {
|
||||||
|
inner_sum_dims.push_back(axis + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dY.defined() && grad_x_hat.defined()) {
|
||||||
|
Tensor sum_input_times_grad_x_hat =
|
||||||
|
sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true);
|
||||||
|
dX_from_dY_path = rstd_broadcast * grad_x_hat -
|
||||||
|
(input * rstd_pow3 / static_cast<double>(N_rms)) *
|
||||||
|
sum_input_times_grad_x_hat;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (drstd.defined()) {
|
||||||
|
Tensor drstd_broadcast = drstd.view(rstd_view_shape);
|
||||||
|
dX_from_drstd_path =
|
||||||
|
-(input * rstd_pow3 / static_cast<double>(N_rms)) * drstd_broadcast;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) {
|
||||||
|
dX = dX_from_dY_path + dX_from_drstd_path;
|
||||||
|
} else if (dX_from_dY_path.defined()) {
|
||||||
|
dX = dX_from_dY_path;
|
||||||
|
} else if (dX_from_drstd_path.defined()) {
|
||||||
|
dX = dX_from_drstd_path;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grad_input_mask[1] && weight.defined()) {
|
||||||
|
if (dY.defined()) {
|
||||||
|
Tensor x_hat = input * rstd_broadcast;
|
||||||
|
Tensor dgamma_full_shape = dY * x_hat;
|
||||||
|
|
||||||
|
if (axis > 0) {
|
||||||
|
std::vector<int64_t> outer_sum_dims;
|
||||||
|
outer_sum_dims.reserve(axis);
|
||||||
|
for (int i = 0; i < axis; ++i) {
|
||||||
|
outer_sum_dims.push_back(i);
|
||||||
|
}
|
||||||
|
dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false);
|
||||||
|
} else {
|
||||||
|
dgamma = dgamma_full_shape;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(dX, dgamma);
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor>
|
std::tuple<Tensor, Tensor, Tensor>
|
||||||
infinitely_differentiable_native_group_norm_backward(
|
infinitely_differentiable_native_group_norm_backward(
|
||||||
const Tensor& dY,
|
const Tensor& dY,
|
||||||
@ -6377,6 +6474,98 @@ Tensor layer_norm_jvp(
|
|||||||
bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
|
bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor rms_norm_jvp(
|
||||||
|
const Tensor& input_p,
|
||||||
|
const Tensor& input_t,
|
||||||
|
const Tensor& weight_p,
|
||||||
|
const Tensor& weight_t,
|
||||||
|
const Tensor& saved_rstd,
|
||||||
|
IntArrayRef normalized_shape) {
|
||||||
|
auto dims = std::vector<int64_t>{};
|
||||||
|
auto view_size = input_t.sizes().vec();
|
||||||
|
auto view_size_affine = input_t.sizes().vec();
|
||||||
|
|
||||||
|
int64_t numel = 1;
|
||||||
|
for (const auto i : c10::irange(view_size.size())) {
|
||||||
|
if (i < view_size.size() - normalized_shape.size()) {
|
||||||
|
view_size_affine[i] = 1;
|
||||||
|
} else {
|
||||||
|
numel *= input_t.size(static_cast<int64_t>(i));
|
||||||
|
view_size[i] = 1;
|
||||||
|
dims.push_back(static_cast<int64_t>(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rstd_p = saved_rstd.view(view_size);
|
||||||
|
|
||||||
|
Tensor rstd_t;
|
||||||
|
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
|
||||||
|
input_t._is_zerotensor()) {
|
||||||
|
rstd_t = -rstd_p.pow(3) * (input_t) * (input_p);
|
||||||
|
} else {
|
||||||
|
rstd_t = input_t * input_p;
|
||||||
|
rstd_t *= -rstd_p.pow(3);
|
||||||
|
}
|
||||||
|
rstd_t = rstd_t.sum(dims, true);
|
||||||
|
rstd_t /= numel;
|
||||||
|
|
||||||
|
Tensor result_t;
|
||||||
|
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
|
||||||
|
input_t._is_zerotensor()) {
|
||||||
|
result_t = (input_t)*rstd_p + (input_p)*rstd_t;
|
||||||
|
} else {
|
||||||
|
result_t = input_t * rstd_p;
|
||||||
|
auto temp = input_p * rstd_t;
|
||||||
|
result_t += temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<Tensor> result_p = std::nullopt;
|
||||||
|
if (weight_p.defined()) {
|
||||||
|
result_p = std::optional<Tensor>(input_p * rstd_p);
|
||||||
|
}
|
||||||
|
|
||||||
|
return _affine_jvp(
|
||||||
|
result_p,
|
||||||
|
result_t,
|
||||||
|
weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
|
||||||
|
weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
|
||||||
|
Tensor());
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor rms_norm_rstd_jvp(
|
||||||
|
const Tensor& input_p,
|
||||||
|
const Tensor& input_t,
|
||||||
|
const Tensor& saved_rstd,
|
||||||
|
IntArrayRef normalized_shape) {
|
||||||
|
auto dims = std::vector<int64_t>{};
|
||||||
|
auto view_size = input_t.sizes().vec();
|
||||||
|
auto view_size_affine = input_t.sizes().vec();
|
||||||
|
|
||||||
|
int64_t numel = 1;
|
||||||
|
for (const auto i : c10::irange(view_size.size())) {
|
||||||
|
if (i < view_size.size() - normalized_shape.size()) {
|
||||||
|
view_size_affine[i] = 1;
|
||||||
|
} else {
|
||||||
|
numel *= input_t.size(static_cast<int64_t>(i));
|
||||||
|
view_size[i] = 1;
|
||||||
|
dims.push_back(static_cast<int64_t>(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rstd_p = saved_rstd.view(view_size);
|
||||||
|
Tensor rstd_t;
|
||||||
|
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
|
||||||
|
input_t._is_zerotensor()) {
|
||||||
|
rstd_t = -rstd_p.pow(3) * (input_t) * (input_p);
|
||||||
|
} else {
|
||||||
|
rstd_t = input_t * input_p;
|
||||||
|
rstd_t *= -rstd_p.pow(3);
|
||||||
|
}
|
||||||
|
rstd_t = rstd_t.sum(dims, true);
|
||||||
|
rstd_t /= numel;
|
||||||
|
return rstd_t;
|
||||||
|
}
|
||||||
|
|
||||||
Tensor group_norm_jvp(
|
Tensor group_norm_jvp(
|
||||||
const Tensor& input_p,
|
const Tensor& input_p,
|
||||||
const Tensor& input_t,
|
const Tensor& input_t,
|
||||||
|
@ -826,6 +826,15 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
|||||||
c10::SymIntArrayRef normalized_shape,
|
c10::SymIntArrayRef normalized_shape,
|
||||||
std::array<bool, 3> output_mask);
|
std::array<bool, 3> output_mask);
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor> infinitely_differentiable_native_rms_norm_backward(
|
||||||
|
const Tensor& dY,
|
||||||
|
const Tensor& drstd,
|
||||||
|
const Tensor& input,
|
||||||
|
IntArrayRef normalized_shape,
|
||||||
|
const Tensor& rstd,
|
||||||
|
const std::optional<Tensor>& weight_opt,
|
||||||
|
std::array<bool, 2> grad_input_mask);
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor> householder_product_backward(
|
std::tuple<Tensor, Tensor> householder_product_backward(
|
||||||
const Tensor& grad,
|
const Tensor& grad,
|
||||||
const Tensor& result,
|
const Tensor& result,
|
||||||
@ -965,6 +974,20 @@ Tensor layer_norm_jvp(
|
|||||||
const Tensor& saved_invstd,
|
const Tensor& saved_invstd,
|
||||||
c10::SymIntArrayRef normalized_shape);
|
c10::SymIntArrayRef normalized_shape);
|
||||||
|
|
||||||
|
Tensor rms_norm_jvp(
|
||||||
|
const Tensor& input_p,
|
||||||
|
const Tensor& input_t,
|
||||||
|
const Tensor& weight_p,
|
||||||
|
const Tensor& weight_t,
|
||||||
|
const Tensor& saved_rstd,
|
||||||
|
IntArrayRef normalized_shape);
|
||||||
|
|
||||||
|
Tensor rms_norm_rstd_jvp(
|
||||||
|
const Tensor& input_p,
|
||||||
|
const Tensor& input_t,
|
||||||
|
const Tensor& saved_rstd,
|
||||||
|
IntArrayRef normalized_shape);
|
||||||
|
|
||||||
Tensor group_norm_jvp(
|
Tensor group_norm_jvp(
|
||||||
const Tensor& input_p,
|
const Tensor& input_p,
|
||||||
const Tensor& input_t,
|
const Tensor& input_t,
|
||||||
|
@ -29,6 +29,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self,
|
|||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);
|
||||||
|
@ -32,6 +32,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenT
|
|||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
|
||||||
|
@ -18,7 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64
|
|||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
||||||
|
@ -13,6 +13,7 @@ extern "C" {
|
|||||||
|
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||||
|
@ -820,6 +820,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
|||||||
torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
|
torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
|
||||||
torch.native_dropout: lambda input, p, train: -1,
|
torch.native_dropout: lambda input, p, train: -1,
|
||||||
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
||||||
|
torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1,
|
||||||
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
|
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
|
||||||
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
|
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
|
||||||
torch.native_channel_shuffle: lambda input, groups: -1,
|
torch.native_channel_shuffle: lambda input, groups: -1,
|
||||||
|
Reference in New Issue
Block a user