mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
2d bias Lt only for activation epilogues
This commit is contained in:
@ -147,7 +147,15 @@ static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
||||
/*
|
||||
* Check whether for the given input we want to enable the Lt interface
|
||||
*/
|
||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
static bool isInputCompliesAddmmCudaLt(
|
||||
Tensor& result,
|
||||
const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Activation activation
|
||||
) {
|
||||
#ifdef USE_ROCM
|
||||
// Implies 2D bias which we currently not send through Lt.
|
||||
// TODO: this check is done pre col-major input preparation,
|
||||
@ -191,11 +199,12 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
|
||||
// NOTE: fine to have 1-len dims to the left from the right-most one
|
||||
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
||||
self.sizes().back() == mat2_sizes[1]
|
||||
) ||
|
||||
( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self),
|
||||
)
|
||||
|| ( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self),
|
||||
// and we need to copy self into result otherwise, so the self's layout becomes irrelevant.
|
||||
// See also TODO from above.
|
||||
self.dim() == 2 && self.sizes()[0] == mat1_sizes[0] && self.sizes()[1] == mat2_sizes[1]
|
||||
activation != Activation::None && // Lt is faster when activation is fused
|
||||
(self.dim() == 2 && self.sizes()[0] == mat1_sizes[0] && self.sizes()[1] == mat2_sizes[1])
|
||||
)
|
||||
)
|
||||
&& ( // some dtype restrictions
|
||||
@ -385,7 +394,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
|
||||
Reference in New Issue
Block a user