[Inductor] support mixed dtype in the native_layer_norm_backward meta function (#159830)

Fixes #159829

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159830
Approved by: https://github.com/albanD
This commit is contained in:
Shaobin Ma
2025-09-17 20:29:08 +00:00
committed by PyTorch MergeBot
parent dfda2dfd53
commit 63276edb7c
2 changed files with 39 additions and 2 deletions

View File

@ -1827,6 +1827,43 @@ class TestMeta(TestCase):
self.assertEqual(out.stride(), f_out.stride()) self.assertEqual(out.stride(), f_out.stride())
@parametrize("in_dtype", [torch.float32, torch.float16])
@parametrize("bias_dtype", [torch.float32, torch.float16, None])
def test_mixed_dtype_for_native_layer_norm_backward(self, in_dtype, bias_dtype):
if in_dtype == torch.float16 and bias_dtype == torch.float32:
self.skipTest(f"not supported input dtype is {in_dtype} and bias dtype is {bias_dtype}")
device = "meta"
def fn(input, weight, bias, need_grad_input):
outputs = torch.nn.functional.layer_norm(input, input.shape[-1:], weight, bias)
grad_outs = torch.ones_like(outputs)
grad_ins = torch.autograd.grad(outputs, need_grad_input, grad_outs)
return grad_ins
input = torch.randn([4, 8, 5], dtype=in_dtype, device=device, requires_grad=True)
need_grad_input = [input]
if bias_dtype:
weight = torch.randn(
[5], dtype=bias_dtype, device=device, requires_grad=True
)
bias = torch.randn(
[5], dtype=bias_dtype, device=device, requires_grad=True
)
need_grad_input.append(weight)
need_grad_input.append(bias)
else:
weight = None
bias = None
outs = fn(input, weight, bias, need_grad_input)
out_dtype = [t.dtype for t in outs]
if bias_dtype:
self.assertEqual(out_dtype, [in_dtype, bias_dtype, bias_dtype])
else:
self.assertEqual(out_dtype, [in_dtype,])
instantiate_device_type_tests(TestMeta, globals()) instantiate_device_type_tests(TestMeta, globals())
def print_op_str_if_not_supported(op_str): def print_op_str_if_not_supported(op_str):

View File

@ -1710,8 +1710,8 @@ def native_layer_norm_backward(
return ( return (
_maybe_cast(d_input, input.dtype), _maybe_cast(d_input, input.dtype),
_maybe_cast(d_weight, input.dtype), _maybe_cast(d_weight, weight.dtype if weight is not None else None),
_maybe_cast(d_bias, input.dtype), _maybe_cast(d_bias, bias.dtype if bias is not None else None),
) )