mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Inductor] support mixed dtype in the native_layer_norm_backward meta function (#159830)
Fixes #159829 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159830 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
dfda2dfd53
commit
63276edb7c
@ -1710,8 +1710,8 @@ def native_layer_norm_backward(
|
||||
|
||||
return (
|
||||
_maybe_cast(d_input, input.dtype),
|
||||
_maybe_cast(d_weight, input.dtype),
|
||||
_maybe_cast(d_bias, input.dtype),
|
||||
_maybe_cast(d_weight, weight.dtype if weight is not None else None),
|
||||
_maybe_cast(d_bias, bias.dtype if bias is not None else None),
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user