mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] MHA: fix contiguity assumption in transform_bias_rescale_qkv (#72465)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72465 This code path incorrectly assumed input tensors were contiguous. Now we check that. ghstack-source-id: 149201476 Test Plan: CI Reviewed By: ngimel Differential Revision: D34007665 fbshipit-source-id: c43438f2495e32304ea3f7846e01eceb4a9448f7 (cherry picked from commit 0767b225f23846c1636ac3622f46b0c5ec071d96)
This commit is contained in:
committed by
PyTorch MergeBot
parent
dadbf43eff
commit
41ad221751
@ -113,15 +113,18 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
|
||||
TORCH_CHECK(_3D % 3 == 0);
|
||||
const auto dim_per_head = D / num_head;
|
||||
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv.options());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v.is_contiguous());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
const auto qkv_contig = qkv.expect_contiguous();
|
||||
const auto qkv_bias_contig = qkv_bias.expect_contiguous();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::Half,
|
||||
ScalarType::BFloat16,
|
||||
qkv.scalar_type(),
|
||||
"transform_bias_rescale_qkv",
|
||||
[&] {
|
||||
scalar_t* qkv_data = qkv.data_ptr<scalar_t>();
|
||||
scalar_t* qkv_bias_data = qkv_bias.data_ptr<scalar_t>();
|
||||
scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
|
||||
scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
|
||||
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
|
||||
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));
|
||||
|
||||
@ -134,6 +137,7 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
|
||||
});
|
||||
auto q_k_v_s =
|
||||
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
|
||||
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user