mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Add native fast path for transformer encoder inference
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75809 The current PyTorch multi-head attention and transformer implementations are slow. This should speed them up for inference. Differential Revision: [D35239925](https://our.internmc.facebook.com/intern/diff/D35239925/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D35239925/)! Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
36420b5e8c
commit
b369b89f23
@ -224,6 +224,11 @@ filegroup(
|
||||
),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_native_transformers_cpp",
|
||||
srcs = glob(["aten/src/ATen/native/transformers/*.cpp"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_native_mkl_cpp",
|
||||
srcs = glob(["aten/src/ATen/native/mkl/*.cpp", "aten/src/ATen/mkl/*.cpp"]),
|
||||
@ -274,6 +279,7 @@ filegroup(
|
||||
"aten/src/ATen/native/miopen/*.cpp",
|
||||
"aten/src/ATen/native/nested/cuda/*.cpp",
|
||||
"aten/src/ATen/native/sparse/cuda/*.cpp",
|
||||
"aten/src/ATen/native/transformers/cuda/*.cpp",
|
||||
"aten/src/THC/*.cpp",
|
||||
],
|
||||
),
|
||||
@ -288,6 +294,7 @@ filegroup(
|
||||
"aten/src/ATen/native/nested/cuda/*.cu",
|
||||
"aten/src/ATen/native/quantized/cuda/*.cu",
|
||||
"aten/src/ATen/native/sparse/cuda/*.cu",
|
||||
"aten/src/ATen/native/transformers/cuda/*.cu",
|
||||
]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"),
|
||||
# It's a bit puzzling to me why it's not necessary to declare the
|
||||
# target that generates these sources...
|
||||
@ -389,6 +396,7 @@ cc_library(
|
||||
":aten_native_quantized_cpp",
|
||||
":aten_native_sparse_cpp",
|
||||
":aten_native_nested_cpp",
|
||||
":aten_native_transformers_cpp",
|
||||
":aten_native_xnnpack",
|
||||
":aten_src_ATen_config",
|
||||
] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
|
||||
|
@ -105,6 +105,7 @@ file(GLOB native_quantized_cpp
|
||||
"native/quantized/*.cpp"
|
||||
"native/quantized/cpu/*.cpp")
|
||||
file(GLOB native_nested_cpp "native/nested/*.cpp")
|
||||
file(GLOB native_transformers_cpp "native/transformers/*.cpp")
|
||||
|
||||
file(GLOB native_h "native/*.h")
|
||||
file(GLOB native_ao_sparse_h
|
||||
@ -128,6 +129,8 @@ file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp")
|
||||
file(GLOB native_quantized_cuda_cu "native/quantized/cuda/*.cu")
|
||||
file(GLOB native_quantized_cuda_cpp "native/quantized/cuda/*.cpp")
|
||||
file(GLOB native_quantized_cudnn_cpp "native/quantized/cudnn/*.cpp")
|
||||
file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu")
|
||||
file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp")
|
||||
|
||||
file(GLOB native_hip_hip "native/hip/*.hip")
|
||||
file(GLOB native_hip_cpp "native/hip/*.cpp")
|
||||
@ -140,6 +143,8 @@ file(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip")
|
||||
file(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp")
|
||||
file(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip")
|
||||
file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
|
||||
file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip")
|
||||
file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp")
|
||||
file(GLOB native_utils_cpp "native/utils/*.cpp")
|
||||
|
||||
# XNNPACK
|
||||
@ -162,6 +167,7 @@ else()
|
||||
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp}
|
||||
${native_ao_sparse_cpp} ${native_sparse_cpp} ${native_nested_cpp}
|
||||
${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp}
|
||||
${native_transformers_cpp}
|
||||
${native_utils_cpp} ${native_xnnpack} ${generated_sources} ${core_generated_sources}
|
||||
${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${ATen_NNAPI_SRCS} ${cpu_kernel_cpp}
|
||||
)
|
||||
@ -205,6 +211,7 @@ if(USE_CUDA)
|
||||
${native_nested_cuda_cu}
|
||||
${native_sparse_cuda_cu}
|
||||
${native_quantized_cuda_cu}
|
||||
${native_transformers_cuda_cu}
|
||||
${cuda_generated_sources}
|
||||
)
|
||||
list(APPEND ATen_CUDA_CPP_SRCS
|
||||
@ -216,6 +223,7 @@ if(USE_CUDA)
|
||||
${native_quantized_cuda_cpp}
|
||||
${native_quantized_cudnn_cpp}
|
||||
${native_sparse_cuda_cpp}
|
||||
${native_transformers_cuda_cpp}
|
||||
)
|
||||
set(ATen_CUDA_LINALG_SRCS ${native_cuda_linalg_cpp})
|
||||
if(NOT BUILD_LAZY_CUDA_LINALG)
|
||||
@ -238,9 +246,9 @@ endif()
|
||||
|
||||
if(USE_ROCM)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
||||
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip})
|
||||
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip} ${native_transformers_hip_hip})
|
||||
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
|
||||
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
|
||||
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${native_transformers_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
|
||||
set(all_hip_cpp ${native_miopen_cpp} ${native_cudnn_hip_cpp} ${miopen_cpp} ${all_hip_cpp})
|
||||
endif()
|
||||
|
||||
|
@ -4663,6 +4663,12 @@
|
||||
|
||||
- func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor
|
||||
|
||||
# Fused implementation detail for transformers. Adds in-projection bias to QKV and divides Q by sqrt(D/num_heads).
|
||||
- func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CPU, NestedTensorCPU: transform_bias_rescale_qkv_cpu
|
||||
CUDA, NestedTensorCUDA: transform_bias_rescale_qkv_cuda
|
||||
|
||||
- func: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
|
||||
device_check: NoCheck # cpu_nested_shape_example will always be on CPU
|
||||
dispatch:
|
||||
@ -11603,3 +11609,14 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_layer_norm
|
||||
|
||||
# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is.
|
||||
- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: transformer_encoder_layer_forward
|
||||
|
||||
- func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True) -> (Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention
|
||||
|
@ -1,3 +1,5 @@
|
||||
#include <ATen/native/nested/NestedTensorMath.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
|
@ -1,5 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
struct NestedTensorImpl;
|
||||
@ -7,7 +11,7 @@ struct NestedTensorImpl;
|
||||
// TODO: cache this and only do it once per NestedTensor
|
||||
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
|
||||
|
||||
std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt);
|
||||
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -35,6 +35,12 @@ Tensor nested_from_padded_cuda(
|
||||
const Tensor& sizes,
|
||||
bool do_transform_0213) {
|
||||
if (padded.dim() > 1 && padded.dim() < 5) {
|
||||
if (padded.dtype() != kFloat && padded.dtype() != kHalf) {
|
||||
TORCH_WARN_ONCE(
|
||||
"nested_from_padded CUDA kernels only support fp32/fp16; falling "
|
||||
"back to slower generic kernel");
|
||||
return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
(padded.dim() == 4 && do_transform_0213) ||
|
||||
(padded.dim() == 3 && !do_transform_0213),
|
||||
|
482
aten/src/ATen/native/transformers/attention.cpp
Normal file
482
aten/src/ATen/native/transformers/attention.cpp
Normal file
@ -0,0 +1,482 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorIndexing.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/cat.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
Tensor gemm_nt(const Tensor& self, const Tensor& other) {
|
||||
if (self.is_nested()) {
|
||||
return NestedTensor_matmul(self, other.t());
|
||||
} else {
|
||||
return at::native::matmul(self, other.t());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void transform_bias_rescale_qkv_inner_loop(
|
||||
int64_t B,
|
||||
int64_t T,
|
||||
int64_t _3D,
|
||||
int64_t D,
|
||||
int64_t num_head,
|
||||
int64_t dim_per_head,
|
||||
scalar_t* qkv_data,
|
||||
scalar_t* qkv_bias_data,
|
||||
scalar_t* q_k_v_data,
|
||||
scalar_t inv_sqrt_dim_per_head,
|
||||
int64_t begin,
|
||||
int64_t end) {
|
||||
for (auto i : c10::irange(begin, end)) {
|
||||
auto t = i % T;
|
||||
i /= T;
|
||||
auto nh = i % num_head;
|
||||
i /= num_head;
|
||||
auto b = i;
|
||||
using Vec = vec::Vectorized<scalar_t>;
|
||||
auto V = vec::Vectorized<scalar_t>::size();
|
||||
auto dh = 0;
|
||||
auto d = nh * dim_per_head;
|
||||
for (; dh + V <= dim_per_head; dh += V, d += V) {
|
||||
// load
|
||||
auto q_bias_data = Vec::loadu(&qkv_bias_data[d + 0 * D]);
|
||||
auto k_bias_data = Vec::loadu(&qkv_bias_data[d + 1 * D]);
|
||||
auto v_bias_data = Vec::loadu(&qkv_bias_data[d + 2 * D]);
|
||||
|
||||
auto q_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 0 * D]) +
|
||||
q_bias_data;
|
||||
auto k_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 1 * D]) +
|
||||
k_bias_data;
|
||||
auto v_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 2 * D]) +
|
||||
v_bias_data;
|
||||
|
||||
q_data = q_data * Vec(inv_sqrt_dim_per_head);
|
||||
|
||||
q_data.store(&q_k_v_data
|
||||
[0 * B * num_head * T * dim_per_head +
|
||||
b * num_head * T * dim_per_head +
|
||||
nh * T * dim_per_head + t * dim_per_head + dh]);
|
||||
k_data.store(&q_k_v_data
|
||||
[1 * B * num_head * T * dim_per_head +
|
||||
b * num_head * T * dim_per_head +
|
||||
nh * T * dim_per_head + t * dim_per_head + dh]);
|
||||
v_data.store(&q_k_v_data
|
||||
[2 * B * num_head * T * dim_per_head +
|
||||
b * num_head * T * dim_per_head +
|
||||
nh * T * dim_per_head + t * dim_per_head + dh]);
|
||||
}
|
||||
for (; dh < dim_per_head; dh++) {
|
||||
auto d = nh * dim_per_head + dh;
|
||||
auto q_bias = qkv_bias_data[d + 0 * D];
|
||||
auto k_bias = qkv_bias_data[d + 1 * D];
|
||||
auto v_bias = qkv_bias_data[d + 2 * D];
|
||||
auto q_data = qkv_data[b * _3D * T + t * _3D + d + 0 * D] + q_bias;
|
||||
auto k_data = qkv_data[b * _3D * T + t * _3D + d + 1 * D] + k_bias;
|
||||
auto v_data = qkv_data[b * _3D * T + t * _3D + d + 2 * D] + v_bias;
|
||||
q_data = q_data * inv_sqrt_dim_per_head;
|
||||
q_k_v_data
|
||||
[0 * B * num_head * T * dim_per_head +
|
||||
b * num_head * T * dim_per_head + nh * T * dim_per_head +
|
||||
t * dim_per_head + dh] = q_data;
|
||||
q_k_v_data
|
||||
[1 * B * num_head * T * dim_per_head +
|
||||
b * num_head * T * dim_per_head + nh * T * dim_per_head +
|
||||
t * dim_per_head + dh] = k_data;
|
||||
q_k_v_data
|
||||
[2 * B * num_head * T * dim_per_head +
|
||||
b * num_head * T * dim_per_head + nh * T * dim_per_head +
|
||||
t * dim_per_head + dh] = v_data;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor bmm_nt(const Tensor& a, const Tensor& b) {
|
||||
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
|
||||
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
|
||||
auto bt_ = b_.transpose(2, 1);
|
||||
auto c_ = at::bmm(a_, bt_);
|
||||
return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});
|
||||
}
|
||||
|
||||
Tensor masked_softmax(
|
||||
Tensor& attn_scores,
|
||||
c10::optional<Tensor> attn_mask,
|
||||
const Tensor& query) {
|
||||
if (query.is_nested() && !attn_mask) {
|
||||
// TODO: maybe we could do better than generating a mask every time?
|
||||
|
||||
attn_mask = NestedTensor_to_mask(query, 2);
|
||||
// TODO: CPU path does not support transformer mask yet.
|
||||
if (attn_scores.is_cpu()) {
|
||||
attn_mask = attn_mask->view({-1, 1, 1, attn_scores.sizes()[3]});
|
||||
// 1 means skip, 0 means keep.
|
||||
// want:
|
||||
// 0,0 -> 0
|
||||
// 0,1 -> 1
|
||||
// 1,1 -> 1
|
||||
// so that's logical OR.
|
||||
*attn_mask = *attn_mask | attn_mask->transpose(2, 3);
|
||||
attn_mask = at::expand_inplace(attn_scores, *attn_mask)->contiguous();
|
||||
}
|
||||
attn_mask = attn_mask->to(query.device(), /*non-blocking=*/true);
|
||||
}
|
||||
if (attn_mask && attn_mask->dtype() != at::kBool) {
|
||||
TORCH_WARN(
|
||||
"Converting mask without torch.bool dtype to bool; this will "
|
||||
"negatively affect performance. Prefer to use a boolean mask directly.");
|
||||
attn_mask = attn_mask->to(at::kBool);
|
||||
}
|
||||
if (attn_scores.is_cpu() && attn_mask && attn_mask->dim() == 2) {
|
||||
// TODO: CPU path does not support transformer mask yet.
|
||||
const auto batch_size = attn_scores.sizes()[0];
|
||||
const auto seq_len = attn_scores.sizes()[3];
|
||||
TORCH_CHECK(attn_mask->sizes()[0] == batch_size);
|
||||
TORCH_CHECK(attn_mask->sizes()[1] == seq_len);
|
||||
attn_mask = attn_mask->view({batch_size, 1, 1, seq_len});
|
||||
attn_mask = at::expand_inplace(attn_scores, *attn_mask)->contiguous();
|
||||
}
|
||||
if (attn_mask) {
|
||||
return _masked_softmax(attn_scores, *attn_mask);
|
||||
} else {
|
||||
return _softmax_out(attn_scores, attn_scores, attn_scores.dim() - 1, false);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b) {
|
||||
const std::array<int64_t, 3> newAShape = {
|
||||
a.sizes()[0] * a.sizes()[1], a.sizes()[2], a.sizes()[3]};
|
||||
auto a_ = a.view(newAShape);
|
||||
const std::array<int64_t, 3> newBShape = {
|
||||
b.sizes()[0] * b.sizes()[1], b.sizes()[2], b.sizes()[3]};
|
||||
auto b_ = b.view(newBShape);
|
||||
auto out_ = out.reshape({newAShape[0], newAShape[1], newBShape[2]});
|
||||
auto c_ = at::bmm_out(out_, a_, b_);
|
||||
return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
|
||||
}
|
||||
|
||||
Tensor transform_0213(const Tensor& a) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
|
||||
return a.permute({0, 2, 1, 3})
|
||||
.contiguous()
|
||||
.view({a.size(0), a.size(2), a.size(1) * a.size(3)});
|
||||
}
|
||||
|
||||
Tensor transform0213_gemm_nt_bias(
|
||||
const Tensor& a,
|
||||
const Tensor& b,
|
||||
const Tensor& c,
|
||||
const Tensor& query) {
|
||||
if (query.is_nested()) {
|
||||
at::Tensor nested_a = _nested_from_padded(
|
||||
a, get_nested_tensor_impl(query)->get_nested_size_tensor(), true);
|
||||
return NestedTensor_times_Tensor_plus_Tensor_addmm(
|
||||
c, nested_a, b.t(), 1, 1);
|
||||
} else {
|
||||
const Tensor a_0213 = transform_0213(a);
|
||||
auto a_ = a_0213.view({a_0213.size(0) * a_0213.size(1), a_0213.size(2)});
|
||||
auto r_ = at::native::linear(a_, b, c);
|
||||
return r_.view({a_0213.size(0), a_0213.size(1), r_.size(1)});
|
||||
}
|
||||
}
|
||||
|
||||
void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
(size_t)t.dim() == shape.size(),
|
||||
"(called from line ",
|
||||
line,
|
||||
") ",
|
||||
"expected ",
|
||||
shape.size(),
|
||||
"-D tensor but got ",
|
||||
t.dim());
|
||||
if (t.is_nested()) {
|
||||
return;
|
||||
}
|
||||
for (auto idx : c10::irange(shape.size())) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
shape[idx] == 0 || t.sizes()[idx] == shape[idx],
|
||||
"(called from line ",
|
||||
line,
|
||||
") ",
|
||||
"expected dim ",
|
||||
idx,
|
||||
" to be ",
|
||||
shape[idx],
|
||||
" but got ",
|
||||
t.sizes()[idx]);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
|
||||
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(
|
||||
const Tensor& qkv,
|
||||
const Tensor& qkv_bias,
|
||||
const int64_t num_head) {
|
||||
auto qkv_ = qkv.is_nested()
|
||||
? c10::MaybeOwned<Tensor>::owned((NestedTensor_to_padded_tensor(qkv, 0)))
|
||||
: c10::MaybeOwned<Tensor>::borrowed(qkv);
|
||||
auto B = qkv_->size(0);
|
||||
auto T = qkv_->size(1);
|
||||
auto _3D = qkv_->size(2);
|
||||
auto D = _3D / 3;
|
||||
TORCH_CHECK(D % num_head == 0);
|
||||
TORCH_CHECK(_3D % 3 == 0);
|
||||
const auto dim_per_head = D / num_head;
|
||||
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_->options());
|
||||
|
||||
const auto qkv_contig = qkv_->expect_contiguous();
|
||||
const auto qkv_bias_contig = qkv_bias.expect_contiguous();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::Half,
|
||||
ScalarType::BFloat16,
|
||||
qkv_->scalar_type(),
|
||||
"transform_bias_rescale_qkv",
|
||||
[&] {
|
||||
scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
|
||||
scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
|
||||
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
|
||||
const scalar_t inv_sqrt_dim_per_head =
|
||||
1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head));
|
||||
|
||||
int64_t grain_size =
|
||||
std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
|
||||
parallel_for(
|
||||
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
|
||||
transform_bias_rescale_qkv_inner_loop(
|
||||
B,
|
||||
T,
|
||||
_3D,
|
||||
D,
|
||||
num_head,
|
||||
dim_per_head,
|
||||
qkv_data,
|
||||
qkv_bias_data,
|
||||
q_k_v_data,
|
||||
inv_sqrt_dim_per_head,
|
||||
begin,
|
||||
end);
|
||||
});
|
||||
});
|
||||
auto q_k_v_s =
|
||||
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
|
||||
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> native_multi_head_attention(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const int64_t embed_dim,
|
||||
const int64_t num_head,
|
||||
const Tensor& qkv_weight,
|
||||
const Tensor& qkv_bias,
|
||||
const Tensor& proj_weight,
|
||||
const Tensor& proj_bias,
|
||||
const c10::optional<Tensor>& mask,
|
||||
bool need_weights,
|
||||
bool average_attn_weights) {
|
||||
// query shape: [B, T, D]
|
||||
// qkv_weight shape: [3 * D, D]
|
||||
|
||||
TORCH_CHECK(
|
||||
!mask || !query.is_nested(),
|
||||
"NestedTensor with mask is not supported yet");
|
||||
const auto D = embed_dim;
|
||||
TORCH_CHECK(
|
||||
query.dim() == 3,
|
||||
"expected 3-D `query`, got ",
|
||||
query.dim(),
|
||||
"-D tensor");
|
||||
TORCH_CHECK(
|
||||
query.is_nested() || query.sizes()[2] == embed_dim,
|
||||
"passed-in embed_dim ",
|
||||
embed_dim,
|
||||
" didn't match last dim of query ",
|
||||
query.sizes()[2]);
|
||||
TORCH_CHECK(
|
||||
key.dim() == 3,
|
||||
"expected 3-D `key`, got ",
|
||||
key.dim(),
|
||||
"-D tensor");
|
||||
TORCH_CHECK(
|
||||
value.dim() == 3,
|
||||
"expected 3-D `value`, got ",
|
||||
value.dim(),
|
||||
"-D tensor");
|
||||
TORCH_CHECK(
|
||||
query.is_nested() || key.is_nested() || value.is_nested() ||
|
||||
(query.sizes() == key.sizes() && key.sizes() == value.sizes()),
|
||||
"expected `query`/`key`/`value` shapes to match");
|
||||
TORCH_CHECK(
|
||||
qkv_weight.dim() == 2,
|
||||
"expected 2-D `qkv_weight`, got ",
|
||||
qkv_weight.dim(),
|
||||
"-D tensor");
|
||||
TORCH_CHECK(
|
||||
D * 3 == qkv_weight.sizes()[0],
|
||||
"expected `qkv_weight` first dim to be 3x embed_dim");
|
||||
TORCH_CHECK(
|
||||
D == qkv_weight.sizes()[1],
|
||||
"expected `qkv_weight` second dim to be embed_Dim");
|
||||
TORCH_CHECK(
|
||||
qkv_bias.dim() == 1,
|
||||
"expected 2-D `qkv_bias`, got ",
|
||||
qkv_bias.dim(),
|
||||
"-D tensor");
|
||||
TORCH_CHECK(
|
||||
qkv_bias.sizes()[0] == 3 * D,
|
||||
"expected `qkv_bias` first dim and first dim of query to be equal");
|
||||
TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`");
|
||||
|
||||
#ifndef NDEBUG
|
||||
const auto B = query.is_nested()
|
||||
? get_nested_tensor_impl(query)->get_nested_size_tensor().size(0)
|
||||
: query.sizes()[0];
|
||||
auto T = query.is_nested() ? 0 : query.sizes()[1];
|
||||
const auto dim_per_head = D / num_head;
|
||||
#endif
|
||||
|
||||
// shape: [B, T, 3 x D]
|
||||
Tensor qkv;
|
||||
|
||||
if (key.is_same(value)) {
|
||||
if (query.is_same(key)) {
|
||||
// self-attention
|
||||
qkv = gemm_nt(query, qkv_weight);
|
||||
} else {
|
||||
// encoder-decoder attention
|
||||
// TODO: is there a more efficient way to set this up?
|
||||
// TODO: can we stay nested insted of using cat? Probably just make a
|
||||
// NestedTensor out of the matmul results or something?
|
||||
auto q_kv_weight_s =
|
||||
at::native::split_with_sizes(qkv_weight, {D, D * 2}, 0);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
q_kv_weight_s.size() == 2,
|
||||
"expected split to produce 2 tensors but it produced ",
|
||||
q_kv_weight_s.size());
|
||||
auto q = gemm_nt(query, q_kv_weight_s[0]);
|
||||
auto kv = gemm_nt(key, q_kv_weight_s[1]);
|
||||
qkv = at::cat({q, kv}, 2);
|
||||
}
|
||||
} else {
|
||||
auto q_k_v_weight_s = at::native::chunk(qkv_weight, 3, 0);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
q_k_v_weight_s.size() == 3,
|
||||
"expected chunk to produce 3 tensors but it produced ",
|
||||
q_k_v_weight_s.size());
|
||||
// TODO: can we stay nested instead of using cat?
|
||||
auto q = gemm_nt(query, q_k_v_weight_s[0]);
|
||||
auto k = gemm_nt(key, q_k_v_weight_s[1]);
|
||||
auto v = gemm_nt(value, q_k_v_weight_s[2]);
|
||||
qkv = at::cat({q, k, v}, 2);
|
||||
}
|
||||
|
||||
if (!qkv.is_nested() && qkv.numel() == 0) {
|
||||
if (query.is_nested()) {
|
||||
return std::make_tuple(Tensor(), Tensor());
|
||||
}
|
||||
return std::make_tuple(at::empty_like(query), Tensor());
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
if (!query.is_nested() || !qkv.is_nested()) {
|
||||
if (query.is_nested()) {
|
||||
T = qkv.size(1);
|
||||
}
|
||||
debug_assert_shape(__LINE__, qkv, {B, T, 3 * D});
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef DEBUG_PRINT_EACH_STEP
|
||||
if (!qkv.is_nested()) {
|
||||
std::cerr << "qkv: " << qkv << std::endl;
|
||||
}
|
||||
#endif
|
||||
// shape: 3 x [B, num_head, T, dim_per_head]
|
||||
auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
||||
qkv = Tensor(); // Not used any more, allow free
|
||||
auto& q = std::get<0>(q_k_v);
|
||||
const auto& k = std::get<1>(q_k_v);
|
||||
const auto& v = std::get<2>(q_k_v);
|
||||
#ifndef NDEBUG
|
||||
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
|
||||
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
|
||||
debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
|
||||
#endif
|
||||
#ifdef DEBUG_PRINT_EACH_STEP
|
||||
std::cerr << "q: " << q << std::endl;
|
||||
std::cerr << "k: " << k << std::endl;
|
||||
std::cerr << "v: " << v << std::endl;
|
||||
#endif
|
||||
|
||||
// shape: [B, num_head, T, T]
|
||||
auto qkt = bmm_nt(q, k);
|
||||
// q & k are dead but cannot be freed because they were packed with v
|
||||
#ifndef NDEBUG
|
||||
debug_assert_shape(__LINE__, qkt, {B, num_head, T, T});
|
||||
#endif
|
||||
#ifdef DEBUG_PRINT_EACH_STEP
|
||||
std::cerr << "qkt: " << qkt << std::endl;
|
||||
#endif
|
||||
|
||||
// shape: [B, num_head, T, T]
|
||||
// TODO: long-term, have a kernel that works with
|
||||
// NestedTensor directly if there is no mask passed
|
||||
qkt = masked_softmax(qkt, mask, query);
|
||||
#ifdef DEBUG_PRINT_EACH_STEP
|
||||
std::cerr << "qkt after softmax: " << qkt << std::endl;
|
||||
#endif
|
||||
|
||||
// shape: [B, num_head, T, dim_per_head]
|
||||
// reuse storage for q; we're done with it
|
||||
auto attn_ctx = bmm_nn(q, qkt, v);
|
||||
// qkv is not dead; we just reused storage for q!
|
||||
if (!need_weights) {
|
||||
qkt = Tensor();
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
|
||||
#endif
|
||||
#ifdef DEBUG_PRINT_EACH_STEP
|
||||
std::cerr << "attn_ctx: " << attn_ctx << std::endl;
|
||||
#endif
|
||||
|
||||
// shape: [B, T, D]
|
||||
// Fuse transform_0213 inside
|
||||
auto proj = transform0213_gemm_nt_bias(
|
||||
attn_ctx, proj_weight, proj_bias, query);
|
||||
#ifndef NDEBUG
|
||||
debug_assert_shape(__LINE__, proj, {B, T, D});
|
||||
#endif
|
||||
if (need_weights && average_attn_weights) {
|
||||
// weights are not needed for full transformer, so don't worry too
|
||||
// much about performance -- we implement this just to make use
|
||||
// cases that don't disable need_weights still get some speedup.
|
||||
qkt = qkt.sum(1);
|
||||
qkt /= num_head;
|
||||
}
|
||||
return std::make_tuple(std::move(proj), std::move(qkt));
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
400
aten/src/ATen/native/transformers/cuda/attention.cu
Normal file
400
aten/src/ATen/native/transformers/cuda/attention.cu
Normal file
@ -0,0 +1,400 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <ATen/TensorAccessor.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
#include <ATen/native/cuda/PersistentSoftmax.cuh>
|
||||
#include <ATen/native/cuda/block_reduce.cuh>
|
||||
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
|
||||
#include <ATen/native/nested/NestedTensorMath.h>
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int TRANSFORM_BIAS_RESCALE_VEC = 4;
|
||||
|
||||
template <typename scalar_t, typename accscalar_t, bool assume_aligned>
|
||||
__global__ void transform_bias_rescale_qkv_kernel(
|
||||
// [B, T, 3 * D]
|
||||
const PackedTensorAccessor64<scalar_t, 3, RestrictPtrTraits> qkv,
|
||||
// [3 * D]
|
||||
const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
|
||||
// [3, B, NH, T, DH]
|
||||
PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v,
|
||||
const scalar_t inv_sqrt_dim_per_head) {
|
||||
// warp per DH.
|
||||
// so launch B * NH * T warps.
|
||||
auto NH = q_k_v.size(2);
|
||||
auto T = q_k_v.size(3);
|
||||
auto DH = q_k_v.size(4);
|
||||
|
||||
auto t = blockIdx.x % T;
|
||||
auto b = blockIdx.x / T;
|
||||
|
||||
auto D = NH * DH;
|
||||
|
||||
if (assume_aligned) {
|
||||
constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
|
||||
using LoadT = memory::aligned_vector<scalar_t, VEC>;
|
||||
for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
|
||||
auto d = d_v * VEC;
|
||||
auto nh = d / DH;
|
||||
auto dh = d % DH;
|
||||
scalar_t qkv_bias_q[VEC];
|
||||
scalar_t qkv_bias_k[VEC];
|
||||
scalar_t qkv_bias_v[VEC];
|
||||
scalar_t qkv_q[VEC];
|
||||
scalar_t qkv_k[VEC];
|
||||
scalar_t qkv_v[VEC];
|
||||
|
||||
// Here we require D % VEC == 0 for these vectorized loads.
|
||||
*reinterpret_cast<LoadT*>(&qkv_bias_q) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
|
||||
*reinterpret_cast<LoadT*>(&qkv_bias_k) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
|
||||
*reinterpret_cast<LoadT*>(&qkv_bias_v) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
|
||||
|
||||
*reinterpret_cast<LoadT*>(&qkv_q) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 0 * D]);
|
||||
*reinterpret_cast<LoadT*>(&qkv_k) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 1 * D]);
|
||||
*reinterpret_cast<LoadT*>(&qkv_v) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 2 * D]);
|
||||
|
||||
#pragma unroll
|
||||
// TODO: specialize for float2half2/half2float2?
|
||||
for (auto ii = 0; ii < VEC; ++ii) {
|
||||
qkv_q[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_q[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_q[ii])) *
|
||||
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
|
||||
qkv_k[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_k[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_k[ii])));
|
||||
qkv_v[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_v[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_v[ii])));
|
||||
}
|
||||
|
||||
// Here we require DH % VEC == 0 for these vectorized stores.
|
||||
*reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_q);
|
||||
*reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_k);
|
||||
*reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_v);
|
||||
}
|
||||
} else {
|
||||
// Same as above, but we can't vectorize memory access.
|
||||
for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
|
||||
auto nh = d / DH;
|
||||
auto dh = d % DH;
|
||||
scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
|
||||
scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
|
||||
scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
|
||||
scalar_t qkv_q = qkv[b][t][d + 0 * D];
|
||||
scalar_t qkv_k = qkv[b][t][d + 1 * D];
|
||||
scalar_t qkv_v = qkv[b][t][d + 2 * D];
|
||||
qkv_q = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_q) +
|
||||
static_cast<accscalar_t>(qkv_bias_q)) *
|
||||
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
|
||||
qkv_k = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_k) +
|
||||
static_cast<accscalar_t>(qkv_bias_k)));
|
||||
qkv_v = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_v) +
|
||||
static_cast<accscalar_t>(qkv_bias_v)));
|
||||
|
||||
q_k_v[0][b][nh][t][dh] = qkv_q;
|
||||
q_k_v[1][b][nh][t][dh] = qkv_k;
|
||||
q_k_v[2][b][nh][t][dh] = qkv_v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t, bool assume_aligned = false>
|
||||
__global__ void transform_bias_rescale_qkv_add_padding_kernel(
|
||||
// [B, T, 3 * D], but it's a NestedTensor buffer
|
||||
const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv,
|
||||
// [3 * D]
|
||||
const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
|
||||
const int* offsets,
|
||||
const int* input_sizes,
|
||||
// [3, B, NH, T, DH]
|
||||
PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v,
|
||||
const scalar_t inv_sqrt_dim_per_head) {
|
||||
// warp per DH.
|
||||
// so launch B * NH * T warps.
|
||||
const auto NH = q_k_v.size(2);
|
||||
const auto T = q_k_v.size(3);
|
||||
const auto DH = q_k_v.size(4);
|
||||
|
||||
const auto t = blockIdx.x % T;
|
||||
const auto b = blockIdx.x / T;
|
||||
|
||||
const auto D = NH * DH;
|
||||
const auto _3D = 3 * D;
|
||||
|
||||
const auto offset_for_batch = offsets[b];
|
||||
const auto input_dim = 1;
|
||||
const auto* sizes_i = input_sizes + b * input_dim;
|
||||
if (assume_aligned) {
|
||||
constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
|
||||
using LoadT = memory::aligned_vector<scalar_t, VEC>;
|
||||
for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
|
||||
auto d = d_v * VEC;
|
||||
auto nh = d / DH;
|
||||
auto dh = d % DH;
|
||||
scalar_t qkv_bias_q[VEC];
|
||||
scalar_t qkv_bias_k[VEC];
|
||||
scalar_t qkv_bias_v[VEC];
|
||||
scalar_t qkv_q[VEC];
|
||||
scalar_t qkv_k[VEC];
|
||||
scalar_t qkv_v[VEC];
|
||||
|
||||
const auto first_item_offset = t * _3D + d;
|
||||
const auto last_item_offset = first_item_offset + VEC - 1;
|
||||
const bool first_item_in_bounds = first_item_offset < sizes_i[0];
|
||||
const bool entire_vec_in_bounds = last_item_offset < sizes_i[0];
|
||||
|
||||
// Here we require D % VEC == 0 for these vectorized loads.
|
||||
*reinterpret_cast<LoadT*>(&qkv_bias_q) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
|
||||
*reinterpret_cast<LoadT*>(&qkv_bias_k) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
|
||||
*reinterpret_cast<LoadT*>(&qkv_bias_v) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
|
||||
|
||||
if (entire_vec_in_bounds) {
|
||||
const auto offset = offset_for_batch + first_item_offset;
|
||||
*reinterpret_cast<LoadT*>(&qkv_q) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv[offset + 0 * D]);
|
||||
*reinterpret_cast<LoadT*>(&qkv_k) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv[offset + 1 * D]);
|
||||
*reinterpret_cast<LoadT*>(&qkv_v) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv[offset + 2 * D]);
|
||||
#pragma unroll
|
||||
// TODO: specialize for float2half2/half2float2?
|
||||
for (auto ii = 0; ii < VEC; ++ii) {
|
||||
qkv_q[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_q[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_q[ii])) *
|
||||
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
|
||||
qkv_k[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_k[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_k[ii])));
|
||||
qkv_v[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_v[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_v[ii])));
|
||||
}
|
||||
} else if (first_item_in_bounds) {
|
||||
const auto offset = offset_for_batch + first_item_offset;
|
||||
qkv_q[0] = qkv[offset + 0 * D];
|
||||
qkv_k[0] = qkv[offset + 1 * D];
|
||||
qkv_v[0] = qkv[offset + 2 * D];
|
||||
qkv_q[0] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_q[0]) +
|
||||
static_cast<accscalar_t>(qkv_bias_q[0])) *
|
||||
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
|
||||
qkv_k[0] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_k[0]) +
|
||||
static_cast<accscalar_t>(qkv_bias_k[0])));
|
||||
qkv_v[0] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_v[0]) +
|
||||
static_cast<accscalar_t>(qkv_bias_v[0])));
|
||||
#pragma unroll
|
||||
for (auto ii = 1; ii < VEC; ++ii) {
|
||||
const auto loop_offset = offset + ii;
|
||||
if (loop_offset < sizes_i[0]) {
|
||||
qkv_q[ii] = qkv[loop_offset + 0 * D];
|
||||
qkv_k[ii] = qkv[loop_offset + 1 * D];
|
||||
qkv_v[ii] = qkv[loop_offset + 2 * D];
|
||||
qkv_q[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_q[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_q[ii])) *
|
||||
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
|
||||
qkv_k[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_k[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_k[ii])));
|
||||
qkv_v[ii] = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_v[ii]) +
|
||||
static_cast<accscalar_t>(qkv_bias_v[ii])));
|
||||
} else {
|
||||
qkv_q[ii] = 0;
|
||||
qkv_k[ii] = 0;
|
||||
qkv_v[ii] = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (auto ii = 0; ii < VEC; ++ii) {
|
||||
qkv_q[ii] = 0;
|
||||
qkv_k[ii] = 0;
|
||||
qkv_v[ii] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Here we require DH % VEC == 0 for these vectorized stores.
|
||||
*reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_q);
|
||||
*reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_k);
|
||||
*reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
|
||||
*reinterpret_cast<const LoadT*>(&qkv_v);
|
||||
}
|
||||
} else {
|
||||
for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
|
||||
auto nh = d / DH;
|
||||
auto dh = d % DH;
|
||||
scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
|
||||
scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
|
||||
scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
|
||||
|
||||
const auto item_offset = t * _3D + d;
|
||||
const bool in_bounds = item_offset < sizes_i[0];
|
||||
scalar_t qkv_q, qkv_k, qkv_v;
|
||||
if (in_bounds) {
|
||||
const auto qkv_offset = offset_for_batch + item_offset;
|
||||
qkv_q = qkv[qkv_offset + 0 * D];
|
||||
qkv_k = qkv[qkv_offset + 1 * D];
|
||||
qkv_v = qkv[qkv_offset + 2 * D];
|
||||
qkv_q = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_q) +
|
||||
static_cast<accscalar_t>(qkv_bias_q)) *
|
||||
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
|
||||
qkv_k = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_k) +
|
||||
static_cast<accscalar_t>(qkv_bias_k)));
|
||||
qkv_v = static_cast<scalar_t>(
|
||||
(static_cast<accscalar_t>(qkv_v) +
|
||||
static_cast<accscalar_t>(qkv_bias_v)));
|
||||
} else {
|
||||
qkv_q = 0;
|
||||
qkv_k = 0;
|
||||
qkv_v = 0;
|
||||
}
|
||||
|
||||
q_k_v[0][b][nh][t][dh] = qkv_q;
|
||||
q_k_v[1][b][nh][t][dh] = qkv_k;
|
||||
q_k_v[2][b][nh][t][dh] = qkv_v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor collapse_dims_1_and_2(const Tensor& sizes) {
|
||||
auto sizes_dim1 = at::native::narrow(sizes, 1, 0, 1);
|
||||
auto sizes_dim2 = at::native::narrow(sizes, 1, 1, 1);
|
||||
|
||||
return (sizes_dim1 * sizes_dim2).contiguous();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
|
||||
__host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
|
||||
const Tensor& qkv,
|
||||
const Tensor& qkv_bias,
|
||||
const int64_t num_head) {
|
||||
auto B = qkv.is_nested()
|
||||
? get_nested_tensor_impl(qkv)->get_nested_size_tensor().size(0)
|
||||
: qkv.size(0);
|
||||
// TODO: calculate this without the std::vector -- NestedTensor_to_mask wants
|
||||
// this too
|
||||
auto T = qkv.is_nested()
|
||||
? NestedTensor_get_max_size(*get_nested_tensor_impl(qkv))[0]
|
||||
: qkv.size(1);
|
||||
auto _3D = qkv_bias.size(0);
|
||||
auto D = _3D / 3;
|
||||
TORCH_CHECK(D % num_head == 0);
|
||||
const auto dim_per_head = D / num_head;
|
||||
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_bias.options());
|
||||
#define CALL_KERNEL(assume_aligned) \
|
||||
transform_bias_rescale_qkv_kernel<scalar_t, accscalar_t, assume_aligned> \
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
||||
qkv.packed_accessor64<scalar_t, 3, RestrictPtrTraits>(), \
|
||||
qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
|
||||
q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>(), \
|
||||
1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head)))
|
||||
#define CALL_ADD_PADDING_KERNEL(assume_aligned) \
|
||||
transform_bias_rescale_qkv_add_padding_kernel< \
|
||||
scalar_t, \
|
||||
accscalar_t, \
|
||||
assume_aligned> \
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
||||
nt_qkv->get_buffer() \
|
||||
.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
|
||||
qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
|
||||
offsets_ptr, \
|
||||
sizes_ptr, \
|
||||
q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>(), \
|
||||
1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head)))
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
ScalarType::Half,
|
||||
ScalarType::BFloat16,
|
||||
qkv.scalar_type(),
|
||||
"transform_bias_rescale_qkv",
|
||||
[&] {
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
auto threads = std::max(
|
||||
std::min<int32_t>(1024, D / TRANSFORM_BIAS_RESCALE_VEC), 1);
|
||||
auto blocks = B * T;
|
||||
const bool aligned =
|
||||
((dim_per_head % TRANSFORM_BIAS_RESCALE_VEC) == 0) &&
|
||||
((reinterpret_cast<intptr_t>(qkv_bias.data_ptr()) %
|
||||
TRANSFORM_BIAS_RESCALE_VEC) == 0);
|
||||
if (aligned) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
D % TRANSFORM_BIAS_RESCALE_VEC == 0,
|
||||
"D = num_heads * dim_per_head, so we should have dim_per_head % "
|
||||
"TRANSFORM_BIAS_RESCALE_VEC == 0 => "
|
||||
"D % TRANSFORM_BIAS_RESCALE_VEC == 0");
|
||||
}
|
||||
if (qkv.is_nested()) {
|
||||
auto* nt_qkv = get_nested_tensor_impl(qkv);
|
||||
auto sizes = collapse_dims_1_and_2(nt_qkv->get_nested_size_tensor());
|
||||
auto offsets =
|
||||
NestedTensor_batch_offsets_from_size_tensor(sizes, sizes.numel());
|
||||
at::native::narrow(offsets, 0, sizes.numel() + 1, sizes.numel())
|
||||
.copy_(sizes.reshape({-1}));
|
||||
auto metadata = offsets.to(at::Device(kCUDA), at::kInt, true, true);
|
||||
const auto offsets_ptr = metadata.data_ptr<int>();
|
||||
const auto sizes_ptr = offsets_ptr + sizes.numel() + 1;
|
||||
const auto input_dim = sizes.sizes()[1];
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input_dim == 1);
|
||||
if (aligned &&
|
||||
((reinterpret_cast<intptr_t>(nt_qkv->get_buffer().data_ptr()) %
|
||||
TRANSFORM_BIAS_RESCALE_VEC) == 0)) {
|
||||
CALL_ADD_PADDING_KERNEL(true);
|
||||
} else {
|
||||
CALL_ADD_PADDING_KERNEL(false);
|
||||
}
|
||||
} else if (aligned) {
|
||||
CALL_KERNEL(true);
|
||||
} else {
|
||||
CALL_KERNEL(false);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
#undef CALL_ADD_PADDING_KERNEL
|
||||
#undef CALL_KERNEL
|
||||
auto q_k_v_s =
|
||||
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
|
||||
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
|
||||
}
|
||||
} // namespace native
|
||||
} // namespace at
|
137
aten/src/ATen/native/transformers/transformer.cpp
Normal file
137
aten/src/ATen/native/transformers/transformer.cpp
Normal file
@ -0,0 +1,137 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
Tensor linear_for_ffn(
|
||||
const Tensor& bias,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
c10::optional<bool> use_gelu) {
|
||||
if (mat1.is_nested()) {
|
||||
return NestedTensor_times_Tensor_plus_Tensor_addmm(
|
||||
bias, mat1, mat2.t(), 1, 1, use_gelu);
|
||||
}
|
||||
|
||||
auto mat1_ = mat1.view({mat1.sizes()[0] * mat1.sizes()[1], mat1.sizes()[2]});
|
||||
Tensor result;
|
||||
if (use_gelu.has_value()) {
|
||||
result = at::_addmm_activation(bias, mat1_, mat2.t(), 1, 1, *use_gelu);
|
||||
} else {
|
||||
result = at::addmm(bias, mat1_, mat2.t());
|
||||
}
|
||||
return result.view({mat1.sizes()[0], mat1.sizes()[1], -1});
|
||||
}
|
||||
|
||||
Tensor ffn(
|
||||
const Tensor& input,
|
||||
const Tensor& w1,
|
||||
const Tensor& b1,
|
||||
const Tensor& w2,
|
||||
const Tensor& b2,
|
||||
bool use_gelu,
|
||||
bool add_norm) {
|
||||
TORCH_CHECK(add_norm == false, "TODO add_norm to be supported in FFN");
|
||||
TORCH_CHECK(input.dim() == 3, "batched input size should be 3");
|
||||
TORCH_CHECK(w1.dim() == 2, "2d weights expected");
|
||||
TORCH_CHECK(w2.dim() == 2, "2d weights expected");
|
||||
Tensor res = linear_for_ffn(b1, input, w1, use_gelu);
|
||||
res = linear_for_ffn(b2, res, w2, c10::nullopt);
|
||||
return res;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Tensor transformer_encoder_layer_forward(
|
||||
const Tensor& src,
|
||||
const int64_t embed_dim,
|
||||
const int64_t num_heads,
|
||||
const Tensor& qkv_weight,
|
||||
const Tensor& qkv_bias,
|
||||
const Tensor& proj_weight,
|
||||
const Tensor& proj_bias,
|
||||
const bool use_gelu,
|
||||
const bool norm_first,
|
||||
const double layer_norm_eps,
|
||||
const Tensor& layer_norm_weight_1,
|
||||
const Tensor& layer_norm_bias_1,
|
||||
const Tensor& layer_norm_weight_2,
|
||||
const Tensor& layer_norm_bias_2,
|
||||
const Tensor& ffn_weight_1,
|
||||
const Tensor& ffn_bias_1,
|
||||
const Tensor& ffn_weight_2,
|
||||
const Tensor& ffn_bias_2,
|
||||
const c10::optional<Tensor>& mask) {
|
||||
{
|
||||
const Tensor& check_for_empty = src.is_nested() ? get_nested_tensor_impl(src)->get_buffer() : src;
|
||||
if (check_for_empty.numel() == 0) {
|
||||
return src.is_nested()
|
||||
? at::detail::make_tensor<NestedTensorImpl>(check_for_empty, get_nested_tensor_impl(src)->get_nested_size_tensor())
|
||||
: src.clone();
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(!norm_first, "norm_first is not supported yet");
|
||||
const bool use_nested_tensor = src.is_nested();
|
||||
auto x = std::get<0>(native_multi_head_attention(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
qkv_weight,
|
||||
qkv_bias,
|
||||
proj_weight,
|
||||
proj_bias,
|
||||
mask,
|
||||
false /* need_weights */));
|
||||
if (use_nested_tensor) {
|
||||
NestedTensor_add_NestedTensor_in_place(x, src);
|
||||
x = NestedTensor_layer_norm(
|
||||
x, layer_norm_weight_1, layer_norm_bias_1, layer_norm_eps);
|
||||
} else {
|
||||
x.add_(src);
|
||||
x = at::layer_norm(
|
||||
x,
|
||||
{embed_dim},
|
||||
layer_norm_weight_1,
|
||||
layer_norm_bias_1,
|
||||
layer_norm_eps,
|
||||
true);
|
||||
}
|
||||
|
||||
auto pre_ffn_res = x;
|
||||
x = ffn(
|
||||
x,
|
||||
ffn_weight_1,
|
||||
ffn_bias_1,
|
||||
ffn_weight_2,
|
||||
ffn_bias_2,
|
||||
use_gelu,
|
||||
/* add_norm* */ false);
|
||||
if (use_nested_tensor) {
|
||||
NestedTensor_add_NestedTensor_in_place(x, pre_ffn_res);
|
||||
x = NestedTensor_layer_norm(
|
||||
x, layer_norm_weight_2, layer_norm_bias_2, layer_norm_eps);
|
||||
} else {
|
||||
x.add_(pre_ffn_res);
|
||||
x = at::layer_norm(
|
||||
x,
|
||||
{embed_dim},
|
||||
layer_norm_weight_2,
|
||||
layer_norm_bias_2,
|
||||
layer_norm_eps,
|
||||
true);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
306
test/test_native_mha.py
Normal file
306
test/test_native_mha.py
Normal file
@ -0,0 +1,306 @@
|
||||
# Owner(s): ["module: nn"]
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
skipMeta,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
class TestMHADeviceType(TestCase):
|
||||
@torch.no_grad()
|
||||
def _test_transform_bias_rescale_qkv_impl(
|
||||
self, device, dtype, use_nt, use_padding=False
|
||||
):
|
||||
tests = [
|
||||
(64, 4, 16, 8),
|
||||
# dim_per_head = 12 does not divide evenly by CPU vectorization length of 8
|
||||
(24, 2, 4, 2),
|
||||
# Make sure CUDA can handle small input sizes
|
||||
(2, 2, 2, 2),
|
||||
# dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4,
|
||||
# causes alignment issues
|
||||
(24, 4, 4, 2),
|
||||
(48, 4, 16, 8),
|
||||
]
|
||||
for (embed_dim, num_heads, bs, sl) in tests:
|
||||
with self.subTest(embed_dim=embed_dim, num_heads=num_heads, bs=bs, sl=sl):
|
||||
torch.manual_seed(9343)
|
||||
dense_x = x = (
|
||||
torch.randn(bs, sl, 3 * embed_dim, device=device, dtype=dtype) * 10
|
||||
)
|
||||
if use_padding:
|
||||
x[0][-1] = torch.full(x[0][-1].shape, float("-Inf"))
|
||||
if use_nt:
|
||||
xs = list(torch.unbind(x))
|
||||
if use_padding:
|
||||
xs[0] = xs[0][:-1]
|
||||
x = torch.nested_tensor(xs, device=device, dtype=dtype)
|
||||
qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
||||
|
||||
# We have to use inference_mode here because q/k/v are
|
||||
# all views of the same Tensor, which autograd doesn't
|
||||
# like. This is fine because this function is only
|
||||
# exposed to Python for purposes of writing this test.
|
||||
with torch.inference_mode():
|
||||
(q, k, v) = torch._transform_bias_rescale_qkv(
|
||||
x, qkv.bias, num_heads=num_heads
|
||||
)
|
||||
|
||||
def simple_transform_bias_rescale_qkv(qkv, bias):
|
||||
(q, k, v) = torch.split(qkv, embed_dim, dim=-1)
|
||||
(q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1)
|
||||
return tuple(
|
||||
x.reshape(
|
||||
(bs, sl, num_heads, embed_dim // num_heads)
|
||||
).transpose(2, 1)
|
||||
for x in (
|
||||
(q + q_bias) / math.sqrt(embed_dim // num_heads),
|
||||
(k + k_bias),
|
||||
(v + v_bias),
|
||||
)
|
||||
)
|
||||
|
||||
correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv(
|
||||
dense_x, qkv.bias
|
||||
)
|
||||
if use_nt and use_padding:
|
||||
for t in (correct_q, correct_k, correct_v):
|
||||
t[t == float("-Inf")] = 0
|
||||
|
||||
self.assertEqual(q.size(), correct_q.size())
|
||||
torch.testing.assert_close(q, correct_q)
|
||||
torch.testing.assert_close(k, correct_k)
|
||||
torch.testing.assert_close(v, correct_v)
|
||||
|
||||
@dtypesIfCUDA(torch.float)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
def test_transform_bias_rescale_qkv(self, device, dtype):
|
||||
for use_padding in (False, True):
|
||||
with self.subTest(use_padding=use_padding):
|
||||
self._test_transform_bias_rescale_qkv_impl(
|
||||
device, dtype, use_nt=False, use_padding=use_padding
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.float)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
@onlyCUDA
|
||||
def test_transform_bias_rescale_qkv_nested(self, device, dtype):
|
||||
for use_padding in (False, True):
|
||||
with self.subTest(use_padding=use_padding):
|
||||
self._test_transform_bias_rescale_qkv_impl(
|
||||
device, dtype, use_nt=True, use_padding=use_padding
|
||||
)
|
||||
|
||||
def _test_multihead_attention_impl(
|
||||
self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False
|
||||
):
|
||||
embed_dim = 64
|
||||
num_heads = 4
|
||||
bs = 16
|
||||
sl = 8
|
||||
|
||||
q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
|
||||
if use_padding:
|
||||
if pad_all:
|
||||
for q_i in q:
|
||||
q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=dtype)
|
||||
mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool)
|
||||
for mask_i in mask:
|
||||
mask_i[-1] = True
|
||||
else:
|
||||
q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=dtype)
|
||||
mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool)
|
||||
mask[0][-1] = True
|
||||
if mode == "self":
|
||||
k = q
|
||||
v = q
|
||||
elif mode == "encdec":
|
||||
k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
|
||||
v = k
|
||||
elif mode == "generic":
|
||||
k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
|
||||
v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
|
||||
else:
|
||||
self.fail(f"invalid mode `{mode}`!")
|
||||
|
||||
qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
||||
proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
pt = torch.nn.MultiheadAttention(
|
||||
embed_dim, num_heads, batch_first=True, device=device, dtype=dtype
|
||||
)
|
||||
pt.in_proj_weight = qkv.weight
|
||||
pt.in_proj_bias = qkv.bias
|
||||
pt.out_proj.weight = proj.weight
|
||||
pt.out_proj.bias = proj.bias
|
||||
|
||||
class NativeMHA(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, qkv, proj):
|
||||
super().__init__()
|
||||
self.qkv = qkv
|
||||
self.proj = proj
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, q, k, v, key_padding_mask):
|
||||
return torch._native_multi_head_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.qkv.weight,
|
||||
self.qkv.bias,
|
||||
self.proj.weight,
|
||||
self.proj.bias,
|
||||
key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
average_attn_weights=average_attn_weights,
|
||||
)
|
||||
|
||||
npt = NativeMHA(
|
||||
embed_dim=embed_dim, num_heads=num_heads, qkv=qkv, proj=proj
|
||||
).to(dtype)
|
||||
|
||||
if device == "cuda":
|
||||
pt = pt.cuda()
|
||||
npt = npt.cuda()
|
||||
|
||||
ypt, weight_pt = pt(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
need_weights=need_weights,
|
||||
average_attn_weights=average_attn_weights,
|
||||
key_padding_mask=mask if use_padding else None,
|
||||
)
|
||||
if use_nt:
|
||||
qs = list(torch.unbind(q))
|
||||
if use_padding:
|
||||
if pad_all:
|
||||
qs = [x[:-1] for x in qs]
|
||||
else:
|
||||
qs[0] = qs[0][:-1]
|
||||
q = torch.nested_tensor(qs, device=device, dtype=dtype)
|
||||
if mode == "self":
|
||||
k = v = q
|
||||
elif mode == "encdec":
|
||||
k = torch.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
|
||||
v = k
|
||||
else:
|
||||
k = torch.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
|
||||
v = torch.nested_tensor(torch.unbind(v), device=device, dtype=dtype)
|
||||
|
||||
ynpt, weight_npt = npt(
|
||||
q, k, v, key_padding_mask=mask if use_padding and not use_nt else None
|
||||
)
|
||||
if use_nt:
|
||||
ynpt = ynpt.to_padded_tensor(0)
|
||||
if pad_all:
|
||||
ynpt_final = torch.zeros_like(ypt)
|
||||
ynpt_final[:, :ynpt.shape[1], :] = ynpt
|
||||
ynpt = ynpt_final
|
||||
|
||||
def do_pad_all(tensors):
|
||||
for t in tensors:
|
||||
for t_i in t:
|
||||
t_i[-1] = torch.zeros_like(t_i[-1], device=device, dtype=dtype)
|
||||
|
||||
# PyTorch implementation returns non-zero junk in the padding
|
||||
# locations; overwrite it so that the comparison works out.
|
||||
if use_padding:
|
||||
ypt[0][-1] = torch.zeros_like(ypt[0][-1], device=device, dtype=dtype)
|
||||
ynpt[0][-1] = torch.zeros_like(ynpt[0][-1], device=device, dtype=dtype)
|
||||
if pad_all:
|
||||
do_pad_all((ypt, ynpt))
|
||||
# Zero the last row of each TxT weight matrix
|
||||
if need_weights:
|
||||
if average_attn_weights:
|
||||
weight_pt[0][-1] = torch.zeros_like(weight_pt[0][-1], device=device, dtype=dtype)
|
||||
weight_npt[0][-1] = torch.zeros_like(weight_npt[0][-1], device=device, dtype=dtype)
|
||||
if pad_all:
|
||||
do_pad_all((weight_pt, weight_npt))
|
||||
else:
|
||||
for nh in range(num_heads):
|
||||
weight_pt[0][nh][-1] = torch.zeros_like(weight_pt[0][nh][-1], device=device, dtype=dtype)
|
||||
weight_npt[0][nh][-1] = torch.zeros_like(weight_npt[0][nh][-1], device=device, dtype=dtype)
|
||||
|
||||
if dtype == torch.half:
|
||||
torch.testing.assert_close(ypt, ynpt, atol=1e-3, rtol=1e-3)
|
||||
else:
|
||||
# High rtol seems necessary for
|
||||
# test_native_multihead_attention_cpu_float32 on Windows,
|
||||
# otherwise 2e-4 would likely be fine.
|
||||
torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3)
|
||||
|
||||
if need_weights:
|
||||
torch.testing.assert_close(weight_pt, weight_npt)
|
||||
else:
|
||||
self.assertEqual(weight_pt, weight_npt)
|
||||
|
||||
@dtypesIfCUDA(torch.float, torch.half)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
@torch.no_grad()
|
||||
def test_native_multihead_self_attention(self, device, dtype):
|
||||
for (use_padding, pad_all) in ((False, False), (True, False), (True, True)):
|
||||
for use_nt in (False, True):
|
||||
# Figuring out exactly which elements of the weights are garbage in this
|
||||
# case eludes me, and it's not particularly enlightening to test anyway
|
||||
# because padding doesn't especially affect the intermediate weights.
|
||||
for need_weights in (False, not pad_all):
|
||||
for average_attn_weights in (False, True):
|
||||
with self.subTest(use_padding=use_padding, pad_all=pad_all,
|
||||
use_nt=use_nt, need_weights=need_weights,
|
||||
average_attn_weights=average_attn_weights):
|
||||
self._test_multihead_attention_impl(
|
||||
device,
|
||||
dtype,
|
||||
"self",
|
||||
use_nt=use_nt,
|
||||
use_padding=use_padding,
|
||||
pad_all=pad_all,
|
||||
need_weights=need_weights,
|
||||
average_attn_weights=average_attn_weights,
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.float, torch.half)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
@torch.no_grad()
|
||||
def test_native_multihead_encoder_decoder_attention(self, device, dtype):
|
||||
self._test_multihead_attention_impl(
|
||||
device,
|
||||
dtype,
|
||||
"encdec",
|
||||
use_nt=False,
|
||||
need_weights=False,
|
||||
average_attn_weights=False,
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.float, torch.half)
|
||||
@dtypes(torch.float)
|
||||
@skipMeta
|
||||
@torch.no_grad()
|
||||
def test_native_multihead_attention(self, device, dtype):
|
||||
self._test_multihead_attention_impl(
|
||||
device,
|
||||
dtype,
|
||||
"generic",
|
||||
use_nt=False,
|
||||
need_weights=False,
|
||||
average_attn_weights=False,
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestMHADeviceType, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
293
test/test_nn.py
293
test/test_nn.py
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: nn"]
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
@ -38,8 +39,8 @@ from torch.nn.parallel._functions import Broadcast
|
||||
from torch.testing._internal.common_dtype import integral_types, floating_types_and, get_all_math_dtypes, \
|
||||
floating_and_complex_types_and
|
||||
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
||||
skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
|
||||
get_function_arglist, load_tests, \
|
||||
skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \
|
||||
download_file, get_function_arglist, load_tests, \
|
||||
suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC, \
|
||||
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, set_default_dtype
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
|
||||
@ -5765,9 +5766,7 @@ class TestNN(NNTestCase):
|
||||
self.assertIsNone(mha.in_proj_bias)
|
||||
self.assertIsNone(mha.out_proj.bias)
|
||||
|
||||
def test_multihead_attn_invalid_shape(self):
|
||||
mha = torch.nn.MultiheadAttention(3, 3)
|
||||
|
||||
def _test_multihead_attn_invalid_shape_impl(self, mha):
|
||||
# Batched (3D) query cases
|
||||
query = torch.randn(3, 3, 3)
|
||||
key = torch.randn(3, 3, 3)
|
||||
@ -5823,6 +5822,113 @@ class TestNN(NNTestCase):
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, key, value, attn_mask=torch.randn(4, 3, 3).bernoulli_().to(torch.bool))
|
||||
|
||||
def test_multihead_attn_invalid_shape(self):
|
||||
mha = torch.nn.MultiheadAttention(3, 3)
|
||||
self._test_multihead_attn_invalid_shape_impl(mha)
|
||||
# Give the test a chance to hit the fast path. (Right now, it
|
||||
# won't, but gating may be less restricted in the future.)
|
||||
with torch.no_grad():
|
||||
self._test_multihead_attn_invalid_shape_impl(mha.eval())
|
||||
|
||||
@torch.no_grad()
|
||||
def test_multihead_attn_fast_path_invalid_shape(self):
|
||||
mha = torch.nn.MultiheadAttention(3, 3, batch_first=True).eval()
|
||||
|
||||
# Batched (3D) query cases
|
||||
query = torch.randn(3, 3, 3)
|
||||
key = torch.randn(3, 3, 3)
|
||||
value = torch.randn(3, 3, 3)
|
||||
|
||||
# Currently, this case will just go to the slow path and get
|
||||
# the usual message because it fails the requirement to be
|
||||
# batched.
|
||||
msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively"
|
||||
# 3D query, 2D key and 3D value
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, torch.randn(3, 3), value, need_weights=False)
|
||||
|
||||
# Currently, this case will just go to the slow path and get
|
||||
# the usual message because it fails the requirement to be
|
||||
# batched.
|
||||
msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively"
|
||||
# 3D query, 3D key and 2D value
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, key, torch.randn(3, 3), need_weights=False)
|
||||
|
||||
msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead"
|
||||
# 3D query, 3D key, 3D value and 1D key_padding_mask
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, key, value, key_padding_mask=torch.tensor([False, True, True], dtype=torch.bool), need_weights=False)
|
||||
|
||||
msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead"
|
||||
# 3D query, 3D key, 3D value and 1D attn_mask
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, key, value, attn_mask=torch.tensor([False, True, True], dtype=torch.bool), need_weights=False)
|
||||
|
||||
# Unbatched (2D) query cases
|
||||
# NOTE: error messages are the same as regular path because the fast path doesn't support 2D.
|
||||
query = torch.randn(3, 3)
|
||||
key = torch.randn(3, 3)
|
||||
value = torch.randn(3, 3)
|
||||
|
||||
msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively"
|
||||
# 2D query, 3D key and 2D value
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, torch.randn(3, 3, 3), value)
|
||||
|
||||
msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively"
|
||||
# 2D query, 3D key and 2D value
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, key, torch.randn(3, 3, 3))
|
||||
|
||||
msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead"
|
||||
# 2D query, 2D key, 2D value and 1D key_padding_mask
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, key, value, key_padding_mask=torch.tensor([[False, True, True] * 2], dtype=torch.bool))
|
||||
|
||||
msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead"
|
||||
# 2D query, 2D key, 2D value and 1D attn_mask
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, key, value, attn_mask=torch.tensor([False, True, True], dtype=torch.bool))
|
||||
|
||||
msg = r"Expected `attn_mask` shape to be \(3, 3, 3\)"
|
||||
# 2D query, 2D key, 2D value and 3D incorrect attn_mask
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(query, key, value, attn_mask=torch.randn(4, 3, 3).bernoulli_().to(torch.bool))
|
||||
|
||||
def test_multihead_attn_nested_tensor_outside_fast_path(self):
|
||||
mha = torch.nn.MultiheadAttention(3, 3, batch_first=True).eval()
|
||||
nt = torch.nested_tensor([torch.randn(3, 3)])
|
||||
# One tested platform (linux-bionic-py3.7-clang) has a torch_function for one
|
||||
# or more of these. Take advantage of that to test the torch_function bailout.
|
||||
has_torch_func = torch.overrides.has_torch_function(
|
||||
(nt, mha.in_proj_weight, mha.in_proj_bias, mha.out_proj.weight, mha.out_proj.bias))
|
||||
if has_torch_func:
|
||||
msg = "MultiheadAttention does not support NestedTensor.*argument has_torch_function"
|
||||
else:
|
||||
msg = ("MultiheadAttention does not support NestedTensor outside of its fast path.*grad is " +
|
||||
"enabled and.*or biases requires_grad")
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(nt, nt, nt)
|
||||
|
||||
if has_torch_func:
|
||||
# Just give up, they're all going to fail with the same message.
|
||||
return
|
||||
|
||||
with torch.no_grad():
|
||||
mha(nt, nt, nt)
|
||||
with torch.inference_mode():
|
||||
mha(nt, nt, nt)
|
||||
nt = torch.nested_tensor([torch.randn(3, 3, requires_grad=False)])
|
||||
nt.requires_grad = False
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
mha(nt, nt, nt)
|
||||
mha.in_proj_weight.requires_grad = False
|
||||
mha.in_proj_bias.requires_grad = False
|
||||
mha.out_proj.weight.requires_grad = False
|
||||
mha.out_proj.bias.requires_grad = False
|
||||
mha(nt, nt, nt)
|
||||
|
||||
def test_normalize(self):
|
||||
inputs = torch.randn(1, 3, 4, 4, requires_grad=True)
|
||||
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
|
||||
@ -7718,7 +7824,7 @@ class TestNN(NNTestCase):
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
for batch_first in (True, False):
|
||||
def _test(batch_first, training):
|
||||
def perm_fn(x):
|
||||
return x.transpose(1, 0) if batch_first else x
|
||||
|
||||
@ -7726,6 +7832,8 @@ class TestNN(NNTestCase):
|
||||
batch_first=batch_first)
|
||||
|
||||
model = nn.TransformerEncoder(encoder_layer, 1).to(device)
|
||||
if not training:
|
||||
model = model.eval()
|
||||
|
||||
# deterministic input
|
||||
encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
|
||||
@ -7779,6 +7887,8 @@ class TestNN(NNTestCase):
|
||||
|
||||
# test case 2, multiple layers no norm
|
||||
model = nn.TransformerEncoder(encoder_layer, 2).to(device)
|
||||
if not training:
|
||||
model = model.eval()
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
|
||||
[2.419102, 0.017452, -0.608703, -0.085026]],
|
||||
@ -7795,6 +7905,8 @@ class TestNN(NNTestCase):
|
||||
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
||||
|
||||
model = nn.TransformerEncoder(encoder_layer, 6).to(device)
|
||||
if not training:
|
||||
model = model.eval()
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
|
||||
[2.419101, 0.017453, -0.608704, -0.085025]],
|
||||
@ -7814,6 +7926,8 @@ class TestNN(NNTestCase):
|
||||
# d_model = 4
|
||||
norm = nn.LayerNorm(4)
|
||||
model = nn.TransformerEncoder(encoder_layer, 2, norm=norm).to(device)
|
||||
if not training:
|
||||
model = model.eval()
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
|
||||
[1.695955, -0.357639, -0.893050, -0.445266]],
|
||||
@ -7830,6 +7944,8 @@ class TestNN(NNTestCase):
|
||||
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
||||
|
||||
model = nn.TransformerEncoder(encoder_layer, 6, norm=norm).to(device)
|
||||
if not training:
|
||||
model = model.eval()
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
|
||||
[1.695955, -0.357639, -0.893051, -0.445265]],
|
||||
@ -7844,7 +7960,15 @@ class TestNN(NNTestCase):
|
||||
)).to(device)
|
||||
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
||||
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
|
||||
|
||||
for batch_first in (True, False):
|
||||
for training in (True, False):
|
||||
# Fast path requires inference mode.
|
||||
if training:
|
||||
cm = contextlib.nullcontext()
|
||||
else:
|
||||
cm = torch.no_grad()
|
||||
with cm:
|
||||
_test(batch_first, training)
|
||||
|
||||
def test_transformerdecoder(self):
|
||||
def get_a_test_layer(use_cuda, activation, batch_first=False):
|
||||
@ -12921,17 +13045,20 @@ class TestNNDeviceType(NNTestCase):
|
||||
output.sum().backward()
|
||||
self.assertEqualTypeString(output, input)
|
||||
|
||||
def _test_module_empty_input(self, module, inp, check_size=True):
|
||||
inp.requires_grad_(True)
|
||||
def _test_module_empty_input(self, module, inp, check_size=True, inference=False):
|
||||
if not inference:
|
||||
inp.requires_grad_(True)
|
||||
out = module(inp)
|
||||
gO = torch.rand_like(out)
|
||||
out.backward(gO)
|
||||
if not inference:
|
||||
gO = torch.rand_like(out)
|
||||
out.backward(gO)
|
||||
if check_size:
|
||||
self.assertEqual(out.size(), inp.size())
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
self.assertEqual(p.grad, torch.zeros_like(p.grad))
|
||||
self.assertEqual(inp.grad, torch.zeros_like(inp))
|
||||
if not inference:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
self.assertEqual(p.grad, torch.zeros_like(p.grad))
|
||||
self.assertEqual(inp.grad, torch.zeros_like(inp))
|
||||
|
||||
def _test_module_empty_inputs(self, module, inputs):
|
||||
for _inp in inputs:
|
||||
@ -14500,11 +14627,29 @@ class TestNNDeviceType(NNTestCase):
|
||||
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
|
||||
@onlyNativeDeviceTypes
|
||||
def test_TransformerEncoderLayer_empty(self, device):
|
||||
for batch_first, input_shape in [(True, (0, 10, 512)),
|
||||
(False, (10, 0, 512))]:
|
||||
input = torch.rand(*input_shape, device=device)
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first).to(device)
|
||||
self._test_module_empty_input(encoder_layer, input, check_size=False)
|
||||
for training in (True, False):
|
||||
for batch_first, input_shape in [(True, (0, 10, 512)),
|
||||
(False, (10, 0, 512))]:
|
||||
input = torch.rand(*input_shape, device=device)
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first).to(device)
|
||||
if not training:
|
||||
encoder_layer = encoder_layer.eval()
|
||||
with torch.no_grad():
|
||||
self._test_module_empty_input(encoder_layer, input, check_size=False, inference=True)
|
||||
if batch_first and not TEST_WITH_CROSSREF:
|
||||
with torch.no_grad():
|
||||
# A NestedTensor with no tensors inside it doesn't have dim 3 (or dim
|
||||
# 2, for that matter) so it can't hit the fast path, nor can we give a
|
||||
# result.
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, 'MultiheadAttention does not support NestedTensor outside'):
|
||||
nt = torch.nested_tensor([], device=device)
|
||||
self._test_module_empty_input(encoder_layer, nt, check_size=False, inference=True)
|
||||
|
||||
nt = torch.nested_tensor([torch.rand(0, 512, device=device)], device=device)
|
||||
self._test_module_empty_input(encoder_layer, nt, check_size=False, inference=True)
|
||||
else:
|
||||
self._test_module_empty_input(encoder_layer, input, check_size=False)
|
||||
|
||||
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
|
||||
@onlyNativeDeviceTypes
|
||||
@ -17935,6 +18080,32 @@ class TestNNDeviceType(NNTestCase):
|
||||
self.assertEqual(q.size(), out[0].size())
|
||||
self.assertEqual(dtype, out[0].dtype)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half, torch.float, torch.double)
|
||||
def test_multihead_attention_dtype_batch_first(self, device, dtype):
|
||||
embed_dim = 128
|
||||
num_heads = 8
|
||||
sl = 10
|
||||
bs = 8
|
||||
# With batch_first=True, we have the possibility of hitting
|
||||
# the native fast path if we call .eval() and enable inference
|
||||
# mode. Test both paths.
|
||||
for training in (True, False):
|
||||
model = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda().to(dtype)
|
||||
if not training:
|
||||
model = model.eval()
|
||||
cm = torch.no_grad()
|
||||
else:
|
||||
cm = contextlib.nullcontext()
|
||||
with cm:
|
||||
q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
|
||||
k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
|
||||
v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
|
||||
# fast path currently doesn't support weights
|
||||
out = model(q, k, v, need_weights=False)
|
||||
self.assertEqual(q.size(), out[0].size())
|
||||
self.assertEqual(dtype, out[0].dtype)
|
||||
|
||||
@dtypesIfCUDA(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
|
||||
@dtypes(torch.float)
|
||||
def test_Conv2d_naive_groups(self, device, dtype):
|
||||
@ -19898,7 +20069,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
m(output_size)(t)
|
||||
|
||||
@dtypes(torch.float)
|
||||
@dtypesIfCUDA(torch.float, torch.half)
|
||||
@dtypesIfCUDA(torch.double, torch.float, torch.half)
|
||||
def test_transformerencoderlayer(self, device, dtype):
|
||||
# this is a deterministic test for TransformerEncoderLayer
|
||||
d_model = 4
|
||||
@ -19913,13 +20084,17 @@ class TestNNDeviceType(NNTestCase):
|
||||
atol = 1e-3
|
||||
rtol = 1e-2
|
||||
|
||||
for batch_first in (False, True):
|
||||
def _test(training, batch_first, atol, rtol):
|
||||
def perm_fn(x):
|
||||
return x.transpose(1, 0) if batch_first else x
|
||||
|
||||
model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
|
||||
batch_first=batch_first, device=device, dtype=dtype)
|
||||
|
||||
if not training:
|
||||
assert dropout == 0
|
||||
model = model.eval()
|
||||
|
||||
# set constant weights of the model
|
||||
for idx, p in enumerate(model.parameters()):
|
||||
x = p.data
|
||||
@ -19936,6 +20111,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
|
||||
# 0 values are NOT masked. This shouldn't mask anything.
|
||||
mask = torch.tensor([[0]], device=device) == 1
|
||||
# TODO: enable fast path for calls with a mask!
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
self.assertEqual(result.shape, ref_output.shape)
|
||||
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
|
||||
@ -19990,6 +20166,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
[2.422901, 0.024187, -0.606178, -0.074929]]], device=device, dtype=dtype))
|
||||
self.assertEqual(result.shape, ref_output.shape)
|
||||
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
# all 0
|
||||
mask = torch.zeros([2, 5], device=device) == 1
|
||||
result = model(encoder_input, src_key_padding_mask=mask)
|
||||
@ -20012,6 +20189,65 @@ class TestNNDeviceType(NNTestCase):
|
||||
self.assertEqual(result.shape, ref_output.shape)
|
||||
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
# NestedTensor is only supported for the fast path
|
||||
# currently, which won't be used if training.
|
||||
if (batch_first and not training and
|
||||
('cuda' in str(device) or 'cpu' in str(device)) and not TEST_WITH_CROSSREF):
|
||||
encoder_input[0][-1] = torch.zeros_like(encoder_input[0][1])
|
||||
mask = torch.zeros(encoder_input.shape[:-1], device=device, dtype=torch.bool)
|
||||
mask[0][-1] = True
|
||||
|
||||
nt = torch.nested_tensor([encoder_input[0][:-1], encoder_input[1]], device=device)
|
||||
result = model(nt)
|
||||
ref_output = torch.tensor(
|
||||
[
|
||||
[
|
||||
[2.4268184, 0.02042419, -0.603311, -0.08476824],
|
||||
[2.423306, 0.01889652, -0.6057701, -0.08519465],
|
||||
[2.431538, 0.02078694, -0.5999354, -0.08746159],
|
||||
[2.4348664, 0.02212971, -0.5975677, -0.08733892],
|
||||
[2.423133, 0.02097577, -0.60594773, -0.08113337],
|
||||
],
|
||||
[
|
||||
[2.4279876, 0.02121329, -0.60249615, -0.08410317],
|
||||
[2.4138637, 0.02221113, -0.6124869, -0.07249016],
|
||||
[2.4251041, 0.01974815, -0.6045152, -0.08483928],
|
||||
[2.4335563, 0.0218913, -0.59850943, -0.08683228],
|
||||
[2.4229012, 0.02418739, -0.6061784, -0.07492948],
|
||||
],
|
||||
],
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
result = result.to_padded_tensor(0)
|
||||
ref_output[0][-1] = torch.zeros_like(
|
||||
ref_output[0][-1], device=device, dtype=dtype
|
||||
)
|
||||
result[0][-1] = torch.zeros_like(
|
||||
result[0][-1], device=device, dtype=dtype
|
||||
)
|
||||
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
|
||||
if 'cuda' in device:
|
||||
if dtype == torch.float:
|
||||
atol = 2e-4
|
||||
rtol = 4e-3
|
||||
else:
|
||||
atol = 7e-4
|
||||
rtol = 2e-2
|
||||
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
|
||||
else:
|
||||
torch.testing.assert_close(result, ref_output)
|
||||
|
||||
|
||||
for batch_first in (True, False):
|
||||
for training in (True, False):
|
||||
if training:
|
||||
cm = contextlib.nullcontext()
|
||||
else:
|
||||
# Fast path requires inference mode.
|
||||
cm = torch.no_grad()
|
||||
with cm:
|
||||
_test(batch_first=batch_first, training=training, atol=atol, rtol=rtol)
|
||||
|
||||
@dtypes(torch.float)
|
||||
@dtypesIfCUDA(torch.half, torch.float)
|
||||
def test_transformerencoderlayer_gelu(self, device, dtype):
|
||||
@ -20028,12 +20264,15 @@ class TestNNDeviceType(NNTestCase):
|
||||
atol = 1e-3
|
||||
rtol = 1e-2
|
||||
|
||||
for activation, batch_first in product(('gelu', F.gelu, nn.GELU()), (True, False)):
|
||||
def _test(activation, batch_first, training):
|
||||
def perm_fn(x):
|
||||
return x.transpose(1, 0) if batch_first else x
|
||||
|
||||
model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
|
||||
activation, batch_first=batch_first, device=device, dtype=dtype)
|
||||
if not training:
|
||||
assert dropout == 0
|
||||
model = model.eval()
|
||||
|
||||
# set constant weights of the model
|
||||
for idx, p in enumerate(model.parameters()):
|
||||
@ -20080,6 +20319,14 @@ class TestNNDeviceType(NNTestCase):
|
||||
[[2.41383916, 0.02686345, -0.61256377, -0.06380707],
|
||||
[2.42000277, 0.03800944, -0.60824798, -0.04754947]]], device=device, dtype=dtype))
|
||||
torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol)
|
||||
for activation, batch_first, training in product(('gelu', F.gelu, nn.GELU()), (True, False), (True, False)):
|
||||
# Fast path requires inference mode.
|
||||
if training:
|
||||
cm = contextlib.nullcontext()
|
||||
else:
|
||||
cm = torch.no_grad()
|
||||
with cm:
|
||||
_test(activation=activation, batch_first=batch_first, training=training)
|
||||
|
||||
|
||||
class TestModuleGlobalHooks(TestCase):
|
||||
|
@ -88,6 +88,7 @@ includes = [
|
||||
"aten/src/ATen/native/nested/cuda/*",
|
||||
"aten/src/ATen/native/sparse/cuda/*",
|
||||
"aten/src/ATen/native/quantized/cuda/*",
|
||||
"aten/src/ATen/native/transformers/cuda/*",
|
||||
"aten/src/THC/*",
|
||||
"aten/src/ATen/test/*",
|
||||
# CMakeLists.txt isn't processed by default, but there are a few
|
||||
|
@ -1393,6 +1393,8 @@ aten_native_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/sparse/SparseTensorMath.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseUnaryOps.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
|
||||
"aten/src/ATen/native/transformers/attention.cpp",
|
||||
"aten/src/ATen/native/transformers/transformer.cpp",
|
||||
"aten/src/ATen/native/utils/Factory.cpp",
|
||||
"aten/src/ATen/native/xnnpack/Activation.cpp",
|
||||
"aten/src/ATen/native/xnnpack/ChannelShuffle.cpp",
|
||||
|
@ -893,6 +893,29 @@ class MultiheadAttention(Module):
|
||||
|
||||
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
||||
|
||||
``forward()`` will use a special optimized implementation if all of the following
|
||||
conditions are met:
|
||||
|
||||
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
||||
restriction will be loosened in the future.)
|
||||
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
||||
- training is disabled (using ``.eval()``)
|
||||
- dropout is 0
|
||||
- ``add_bias_kv`` is ``False``
|
||||
- ``add_zero_attn`` is ``False``
|
||||
- ``batch_first`` is ``True`` and the input is batched
|
||||
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
||||
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
||||
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
||||
nor ``attn_mask`` is passed
|
||||
|
||||
If the optimized implementation is in use, a
|
||||
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
||||
``query``/``key``/``value`` to represent padding more efficiently than using a
|
||||
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
||||
will be returned, and an additional speedup proportional to the fraction of the input
|
||||
that is padding can be expected.
|
||||
|
||||
Args:
|
||||
embed_dim: Total dimension of the model.
|
||||
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
||||
@ -911,6 +934,7 @@ class MultiheadAttention(Module):
|
||||
|
||||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
"""
|
||||
__constants__ = ['batch_first']
|
||||
bias_k: Optional[torch.Tensor]
|
||||
@ -1034,6 +1058,67 @@ class MultiheadAttention(Module):
|
||||
`batch_first` argument is ignored for unbatched inputs.
|
||||
"""
|
||||
is_batched = query.dim() == 3
|
||||
why_not_fast_path = ''
|
||||
if not is_batched:
|
||||
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
elif query is not key or key is not value:
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif self.training:
|
||||
why_not_fast_path = "training is enabled"
|
||||
elif not self.batch_first:
|
||||
why_not_fast_path = "batch_first was not True"
|
||||
elif self.bias_k is not None:
|
||||
why_not_fast_path = "self.bias_k was not None"
|
||||
elif self.bias_v is not None:
|
||||
why_not_fast_path = "self.bias_v was not None"
|
||||
elif self.dropout:
|
||||
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
||||
elif self.add_zero_attn:
|
||||
why_not_fast_path = "add_zero_attn was enabled"
|
||||
elif not self._qkv_same_embed_dim:
|
||||
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
||||
elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
|
||||
why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input"
|
||||
elif not query.is_nested and key_padding_mask is not None and attn_mask is not None:
|
||||
why_not_fast_path = "key_padding_mask and attn_mask were both supplied"
|
||||
|
||||
if not why_not_fast_path:
|
||||
tensor_args = (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
)
|
||||
# We have to use list comprehensions below because TorchScript does not support
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
|
||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
|
||||
why_not_fast_path = ("grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad")
|
||||
if not why_not_fast_path:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||
need_weights,
|
||||
average_attn_weights)
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
|
||||
f"The fast path was not hit because {why_not_fast_path}")
|
||||
|
||||
if self.batch_first and is_batched:
|
||||
# make sure that the transpose op does not affect the "is" property
|
||||
if key is value:
|
||||
|
@ -289,6 +289,29 @@ class TransformerEncoderLayer(Module):
|
||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
||||
>>> src = torch.rand(32, 10, 512)
|
||||
>>> out = encoder_layer(src)
|
||||
|
||||
Fast path:
|
||||
forward() will use a special optimized implementation if all of the following
|
||||
conditions are met:
|
||||
|
||||
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
|
||||
argument ``requires_grad``
|
||||
- training is disabled (using ``.eval()``)
|
||||
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
|
||||
- norm_first is ``False`` (this restriction may be loosened in the future)
|
||||
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
|
||||
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
|
||||
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
|
||||
nor ``src_key_padding_mask`` is passed
|
||||
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
|
||||
unless the caller has manually modified one without modifying the other)
|
||||
|
||||
If the optimized implementation is in use, a
|
||||
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
|
||||
passed for ``src`` to represent padding more efficiently than using a padding
|
||||
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
|
||||
returned, and an additional speedup proportional to the fraction of the input that
|
||||
is padding can be expected.
|
||||
"""
|
||||
__constants__ = ['batch_first', 'norm_first']
|
||||
|
||||
@ -313,16 +336,25 @@ class TransformerEncoderLayer(Module):
|
||||
|
||||
# Legacy string support for activation function.
|
||||
if isinstance(activation, str):
|
||||
self.activation = _get_activation_fn(activation)
|
||||
activation = _get_activation_fn(activation)
|
||||
|
||||
# We can't test self.activation in forward() in TorchScript,
|
||||
# so stash some information about it instead.
|
||||
if activation is F.relu:
|
||||
self.activation_relu_or_gelu = 1
|
||||
elif activation is F.gelu:
|
||||
self.activation_relu_or_gelu = 2
|
||||
else:
|
||||
self.activation = activation
|
||||
self.activation_relu_or_gelu = 0
|
||||
self.activation = activation
|
||||
|
||||
def __setstate__(self, state):
|
||||
if 'activation' not in state:
|
||||
state['activation'] = F.relu
|
||||
super(TransformerEncoderLayer, self).__setstate__(state)
|
||||
|
||||
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
||||
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
@ -336,6 +368,53 @@ class TransformerEncoderLayer(Module):
|
||||
|
||||
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
||||
|
||||
if (not self.norm_first and not self.training and
|
||||
self.self_attn.batch_first and src.dim() == 3 and self.self_attn._qkv_same_embed_dim and
|
||||
self.activation_relu_or_gelu and self.norm1.eps == self.norm2.eps and
|
||||
((src_mask is None and src_key_padding_mask is None)
|
||||
if src.is_nested
|
||||
else (src_mask is None or src_key_padding_mask is None))):
|
||||
tensor_args = (
|
||||
src,
|
||||
self.self_attn.in_proj_weight,
|
||||
self.self_attn.in_proj_bias,
|
||||
self.self_attn.out_proj.weight,
|
||||
self.self_attn.out_proj.bias,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
self.linear1.weight,
|
||||
self.linear1.bias,
|
||||
self.linear2.weight,
|
||||
self.linear2.bias,
|
||||
)
|
||||
if (not torch.overrides.has_torch_function(tensor_args) and
|
||||
# We have to use a list comprehension here because TorchScript
|
||||
# doesn't support generator expressions.
|
||||
all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]) and
|
||||
(not torch.is_grad_enabled() or all([not x.requires_grad for x in tensor_args]))):
|
||||
return torch._transformer_encoder_layer_fwd(
|
||||
src,
|
||||
self.self_attn.embed_dim,
|
||||
self.self_attn.num_heads,
|
||||
self.self_attn.in_proj_weight,
|
||||
self.self_attn.in_proj_bias,
|
||||
self.self_attn.out_proj.weight,
|
||||
self.self_attn.out_proj.bias,
|
||||
self.activation_relu_or_gelu == 2,
|
||||
False, # norm_first, currently not supported
|
||||
self.norm1.eps,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
self.linear1.weight,
|
||||
self.linear1.bias,
|
||||
self.linear2.weight,
|
||||
self.linear2.bias,
|
||||
src_mask if src_mask is not None else src_key_padding_mask,
|
||||
)
|
||||
x = src
|
||||
if self.norm_first:
|
||||
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
||||
@ -488,7 +567,7 @@ def _get_clones(module, N):
|
||||
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
|
@ -811,6 +811,7 @@ def preprocessor(
|
||||
or f.startswith("ATen/native/nested/cuda")
|
||||
or f.startswith("ATen/native/quantized/cuda")
|
||||
or f.startswith("ATen/native/sparse/cuda")
|
||||
or f.startswith("ATen/native/transformers/cuda")
|
||||
or f.startswith("THC/")
|
||||
or (f.startswith("THC") and not f.startswith("THCP"))
|
||||
):
|
||||
|
Reference in New Issue
Block a user