mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] support mixed dtype in the native_layer_norm_backward meta function (#159830)
Fixes #159829 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159830 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
dfda2dfd53
commit
63276edb7c
@ -1827,6 +1827,43 @@ class TestMeta(TestCase):
|
||||
|
||||
self.assertEqual(out.stride(), f_out.stride())
|
||||
|
||||
|
||||
@parametrize("in_dtype", [torch.float32, torch.float16])
|
||||
@parametrize("bias_dtype", [torch.float32, torch.float16, None])
|
||||
def test_mixed_dtype_for_native_layer_norm_backward(self, in_dtype, bias_dtype):
|
||||
if in_dtype == torch.float16 and bias_dtype == torch.float32:
|
||||
self.skipTest(f"not supported input dtype is {in_dtype} and bias dtype is {bias_dtype}")
|
||||
device = "meta"
|
||||
|
||||
def fn(input, weight, bias, need_grad_input):
|
||||
outputs = torch.nn.functional.layer_norm(input, input.shape[-1:], weight, bias)
|
||||
grad_outs = torch.ones_like(outputs)
|
||||
grad_ins = torch.autograd.grad(outputs, need_grad_input, grad_outs)
|
||||
return grad_ins
|
||||
|
||||
input = torch.randn([4, 8, 5], dtype=in_dtype, device=device, requires_grad=True)
|
||||
need_grad_input = [input]
|
||||
|
||||
if bias_dtype:
|
||||
weight = torch.randn(
|
||||
[5], dtype=bias_dtype, device=device, requires_grad=True
|
||||
)
|
||||
bias = torch.randn(
|
||||
[5], dtype=bias_dtype, device=device, requires_grad=True
|
||||
)
|
||||
need_grad_input.append(weight)
|
||||
need_grad_input.append(bias)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
|
||||
outs = fn(input, weight, bias, need_grad_input)
|
||||
out_dtype = [t.dtype for t in outs]
|
||||
if bias_dtype:
|
||||
self.assertEqual(out_dtype, [in_dtype, bias_dtype, bias_dtype])
|
||||
else:
|
||||
self.assertEqual(out_dtype, [in_dtype,])
|
||||
|
||||
instantiate_device_type_tests(TestMeta, globals())
|
||||
|
||||
def print_op_str_if_not_supported(op_str):
|
||||
|
@ -1710,8 +1710,8 @@ def native_layer_norm_backward(
|
||||
|
||||
return (
|
||||
_maybe_cast(d_input, input.dtype),
|
||||
_maybe_cast(d_weight, input.dtype),
|
||||
_maybe_cast(d_bias, input.dtype),
|
||||
_maybe_cast(d_weight, weight.dtype if weight is not None else None),
|
||||
_maybe_cast(d_bias, bias.dtype if bias is not None else None),
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user