[PyTorch] MHA: fix contiguity assumption in transform_bias_rescale_qkv (#72465)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72465

This code path incorrectly assumed input tensors were contiguous. Now we check that.
ghstack-source-id: 149201476

Test Plan: CI

Reviewed By: ngimel

Differential Revision: D34007665

fbshipit-source-id: c43438f2495e32304ea3f7846e01eceb4a9448f7
(cherry picked from commit 0767b225f23846c1636ac3622f46b0c5ec071d96)
This commit is contained in:
Scott Wolchok
2022-02-16 10:18:51 -08:00
committed by PyTorch MergeBot
parent dadbf43eff
commit 41ad221751

View File

@ -113,15 +113,18 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
TORCH_CHECK(_3D % 3 == 0);
const auto dim_per_head = D / num_head;
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv.options());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v.is_contiguous());
AT_DISPATCH_FLOATING_TYPES_AND2(
const auto qkv_contig = qkv.expect_contiguous();
const auto qkv_bias_contig = qkv_bias.expect_contiguous();
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
qkv.scalar_type(),
"transform_bias_rescale_qkv",
[&] {
scalar_t* qkv_data = qkv.data_ptr<scalar_t>();
scalar_t* qkv_bias_data = qkv_bias.data_ptr<scalar_t>();
scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
@ -134,6 +137,7 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
});
auto q_k_v_s =
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
}