mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix addbmm & addmv & baddbmm out dtype check (#148176)
---- - torch.addbmm - torch.addmv - torch.baddbmm ISSUE related: https://github.com/pytorch/pytorch/issues/138399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148176 Approved by: https://github.com/jansel ghstack dependencies: #148174
This commit is contained in:
@ -1482,7 +1482,7 @@ def _addmm_activation(
|
||||
|
||||
|
||||
@register_decomposition(aten.addmv)
|
||||
@out_wrapper()
|
||||
@out_wrapper(exact_dtype=True)
|
||||
@pw_cast_for_opmath
|
||||
def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1):
|
||||
if not self.is_floating_point() and not self.is_complex():
|
||||
@ -5031,7 +5031,7 @@ def register_inplace(aten_op, outplace_op):
|
||||
|
||||
|
||||
@register_decomposition([aten.baddbmm])
|
||||
@out_wrapper()
|
||||
@out_wrapper(exact_dtype=True)
|
||||
@pw_cast_for_opmath
|
||||
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
|
||||
if not self.is_floating_point() and not self.is_complex():
|
||||
|
Reference in New Issue
Block a user