2d bias Lt only for activation epilogues

This commit is contained in:
Nikita Vedeneev
2025-11-05 14:31:18 +00:00
parent 7e1a5f8c8b
commit 05f45e231e

View File

@ -147,7 +147,15 @@ static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
static bool isInputCompliesAddmmCudaLt(
Tensor& result,
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Activation activation
) {
#ifdef USE_ROCM
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
@ -191,11 +199,12 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
// NOTE: fine to have 1-len dims to the left from the right-most one
(self.dim() == 1 || self.squeeze().dim() == 1) &&
self.sizes().back() == mat2_sizes[1]
) ||
( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self),
)
|| ( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self),
// and we need to copy self into result otherwise, so the self's layout becomes irrelevant.
// See also TODO from above.
self.dim() == 2 && self.sizes()[0] == mat1_sizes[0] && self.sizes()[1] == mat2_sizes[1]
activation != Activation::None && // Lt is faster when activation is fused
(self.dim() == 2 && self.sizes()[0] == mat1_sizes[0] && self.sizes()[1] == mat2_sizes[1])
)
)
&& ( // some dtype restrictions
@ -385,7 +394,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
// }
at::ScalarType scalar_type = mat1.scalar_type();