Fix addbmm & addmv & baddbmm out dtype check (#148176)

----

- torch.addbmm
- torch.addmv
- torch.baddbmm

ISSUE related:
https://github.com/pytorch/pytorch/issues/138399
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148176
Approved by: https://github.com/jansel
ghstack dependencies: #148174
This commit is contained in:
FFFrog
2025-03-17 11:01:31 +08:00
committed by PyTorch MergeBot
parent 4d6ff6ca5c
commit b01877aa13
3 changed files with 4 additions and 7 deletions

View File

@ -1482,7 +1482,7 @@ def _addmm_activation(
@register_decomposition(aten.addmv)
@out_wrapper()
@out_wrapper(exact_dtype=True)
@pw_cast_for_opmath
def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1):
if not self.is_floating_point() and not self.is_complex():
@ -5031,7 +5031,7 @@ def register_inplace(aten_op, outplace_op):
@register_decomposition([aten.baddbmm])
@out_wrapper()
@out_wrapper(exact_dtype=True)
@pw_cast_for_opmath
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
if not self.is_floating_point() and not self.is_complex():