mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] support mixed dtype in the native_layer_norm_backward meta function (#159830)
Fixes #159829 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159830 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
dfda2dfd53
commit
63276edb7c
@ -1827,6 +1827,43 @@ class TestMeta(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(out.stride(), f_out.stride())
|
self.assertEqual(out.stride(), f_out.stride())
|
||||||
|
|
||||||
|
|
||||||
|
@parametrize("in_dtype", [torch.float32, torch.float16])
|
||||||
|
@parametrize("bias_dtype", [torch.float32, torch.float16, None])
|
||||||
|
def test_mixed_dtype_for_native_layer_norm_backward(self, in_dtype, bias_dtype):
|
||||||
|
if in_dtype == torch.float16 and bias_dtype == torch.float32:
|
||||||
|
self.skipTest(f"not supported input dtype is {in_dtype} and bias dtype is {bias_dtype}")
|
||||||
|
device = "meta"
|
||||||
|
|
||||||
|
def fn(input, weight, bias, need_grad_input):
|
||||||
|
outputs = torch.nn.functional.layer_norm(input, input.shape[-1:], weight, bias)
|
||||||
|
grad_outs = torch.ones_like(outputs)
|
||||||
|
grad_ins = torch.autograd.grad(outputs, need_grad_input, grad_outs)
|
||||||
|
return grad_ins
|
||||||
|
|
||||||
|
input = torch.randn([4, 8, 5], dtype=in_dtype, device=device, requires_grad=True)
|
||||||
|
need_grad_input = [input]
|
||||||
|
|
||||||
|
if bias_dtype:
|
||||||
|
weight = torch.randn(
|
||||||
|
[5], dtype=bias_dtype, device=device, requires_grad=True
|
||||||
|
)
|
||||||
|
bias = torch.randn(
|
||||||
|
[5], dtype=bias_dtype, device=device, requires_grad=True
|
||||||
|
)
|
||||||
|
need_grad_input.append(weight)
|
||||||
|
need_grad_input.append(bias)
|
||||||
|
else:
|
||||||
|
weight = None
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
outs = fn(input, weight, bias, need_grad_input)
|
||||||
|
out_dtype = [t.dtype for t in outs]
|
||||||
|
if bias_dtype:
|
||||||
|
self.assertEqual(out_dtype, [in_dtype, bias_dtype, bias_dtype])
|
||||||
|
else:
|
||||||
|
self.assertEqual(out_dtype, [in_dtype,])
|
||||||
|
|
||||||
instantiate_device_type_tests(TestMeta, globals())
|
instantiate_device_type_tests(TestMeta, globals())
|
||||||
|
|
||||||
def print_op_str_if_not_supported(op_str):
|
def print_op_str_if_not_supported(op_str):
|
||||||
|
@ -1710,8 +1710,8 @@ def native_layer_norm_backward(
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
_maybe_cast(d_input, input.dtype),
|
_maybe_cast(d_input, input.dtype),
|
||||||
_maybe_cast(d_weight, input.dtype),
|
_maybe_cast(d_weight, weight.dtype if weight is not None else None),
|
||||||
_maybe_cast(d_bias, input.dtype),
|
_maybe_cast(d_bias, bias.dtype if bias is not None else None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user