Fused RMSNorm Housekeeping (#159317)

Small PR to address comments that were made from the original fused rmsnorm PR that were not landed

Changes:
- Warning message when input.dtype doesn't match weight.dtype
- Ensure default epsilon value is correct

Comments:
https://github.com/pytorch/pytorch/pull/153666#discussion_r2114735005
https://github.com/pytorch/pytorch/pull/153666#discussion_r2223518064

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159317
Approved by: https://github.com/ngimel, https://github.com/Skylion007, https://github.com/eqy
This commit is contained in:
AaronWang04
2025-07-29 22:39:18 +00:00
committed by PyTorch MergeBot
parent b4619f0272
commit dc286aef61
2 changed files with 26 additions and 2 deletions

View File

@ -342,8 +342,8 @@ Tensor rms_norm_symint(
if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) { if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) {
TORCH_WARN_ONCE( TORCH_WARN_ONCE(
"Mismatch dtype between input and module: input dtype = ", input.dtype(), "Mismatch dtype between input and weight: input dtype = ", input.dtype(),
", module dtype = ", weight_opt.value().dtype(), ", Can not dispatch to fused implementation" ", weight dtype = ", weight_opt.value().dtype(), ", Cannot dispatch to fused implementation."
); );
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
} }

View File

@ -8748,6 +8748,30 @@ class TestNNDeviceType(NNTestCase):
self.assertEqual(Y_ref, Y) self.assertEqual(Y_ref, Y)
@onlyNativeDeviceTypes
@dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64)
def test_rmsnorm_epsilon(self, device, dtype):
def rms_norm_reference_fn(i, normalized_shape):
eps = torch.finfo(i.dtype).eps
ndim = i.ndim
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
if i.dtype is not torch.float64:
upcasted_i = i.float()
else:
upcasted_i = i
result = upcasted_i * torch.rsqrt(
upcasted_i.pow(2).mean(dim=dims, keepdim=True) + eps
)
return result.type_as(i)
shape = (2, 2)
X = torch.tensor([[1e-12, -1e-12], [1e-12, -1e-12]], dtype=dtype, device=device)
Y = torch.nn.functional.rms_norm(X, shape)
Y_ref = rms_norm_reference_fn(X, shape)
self.assertEqual(Y_ref, Y)
@onlyCPU @onlyCPU
def test_glu_bfloat16(self, device): def test_glu_bfloat16(self, device):
def test_dtype(fn, input, dtype): def test_dtype(fn, input, dtype):