mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fused RMSNorm Housekeeping (#159317)
Small PR to address comments that were made from the original fused rmsnorm PR that were not landed Changes: - Warning message when input.dtype doesn't match weight.dtype - Ensure default epsilon value is correct Comments: https://github.com/pytorch/pytorch/pull/153666#discussion_r2114735005 https://github.com/pytorch/pytorch/pull/153666#discussion_r2223518064 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159317 Approved by: https://github.com/ngimel, https://github.com/Skylion007, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
b4619f0272
commit
dc286aef61
@ -342,8 +342,8 @@ Tensor rms_norm_symint(
|
|||||||
|
|
||||||
if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) {
|
if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) {
|
||||||
TORCH_WARN_ONCE(
|
TORCH_WARN_ONCE(
|
||||||
"Mismatch dtype between input and module: input dtype = ", input.dtype(),
|
"Mismatch dtype between input and weight: input dtype = ", input.dtype(),
|
||||||
", module dtype = ", weight_opt.value().dtype(), ", Can not dispatch to fused implementation"
|
", weight dtype = ", weight_opt.value().dtype(), ", Cannot dispatch to fused implementation."
|
||||||
);
|
);
|
||||||
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||||
}
|
}
|
||||||
|
@ -8748,6 +8748,30 @@ class TestNNDeviceType(NNTestCase):
|
|||||||
|
|
||||||
self.assertEqual(Y_ref, Y)
|
self.assertEqual(Y_ref, Y)
|
||||||
|
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
@dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64)
|
||||||
|
def test_rmsnorm_epsilon(self, device, dtype):
|
||||||
|
def rms_norm_reference_fn(i, normalized_shape):
|
||||||
|
eps = torch.finfo(i.dtype).eps
|
||||||
|
ndim = i.ndim
|
||||||
|
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
|
||||||
|
if i.dtype is not torch.float64:
|
||||||
|
upcasted_i = i.float()
|
||||||
|
else:
|
||||||
|
upcasted_i = i
|
||||||
|
result = upcasted_i * torch.rsqrt(
|
||||||
|
upcasted_i.pow(2).mean(dim=dims, keepdim=True) + eps
|
||||||
|
)
|
||||||
|
return result.type_as(i)
|
||||||
|
|
||||||
|
shape = (2, 2)
|
||||||
|
X = torch.tensor([[1e-12, -1e-12], [1e-12, -1e-12]], dtype=dtype, device=device)
|
||||||
|
|
||||||
|
Y = torch.nn.functional.rms_norm(X, shape)
|
||||||
|
Y_ref = rms_norm_reference_fn(X, shape)
|
||||||
|
|
||||||
|
self.assertEqual(Y_ref, Y)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_glu_bfloat16(self, device):
|
def test_glu_bfloat16(self, device):
|
||||||
def test_dtype(fn, input, dtype):
|
def test_dtype(fn, input, dtype):
|
||||||
|
Reference in New Issue
Block a user