Fix torch.matmul related out dtype check (#148174)

----

- torch.matmul -> CompositeImplicitAutograd -> dot_out (when left_dim == 1 & right_dim == 1)
                                            -> mv_out (when left_dim == 2 & right_dim == 1)
                                            -> mm_out (when left_dim == 1 & right_dim == 2)
                                            -> ...
- torch.dot
- torch.vdot
- torch.mm
- torch.mv

ISSUE related:
https://github.com/pytorch/pytorch/issues/138399
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148174
Approved by: https://github.com/jansel
This commit is contained in:
FFFrog
2025-04-08 11:53:48 +00:00
committed by PyTorch MergeBot
parent 173f126068
commit 3e0038ae85
4 changed files with 4 additions and 11 deletions

View File

@ -135,7 +135,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("cummin"),
xfail("diag"),
xfail("diagonal_copy"),
xfail("dot"),
xfail("expand_copy"),
xfail("fft.ihfft2"),
xfail("fft.ihfftn"),
@ -159,7 +158,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("linalg.lu_factor"),
xfail("linalg.lu_factor_ex"),
xfail("linalg.lu_solve"),
xfail("linalg.matrix_power"),
xfail("linalg.qr"),
xfail("linalg.slogdet"),
xfail("linalg.solve"),
@ -168,12 +166,9 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("logcumsumexp"),
xfail("lu_solve"),
xfail("lu_unpack"),
xfail("matmul"),
xfail("mm"),
xfail("mode"),
xfail("msort"),
xfail("multinomial"),
xfail("mv"),
xfail("nan_to_num"),
xfail("nanmean"),
xfail("narrow_copy"),
@ -182,7 +177,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("nn.functional.avg_pool3d"),
xfail("nn.functional.gelu"),
xfail("nn.functional.hardshrink"),
xfail("nn.functional.linear"),
xfail("nn.functional.logsigmoid"),
xfail("nn.functional.softplus"),
xfail("nn.functional.softshrink"),
@ -210,7 +204,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("triu"),
xfail("unfold_copy"),
xfail("unsqueeze_copy"),
xfail("vdot"),
xfail("view_copy"),
xfail("where"),
# Output has dynamic shape.