mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
log_softmax
: fix meta function output argument dtype check. (#140289)
Tracking issue: #138399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140289 Approved by: https://github.com/ezyang ghstack dependencies: #140186, #140286, #140288
This commit is contained in:
committed by
PyTorch MergeBot
parent
435286e985
commit
48a276c5a0
@ -168,7 +168,6 @@ meta_consistency_out_dtype_mismatch_xfails = {
|
||||
xfail("linalg.solve"),
|
||||
xfail("linalg.solve_ex"),
|
||||
xfail("linalg.solve_triangular"),
|
||||
xfail("log_softmax"),
|
||||
xfail("logcumsumexp"),
|
||||
xfail("lu_solve"),
|
||||
xfail("lu_unpack"),
|
||||
|
@ -1220,7 +1220,7 @@ def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||
|
||||
|
||||
@register_decomposition(aten._log_softmax)
|
||||
@out_wrapper()
|
||||
@out_wrapper(exact_dtype=True)
|
||||
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
||||
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
|
||||
# returns a contiguous tensor.
|
||||
|
Reference in New Issue
Block a user