Added to docs for out_dtype arg in torch gemms (#151704)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151704
Approved by: https://github.com/bdhirsh
This commit is contained in:
PaulZhang12
2025-04-18 14:20:53 -07:00
committed by PyTorch MergeBot
parent 1a6effc5d8
commit 191b0237a6

View File

@ -541,7 +541,7 @@ Example::
add_docstr(
torch.addmm,
r"""
addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor
addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor
Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.
The matrix :attr:`input` is added to the final result.
@ -578,6 +578,9 @@ Args:
input (Tensor): matrix to be added
mat1 (Tensor): the first matrix to be matrix multiplied
mat2 (Tensor): the second matrix to be matrix multiplied
out_dtype (dtype, optional): the dtype of the output tensor,
Supported only on CUDA and for torch.float32 given
torch.float16/torch.bfloat16 input dtypes
Keyword args:
beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`)
@ -1324,7 +1327,7 @@ Example::
add_docstr(
torch.baddbmm,
r"""
baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor
baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor
Performs a batch matrix-matrix product of matrices in :attr:`batch1`
and :attr:`batch2`.
@ -1358,6 +1361,9 @@ Args:
input (Tensor): the tensor to be added
batch1 (Tensor): the first batch of matrices to be multiplied
batch2 (Tensor): the second batch of matrices to be multiplied
out_dtype (dtype, optional): the dtype of the output tensor,
Supported only on CUDA and for torch.float32 given
torch.float16/torch.bfloat16 input dtypes
Keyword args:
beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`)
@ -1500,7 +1506,7 @@ Example::
add_docstr(
torch.bmm,
r"""
bmm(input, mat2, *, out=None) -> Tensor
bmm(input, mat2, out_dtype=None, *, out=None) -> Tensor
Performs a batch matrix-matrix product of matrices stored in :attr:`input`
and :attr:`mat2`.
@ -1526,6 +1532,9 @@ If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a
Args:
input (Tensor): the first batch of matrices to be multiplied
mat2 (Tensor): the second batch of matrices to be multiplied
out_dtype (dtype, optional): the dtype of the output tensor,
Supported only on CUDA and for torch.float32 given
torch.float16/torch.bfloat16 input dtypes
Keyword Args:
{out}
@ -7353,7 +7362,7 @@ Example::
add_docstr(
torch.mm,
r"""
mm(input, mat2, *, out=None) -> Tensor
mm(input, mat2, out_dtype=None, *, out=None) -> Tensor
Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`.
@ -7379,6 +7388,9 @@ layout will be deduced from that of :attr:`input`.
Args:
input (Tensor): the first matrix to be matrix multiplied
mat2 (Tensor): the second matrix to be matrix multiplied
out_dtype (dtype, optional): the dtype of the output tensor,
Supported only on CUDA and for torch.float32 given
torch.float16/torch.bfloat16 input dtypes
Keyword args:
{out}