diff --git a/test/test_meta.py b/test/test_meta.py index b3e5faab4f65..4e79e59cfe62 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -1827,6 +1827,43 @@ class TestMeta(TestCase): self.assertEqual(out.stride(), f_out.stride()) + + @parametrize("in_dtype", [torch.float32, torch.float16]) + @parametrize("bias_dtype", [torch.float32, torch.float16, None]) + def test_mixed_dtype_for_native_layer_norm_backward(self, in_dtype, bias_dtype): + if in_dtype == torch.float16 and bias_dtype == torch.float32: + self.skipTest(f"not supported input dtype is {in_dtype} and bias dtype is {bias_dtype}") + device = "meta" + + def fn(input, weight, bias, need_grad_input): + outputs = torch.nn.functional.layer_norm(input, input.shape[-1:], weight, bias) + grad_outs = torch.ones_like(outputs) + grad_ins = torch.autograd.grad(outputs, need_grad_input, grad_outs) + return grad_ins + + input = torch.randn([4, 8, 5], dtype=in_dtype, device=device, requires_grad=True) + need_grad_input = [input] + + if bias_dtype: + weight = torch.randn( + [5], dtype=bias_dtype, device=device, requires_grad=True + ) + bias = torch.randn( + [5], dtype=bias_dtype, device=device, requires_grad=True + ) + need_grad_input.append(weight) + need_grad_input.append(bias) + else: + weight = None + bias = None + + outs = fn(input, weight, bias, need_grad_input) + out_dtype = [t.dtype for t in outs] + if bias_dtype: + self.assertEqual(out_dtype, [in_dtype, bias_dtype, bias_dtype]) + else: + self.assertEqual(out_dtype, [in_dtype,]) + instantiate_device_type_tests(TestMeta, globals()) def print_op_str_if_not_supported(op_str): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index ba09c6173c5f..2a00c57419da 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1710,8 +1710,8 @@ def native_layer_norm_backward( return ( _maybe_cast(d_input, input.dtype), - _maybe_cast(d_weight, input.dtype), - _maybe_cast(d_bias, input.dtype), + _maybe_cast(d_weight, weight.dtype if weight is not None else None), + _maybe_cast(d_bias, bias.dtype if bias is not None else None), )