[Inductor] support mixed dtype in the native_layer_norm_backward meta function (#159830)

Fixes #159829

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159830
Approved by: https://github.com/albanD
This commit is contained in:
Shaobin Ma
2025-09-17 20:29:08 +00:00
committed by PyTorch MergeBot
parent dfda2dfd53
commit 63276edb7c
2 changed files with 39 additions and 2 deletions

View File

@ -1710,8 +1710,8 @@ def native_layer_norm_backward(
return (
_maybe_cast(d_input, input.dtype),
_maybe_cast(d_weight, input.dtype),
_maybe_cast(d_bias, input.dtype),
_maybe_cast(d_weight, weight.dtype if weight is not None else None),
_maybe_cast(d_bias, bias.dtype if bias is not None else None),
)