mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
addmm
: error on output dtype mismatch. (#138520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138520 Approved by: https://github.com/ezyang ghstack dependencies: #138515
This commit is contained in:
committed by
PyTorch MergeBot
parent
6da3a043a8
commit
fef5e94657
@ -123,8 +123,6 @@ aten = torch.ops.aten
|
||||
meta_consistency_out_dtype_mismatch_xfails = {
|
||||
xfail("abs"),
|
||||
xfail("addbmm"),
|
||||
xfail("addmm"),
|
||||
xfail("addmm", "decomposed"),
|
||||
xfail("addmv"),
|
||||
xfail("alias_copy"),
|
||||
xfail("all"),
|
||||
|
@ -1479,7 +1479,7 @@ def tensor_split_tensor_indices_or_sections_py_impl(
|
||||
|
||||
# TODO: this doesn't appear to have enough precision in bfloat16
|
||||
@register_decomposition(aten.addmm)
|
||||
@out_wrapper()
|
||||
@out_wrapper(exact_dtype=True)
|
||||
@pw_cast_for_opmath
|
||||
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
|
||||
if not self.is_floating_point() and not self.is_complex():
|
||||
|
Reference in New Issue
Block a user