mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 16:44:54 +08:00
Deprecate torch.lu
**BC-breaking note**: This PR deprecates `torch.lu` in favor of `torch.linalg.lu_factor`. A upgrade guide is added to the documentation for `torch.lu`. Note this PR DOES NOT remove `torch.lu`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/73804 Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
9dc8f2562f
commit
a5bbfd94fb
@ -5809,16 +5809,16 @@ add_docstr(torch.lu_solve,
|
||||
lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor
|
||||
|
||||
Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
|
||||
LU factorization of A from :meth:`torch.lu`.
|
||||
LU factorization of A from :func:`~linalg.lu_factor`.
|
||||
|
||||
This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.
|
||||
|
||||
Arguments:
|
||||
b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*`
|
||||
is zero or more batch dimensions.
|
||||
LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu` of size :math:`(*, m, m)`,
|
||||
LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`,
|
||||
where :math:`*` is zero or more batch dimensions.
|
||||
LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`torch.lu` of size :math:`(*, m)`,
|
||||
LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`,
|
||||
where :math:`*` is zero or more batch dimensions.
|
||||
The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of
|
||||
:attr:`LU_data`.
|
||||
@ -5830,9 +5830,9 @@ Example::
|
||||
|
||||
>>> A = torch.randn(2, 3, 3)
|
||||
>>> b = torch.randn(2, 3, 1)
|
||||
>>> A_LU = torch.lu(A)
|
||||
>>> x = torch.lu_solve(b, *A_LU)
|
||||
>>> torch.norm(torch.bmm(A, x) - b)
|
||||
>>> LU, pivots = torch.linalg.lu_factor(A)
|
||||
>>> x = torch.lu_solve(b, LU, pivots)
|
||||
>>> torch.dist(A @ x, b)
|
||||
tensor(1.00000e-07 *
|
||||
2.8312)
|
||||
""".format(**common_args))
|
||||
|
||||
Reference in New Issue
Block a user