mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Round T up to next multiple of 8 in NestedTensor case
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77903 Code comment should explain why; in brief, it lets us use Tensor cores. Differential Revision: [D36527773](https://our.internmc.facebook.com/intern/diff/D36527773/) Approved by: https://github.com/ngimel, https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
51c4c79e3d
commit
7e4730d017
@ -969,7 +969,7 @@ Tensor masked_softmax_cuda(const Tensor& input_, const Tensor& mask_, const c10:
|
||||
// expand mask to [B, H, T, T] and treat it like regular mask
|
||||
// TODO We should have special fast kernel for TxT mask as well
|
||||
bool is_TxT_mask = input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1);
|
||||
TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input");
|
||||
TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes());
|
||||
|
||||
auto input = input_.dim() == 0 ? input_.view(1) : input_;
|
||||
auto mask = mask_.dim() == 0 ? mask_.view(1) : mask_;
|
||||
|
@ -316,7 +316,17 @@ Tensor nested_from_padded_generic(
|
||||
padded.size(2),
|
||||
padded.size(1) * padded.size(3)});
|
||||
}
|
||||
const auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes);
|
||||
auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes);
|
||||
// There may be extra padding on padded beyond the max size in the nested tensor.
|
||||
// Make the mask size match.
|
||||
const size_t dim = padded_transformed.dim();
|
||||
TORCH_CHECK(dim - 1 == target_size.size(), "dim: ", dim, "target_size: ", target_size.size());
|
||||
for (size_t ii = 0; ii < dim - 1; ++ii) {
|
||||
const auto padded_size_i = padded_transformed.sizes()[ii + 1];
|
||||
if (target_size[ii] < padded_size_i) {
|
||||
target_size[ii] = padded_size_i;
|
||||
}
|
||||
}
|
||||
IntArrayRef target_size_arr(target_size);
|
||||
std::vector<at::Tensor> masks;
|
||||
std::vector<at::Tensor> all_sizes = sizes.unbind();
|
||||
|
@ -104,7 +104,7 @@ Tensor NestedTensor_batch_offsets_from_size_tensor(
|
||||
return offsets;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim) {
|
||||
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c10::optional<int64_t> mask_dim_length) {
|
||||
auto* nt_impl = get_nested_tensor_impl(nt);
|
||||
TORCH_CHECK(
|
||||
!mask_dim || *mask_dim < nt.dim(),
|
||||
@ -123,7 +123,7 @@ Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim) {
|
||||
const auto& sizes = nt_impl->get_nested_size_tensor();
|
||||
// Shape: # of tensors in our NestedTensor by max size along first dim
|
||||
// TODO: calculate this without allocating a std::vector.
|
||||
const auto result_size_1 = NestedTensor_get_max_size(*nt_impl)[0];
|
||||
const auto result_size_1 = mask_dim_length ? *mask_dim_length : NestedTensor_get_max_size(*nt_impl)[0];
|
||||
auto result = at::ones({sizes.sizes()[0], result_size_1}, at::kBool);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
|
||||
auto* result_data = result.data_ptr<bool>();
|
||||
|
@ -50,7 +50,7 @@ Tensor NestedTensor_from_padded_tensor_cpu(
|
||||
const Tensor& padded,
|
||||
const NestedTensorImpl& nt);
|
||||
|
||||
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim);
|
||||
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c10::optional<int64_t> mask_dim_length);
|
||||
|
||||
template <typename T>
|
||||
void remove_padding_kernelLauncher(
|
||||
|
@ -122,7 +122,7 @@ Tensor masked_softmax(
|
||||
if (query.is_nested() && !attn_mask) {
|
||||
// TODO: maybe we could do better than generating a mask every time?
|
||||
|
||||
attn_mask = NestedTensor_to_mask(query, 2);
|
||||
attn_mask = NestedTensor_to_mask(query, 2, attn_scores.size(2));
|
||||
// TODO: CPU path does not support transformer mask yet.
|
||||
if (attn_scores.is_cpu()) {
|
||||
attn_mask = attn_mask->view({-1, 1, 1, attn_scores.sizes()[3]});
|
||||
|
@ -317,6 +317,16 @@ __host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
|
||||
auto T = qkv.is_nested()
|
||||
? NestedTensor_get_max_size(*get_nested_tensor_impl(qkv))[0]
|
||||
: qkv.size(1);
|
||||
if (qkv.is_nested()) {
|
||||
// Don't mess with non-nested case for now since it's not set up to fiddle
|
||||
// with mask size.
|
||||
|
||||
// Round T up to next multiple of 8 so as to be able to utilize Tensor
|
||||
// cores. Otherwise, sometimes with padding, *no* row will have the maximum
|
||||
// sequence length and so we'll have a non-divisible-by-8 dimension even if
|
||||
// the model author chose a multiple of 8.
|
||||
T = T + (8 - (T % 8)) % 8;
|
||||
}
|
||||
auto _3D = qkv_bias.size(0);
|
||||
auto D = _3D / 3;
|
||||
TORCH_CHECK(D % num_head == 0);
|
||||
|
@ -54,9 +54,19 @@ class TestMHADeviceType(TestCase):
|
||||
def simple_transform_bias_rescale_qkv(qkv, bias):
|
||||
(q, k, v) = torch.split(qkv, embed_dim, dim=-1)
|
||||
(q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1)
|
||||
|
||||
def embiggen(x):
|
||||
if not use_nt:
|
||||
return x
|
||||
b, t, d = x.size()
|
||||
t = t + (8 - t % 8) % 8
|
||||
newsize = (b, t, d)
|
||||
new_x = torch.zeros(newsize, device=device, dtype=dtype)
|
||||
new_x[:x.size()[0], :x.size()[1], :x.size()[2]] = x
|
||||
return new_x
|
||||
return tuple(
|
||||
x.reshape(
|
||||
(bs, sl, num_heads, embed_dim // num_heads)
|
||||
embiggen(x).reshape(
|
||||
(bs, -1, num_heads, embed_dim // num_heads)
|
||||
).transpose(2, 1)
|
||||
for x in (
|
||||
(q + q_bias) / math.sqrt(embed_dim // num_heads),
|
||||
|
Reference in New Issue
Block a user