addmm: error on output dtype mismatch. (#138520)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138520
Approved by: https://github.com/ezyang
ghstack dependencies: #138515
This commit is contained in:
Yukio Siraichi
2024-10-29 19:02:00 -03:00
committed by PyTorch MergeBot
parent 6da3a043a8
commit fef5e94657
2 changed files with 1 additions and 3 deletions

View File

@ -123,8 +123,6 @@ aten = torch.ops.aten
meta_consistency_out_dtype_mismatch_xfails = {
xfail("abs"),
xfail("addbmm"),
xfail("addmm"),
xfail("addmm", "decomposed"),
xfail("addmv"),
xfail("alias_copy"),
xfail("all"),

View File

@ -1479,7 +1479,7 @@ def tensor_split_tensor_indices_or_sections_py_impl(
# TODO: this doesn't appear to have enough precision in bfloat16
@register_decomposition(aten.addmm)
@out_wrapper()
@out_wrapper(exact_dtype=True)
@pw_cast_for_opmath
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
if not self.is_floating_point() and not self.is_complex():