[PyTorch] Round T up to next multiple of 8 in NestedTensor case

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77903

Code comment should explain why; in brief, it lets us use Tensor cores.

Differential Revision: [D36527773](https://our.internmc.facebook.com/intern/diff/D36527773/)

Approved by: https://github.com/ngimel, https://github.com/cpuhrsch
This commit is contained in:
Scott Wolchok
2022-05-24 13:36:34 -07:00
committed by PyTorch MergeBot
parent 51c4c79e3d
commit 7e4730d017
7 changed files with 38 additions and 8 deletions

View File

@ -969,7 +969,7 @@ Tensor masked_softmax_cuda(const Tensor& input_, const Tensor& mask_, const c10:
// expand mask to [B, H, T, T] and treat it like regular mask
// TODO We should have special fast kernel for TxT mask as well
bool is_TxT_mask = input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1);
TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input");
TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes());
auto input = input_.dim() == 0 ? input_.view(1) : input_;
auto mask = mask_.dim() == 0 ? mask_.view(1) : mask_;

View File

@ -316,7 +316,17 @@ Tensor nested_from_padded_generic(
padded.size(2),
padded.size(1) * padded.size(3)});
}
const auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes);
auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes);
// There may be extra padding on padded beyond the max size in the nested tensor.
// Make the mask size match.
const size_t dim = padded_transformed.dim();
TORCH_CHECK(dim - 1 == target_size.size(), "dim: ", dim, "target_size: ", target_size.size());
for (size_t ii = 0; ii < dim - 1; ++ii) {
const auto padded_size_i = padded_transformed.sizes()[ii + 1];
if (target_size[ii] < padded_size_i) {
target_size[ii] = padded_size_i;
}
}
IntArrayRef target_size_arr(target_size);
std::vector<at::Tensor> masks;
std::vector<at::Tensor> all_sizes = sizes.unbind();

View File

@ -104,7 +104,7 @@ Tensor NestedTensor_batch_offsets_from_size_tensor(
return offsets;
}
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim) {
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c10::optional<int64_t> mask_dim_length) {
auto* nt_impl = get_nested_tensor_impl(nt);
TORCH_CHECK(
!mask_dim || *mask_dim < nt.dim(),
@ -123,7 +123,7 @@ Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim) {
const auto& sizes = nt_impl->get_nested_size_tensor();
// Shape: # of tensors in our NestedTensor by max size along first dim
// TODO: calculate this without allocating a std::vector.
const auto result_size_1 = NestedTensor_get_max_size(*nt_impl)[0];
const auto result_size_1 = mask_dim_length ? *mask_dim_length : NestedTensor_get_max_size(*nt_impl)[0];
auto result = at::ones({sizes.sizes()[0], result_size_1}, at::kBool);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
auto* result_data = result.data_ptr<bool>();

View File

@ -50,7 +50,7 @@ Tensor NestedTensor_from_padded_tensor_cpu(
const Tensor& padded,
const NestedTensorImpl& nt);
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim);
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c10::optional<int64_t> mask_dim_length);
template <typename T>
void remove_padding_kernelLauncher(

View File

@ -122,7 +122,7 @@ Tensor masked_softmax(
if (query.is_nested() && !attn_mask) {
// TODO: maybe we could do better than generating a mask every time?
attn_mask = NestedTensor_to_mask(query, 2);
attn_mask = NestedTensor_to_mask(query, 2, attn_scores.size(2));
// TODO: CPU path does not support transformer mask yet.
if (attn_scores.is_cpu()) {
attn_mask = attn_mask->view({-1, 1, 1, attn_scores.sizes()[3]});

View File

@ -317,6 +317,16 @@ __host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
auto T = qkv.is_nested()
? NestedTensor_get_max_size(*get_nested_tensor_impl(qkv))[0]
: qkv.size(1);
if (qkv.is_nested()) {
// Don't mess with non-nested case for now since it's not set up to fiddle
// with mask size.
// Round T up to next multiple of 8 so as to be able to utilize Tensor
// cores. Otherwise, sometimes with padding, *no* row will have the maximum
// sequence length and so we'll have a non-divisible-by-8 dimension even if
// the model author chose a multiple of 8.
T = T + (8 - (T % 8)) % 8;
}
auto _3D = qkv_bias.size(0);
auto D = _3D / 3;
TORCH_CHECK(D % num_head == 0);

View File

@ -54,9 +54,19 @@ class TestMHADeviceType(TestCase):
def simple_transform_bias_rescale_qkv(qkv, bias):
(q, k, v) = torch.split(qkv, embed_dim, dim=-1)
(q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1)
def embiggen(x):
if not use_nt:
return x
b, t, d = x.size()
t = t + (8 - t % 8) % 8
newsize = (b, t, d)
new_x = torch.zeros(newsize, device=device, dtype=dtype)
new_x[:x.size()[0], :x.size()[1], :x.size()[2]] = x
return new_x
return tuple(
x.reshape(
(bs, sl, num_heads, embed_dim // num_heads)
embiggen(x).reshape(
(bs, -1, num_heads, embed_dim // num_heads)
).transpose(2, 1)
for x in (
(q + q_bias) / math.sqrt(embed_dim // num_heads),