log_softmax: fix meta function output argument dtype check. (#140289)

Tracking issue: #138399
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140289
Approved by: https://github.com/ezyang
ghstack dependencies: #140186, #140286, #140288
This commit is contained in:
Yukio Siraichi
2024-11-11 21:01:40 -03:00
committed by PyTorch MergeBot
parent 435286e985
commit 48a276c5a0
2 changed files with 1 additions and 2 deletions

View File

@ -168,7 +168,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
xfail("linalg.solve"),
xfail("linalg.solve_ex"),
xfail("linalg.solve_triangular"),
xfail("log_softmax"),
xfail("logcumsumexp"),
xfail("lu_solve"),
xfail("lu_unpack"),

View File

@ -1220,7 +1220,7 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
@register_decomposition(aten._log_softmax)
@out_wrapper()
@out_wrapper(exact_dtype=True)
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
# returns a contiguous tensor.