Revert "[PyTorch] Add native fast path for transformer encoder inference"

This reverts commit b369b89f235f54bc9de85d768fb62ac4579681dc.

This has internal changes and should not have been landed via mergebot.

Ref: https://github.com/pytorch/pytorch/pull/75809#issuecomment-1108717166
This commit is contained in:
Jon Janzen
2022-04-25 11:40:02 -04:00
parent 1c35f37c9f
commit 2387efd356
16 changed files with 30 additions and 1815 deletions

View File

@ -224,11 +224,6 @@ filegroup(
),
)
filegroup(
name = "aten_native_transformers_cpp",
srcs = glob(["aten/src/ATen/native/transformers/*.cpp"]),
)
filegroup(
name = "aten_native_mkl_cpp",
srcs = glob(["aten/src/ATen/native/mkl/*.cpp", "aten/src/ATen/mkl/*.cpp"]),
@ -279,7 +274,6 @@ filegroup(
"aten/src/ATen/native/miopen/*.cpp",
"aten/src/ATen/native/nested/cuda/*.cpp",
"aten/src/ATen/native/sparse/cuda/*.cpp",
"aten/src/ATen/native/transformers/cuda/*.cpp",
"aten/src/THC/*.cpp",
],
),
@ -294,7 +288,6 @@ filegroup(
"aten/src/ATen/native/nested/cuda/*.cu",
"aten/src/ATen/native/quantized/cuda/*.cu",
"aten/src/ATen/native/sparse/cuda/*.cu",
"aten/src/ATen/native/transformers/cuda/*.cu",
]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"),
# It's a bit puzzling to me why it's not necessary to declare the
# target that generates these sources...
@ -396,7 +389,6 @@ cc_library(
":aten_native_quantized_cpp",
":aten_native_sparse_cpp",
":aten_native_nested_cpp",
":aten_native_transformers_cpp",
":aten_native_xnnpack",
":aten_src_ATen_config",
] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),

View File

@ -105,7 +105,6 @@ file(GLOB native_quantized_cpp
"native/quantized/*.cpp"
"native/quantized/cpu/*.cpp")
file(GLOB native_nested_cpp "native/nested/*.cpp")
file(GLOB native_transformers_cpp "native/transformers/*.cpp")
file(GLOB native_h "native/*.h")
file(GLOB native_ao_sparse_h
@ -129,8 +128,6 @@ file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp")
file(GLOB native_quantized_cuda_cu "native/quantized/cuda/*.cu")
file(GLOB native_quantized_cuda_cpp "native/quantized/cuda/*.cpp")
file(GLOB native_quantized_cudnn_cpp "native/quantized/cudnn/*.cpp")
file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu")
file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp")
file(GLOB native_hip_hip "native/hip/*.hip")
file(GLOB native_hip_cpp "native/hip/*.cpp")
@ -143,8 +140,6 @@ file(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip")
file(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp")
file(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip")
file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip")
file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp")
file(GLOB native_utils_cpp "native/utils/*.cpp")
# XNNPACK
@ -167,7 +162,6 @@ else()
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp}
${native_ao_sparse_cpp} ${native_sparse_cpp} ${native_nested_cpp}
${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp}
${native_transformers_cpp}
${native_utils_cpp} ${native_xnnpack} ${generated_sources} ${core_generated_sources}
${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${ATen_NNAPI_SRCS} ${cpu_kernel_cpp}
)
@ -211,7 +205,6 @@ if(USE_CUDA)
${native_nested_cuda_cu}
${native_sparse_cuda_cu}
${native_quantized_cuda_cu}
${native_transformers_cuda_cu}
${cuda_generated_sources}
)
list(APPEND ATen_CUDA_CPP_SRCS
@ -223,7 +216,6 @@ if(USE_CUDA)
${native_quantized_cuda_cpp}
${native_quantized_cudnn_cpp}
${native_sparse_cuda_cpp}
${native_transformers_cuda_cpp}
)
set(ATen_CUDA_LINALG_SRCS ${native_cuda_linalg_cpp})
if(NOT BUILD_LAZY_CUDA_LINALG)
@ -246,9 +238,9 @@ endif()
if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip} ${native_transformers_hip_hip})
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip})
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${native_transformers_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
set(all_hip_cpp ${native_miopen_cpp} ${native_cudnn_hip_cpp} ${miopen_cpp} ${all_hip_cpp})
endif()

View File

@ -4663,12 +4663,6 @@
- func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor
# Fused implementation detail for transformers. Adds in-projection bias to QKV and divides Q by sqrt(D/num_heads).
- func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)
dispatch:
CPU, NestedTensorCPU: transform_bias_rescale_qkv_cpu
CUDA, NestedTensorCUDA: transform_bias_rescale_qkv_cuda
- func: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
device_check: NoCheck # cpu_nested_shape_example will always be on CPU
dispatch:
@ -11609,14 +11603,3 @@
variants: method
dispatch:
NestedTensorCPU, NestedTensorCUDA: NestedTensor_layer_norm
# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is.
- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None) -> Tensor
variants: function
dispatch:
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: transformer_encoder_layer_forward
- func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True) -> (Tensor, Tensor)
variants: function
dispatch:
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention

View File

@ -1,5 +1,3 @@
#include <ATen/native/nested/NestedTensorMath.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/NamedTensorUtils.h>

View File

@ -1,9 +1,5 @@
#pragma once
#include <c10/macros/Macros.h>
#include <vector>
namespace at {
namespace native {
struct NestedTensorImpl;
@ -11,7 +7,7 @@ struct NestedTensorImpl;
// TODO: cache this and only do it once per NestedTensor
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt);
std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt);
} // namespace native
} // namespace at

View File

@ -35,12 +35,6 @@ Tensor nested_from_padded_cuda(
const Tensor& sizes,
bool do_transform_0213) {
if (padded.dim() > 1 && padded.dim() < 5) {
if (padded.dtype() != kFloat && padded.dtype() != kHalf) {
TORCH_WARN_ONCE(
"nested_from_padded CUDA kernels only support fp32/fp16; falling "
"back to slower generic kernel");
return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213);
}
TORCH_CHECK(
(padded.dim() == 4 && do_transform_0213) ||
(padded.dim() == 3 && !do_transform_0213),

View File

@ -1,482 +0,0 @@
#include <type_traits>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIndexing.h>
#include <ATen/cpu/vec/vec256/vec256.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/cat.h>
#endif
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
namespace at {
namespace native {
namespace {
Tensor gemm_nt(const Tensor& self, const Tensor& other) {
if (self.is_nested()) {
return NestedTensor_matmul(self, other.t());
} else {
return at::native::matmul(self, other.t());
}
}
template <typename scalar_t>
void transform_bias_rescale_qkv_inner_loop(
int64_t B,
int64_t T,
int64_t _3D,
int64_t D,
int64_t num_head,
int64_t dim_per_head,
scalar_t* qkv_data,
scalar_t* qkv_bias_data,
scalar_t* q_k_v_data,
scalar_t inv_sqrt_dim_per_head,
int64_t begin,
int64_t end) {
for (auto i : c10::irange(begin, end)) {
auto t = i % T;
i /= T;
auto nh = i % num_head;
i /= num_head;
auto b = i;
using Vec = vec::Vectorized<scalar_t>;
auto V = vec::Vectorized<scalar_t>::size();
auto dh = 0;
auto d = nh * dim_per_head;
for (; dh + V <= dim_per_head; dh += V, d += V) {
// load
auto q_bias_data = Vec::loadu(&qkv_bias_data[d + 0 * D]);
auto k_bias_data = Vec::loadu(&qkv_bias_data[d + 1 * D]);
auto v_bias_data = Vec::loadu(&qkv_bias_data[d + 2 * D]);
auto q_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 0 * D]) +
q_bias_data;
auto k_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 1 * D]) +
k_bias_data;
auto v_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 2 * D]) +
v_bias_data;
q_data = q_data * Vec(inv_sqrt_dim_per_head);
q_data.store(&q_k_v_data
[0 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head + t * dim_per_head + dh]);
k_data.store(&q_k_v_data
[1 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head + t * dim_per_head + dh]);
v_data.store(&q_k_v_data
[2 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head + t * dim_per_head + dh]);
}
for (; dh < dim_per_head; dh++) {
auto d = nh * dim_per_head + dh;
auto q_bias = qkv_bias_data[d + 0 * D];
auto k_bias = qkv_bias_data[d + 1 * D];
auto v_bias = qkv_bias_data[d + 2 * D];
auto q_data = qkv_data[b * _3D * T + t * _3D + d + 0 * D] + q_bias;
auto k_data = qkv_data[b * _3D * T + t * _3D + d + 1 * D] + k_bias;
auto v_data = qkv_data[b * _3D * T + t * _3D + d + 2 * D] + v_bias;
q_data = q_data * inv_sqrt_dim_per_head;
q_k_v_data
[0 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head + nh * T * dim_per_head +
t * dim_per_head + dh] = q_data;
q_k_v_data
[1 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head + nh * T * dim_per_head +
t * dim_per_head + dh] = k_data;
q_k_v_data
[2 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head + nh * T * dim_per_head +
t * dim_per_head + dh] = v_data;
}
}
}
Tensor bmm_nt(const Tensor& a, const Tensor& b) {
auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});
auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});
auto bt_ = b_.transpose(2, 1);
auto c_ = at::bmm(a_, bt_);
return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});
}
Tensor masked_softmax(
Tensor& attn_scores,
c10::optional<Tensor> attn_mask,
const Tensor& query) {
if (query.is_nested() && !attn_mask) {
// TODO: maybe we could do better than generating a mask every time?
attn_mask = NestedTensor_to_mask(query, 2);
// TODO: CPU path does not support transformer mask yet.
if (attn_scores.is_cpu()) {
attn_mask = attn_mask->view({-1, 1, 1, attn_scores.sizes()[3]});
// 1 means skip, 0 means keep.
// want:
// 0,0 -> 0
// 0,1 -> 1
// 1,1 -> 1
// so that's logical OR.
*attn_mask = *attn_mask | attn_mask->transpose(2, 3);
attn_mask = at::expand_inplace(attn_scores, *attn_mask)->contiguous();
}
attn_mask = attn_mask->to(query.device(), /*non-blocking=*/true);
}
if (attn_mask && attn_mask->dtype() != at::kBool) {
TORCH_WARN(
"Converting mask without torch.bool dtype to bool; this will "
"negatively affect performance. Prefer to use a boolean mask directly.");
attn_mask = attn_mask->to(at::kBool);
}
if (attn_scores.is_cpu() && attn_mask && attn_mask->dim() == 2) {
// TODO: CPU path does not support transformer mask yet.
const auto batch_size = attn_scores.sizes()[0];
const auto seq_len = attn_scores.sizes()[3];
TORCH_CHECK(attn_mask->sizes()[0] == batch_size);
TORCH_CHECK(attn_mask->sizes()[1] == seq_len);
attn_mask = attn_mask->view({batch_size, 1, 1, seq_len});
attn_mask = at::expand_inplace(attn_scores, *attn_mask)->contiguous();
}
if (attn_mask) {
return _masked_softmax(attn_scores, *attn_mask);
} else {
return _softmax_out(attn_scores, attn_scores, attn_scores.dim() - 1, false);
}
}
Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b) {
const std::array<int64_t, 3> newAShape = {
a.sizes()[0] * a.sizes()[1], a.sizes()[2], a.sizes()[3]};
auto a_ = a.view(newAShape);
const std::array<int64_t, 3> newBShape = {
b.sizes()[0] * b.sizes()[1], b.sizes()[2], b.sizes()[3]};
auto b_ = b.view(newBShape);
auto out_ = out.reshape({newAShape[0], newAShape[1], newBShape[2]});
auto c_ = at::bmm_out(out_, a_, b_);
return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
}
Tensor transform_0213(const Tensor& a) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
return a.permute({0, 2, 1, 3})
.contiguous()
.view({a.size(0), a.size(2), a.size(1) * a.size(3)});
}
Tensor transform0213_gemm_nt_bias(
const Tensor& a,
const Tensor& b,
const Tensor& c,
const Tensor& query) {
if (query.is_nested()) {
at::Tensor nested_a = _nested_from_padded(
a, get_nested_tensor_impl(query)->get_nested_size_tensor(), true);
return NestedTensor_times_Tensor_plus_Tensor_addmm(
c, nested_a, b.t(), 1, 1);
} else {
const Tensor a_0213 = transform_0213(a);
auto a_ = a_0213.view({a_0213.size(0) * a_0213.size(1), a_0213.size(2)});
auto r_ = at::native::linear(a_, b, c);
return r_.view({a_0213.size(0), a_0213.size(1), r_.size(1)});
}
}
void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
(size_t)t.dim() == shape.size(),
"(called from line ",
line,
") ",
"expected ",
shape.size(),
"-D tensor but got ",
t.dim());
if (t.is_nested()) {
return;
}
for (auto idx : c10::irange(shape.size())) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
shape[idx] == 0 || t.sizes()[idx] == shape[idx],
"(called from line ",
line,
") ",
"expected dim ",
idx,
" to be ",
shape[idx],
" but got ",
t.sizes()[idx]);
}
}
} // namespace
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(
const Tensor& qkv,
const Tensor& qkv_bias,
const int64_t num_head) {
auto qkv_ = qkv.is_nested()
? c10::MaybeOwned<Tensor>::owned((NestedTensor_to_padded_tensor(qkv, 0)))
: c10::MaybeOwned<Tensor>::borrowed(qkv);
auto B = qkv_->size(0);
auto T = qkv_->size(1);
auto _3D = qkv_->size(2);
auto D = _3D / 3;
TORCH_CHECK(D % num_head == 0);
TORCH_CHECK(_3D % 3 == 0);
const auto dim_per_head = D / num_head;
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_->options());
const auto qkv_contig = qkv_->expect_contiguous();
const auto qkv_bias_contig = qkv_bias.expect_contiguous();
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
qkv_->scalar_type(),
"transform_bias_rescale_qkv",
[&] {
scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
const scalar_t inv_sqrt_dim_per_head =
1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head));
int64_t grain_size =
std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
parallel_for(
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
transform_bias_rescale_qkv_inner_loop(
B,
T,
_3D,
D,
num_head,
dim_per_head,
qkv_data,
qkv_bias_data,
q_k_v_data,
inv_sqrt_dim_per_head,
begin,
end);
});
});
auto q_k_v_s =
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
}
std::tuple<Tensor, Tensor> native_multi_head_attention(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const int64_t embed_dim,
const int64_t num_head,
const Tensor& qkv_weight,
const Tensor& qkv_bias,
const Tensor& proj_weight,
const Tensor& proj_bias,
const c10::optional<Tensor>& mask,
bool need_weights,
bool average_attn_weights) {
// query shape: [B, T, D]
// qkv_weight shape: [3 * D, D]
TORCH_CHECK(
!mask || !query.is_nested(),
"NestedTensor with mask is not supported yet");
const auto D = embed_dim;
TORCH_CHECK(
query.dim() == 3,
"expected 3-D `query`, got ",
query.dim(),
"-D tensor");
TORCH_CHECK(
query.is_nested() || query.sizes()[2] == embed_dim,
"passed-in embed_dim ",
embed_dim,
" didn't match last dim of query ",
query.sizes()[2]);
TORCH_CHECK(
key.dim() == 3,
"expected 3-D `key`, got ",
key.dim(),
"-D tensor");
TORCH_CHECK(
value.dim() == 3,
"expected 3-D `value`, got ",
value.dim(),
"-D tensor");
TORCH_CHECK(
query.is_nested() || key.is_nested() || value.is_nested() ||
(query.sizes() == key.sizes() && key.sizes() == value.sizes()),
"expected `query`/`key`/`value` shapes to match");
TORCH_CHECK(
qkv_weight.dim() == 2,
"expected 2-D `qkv_weight`, got ",
qkv_weight.dim(),
"-D tensor");
TORCH_CHECK(
D * 3 == qkv_weight.sizes()[0],
"expected `qkv_weight` first dim to be 3x embed_dim");
TORCH_CHECK(
D == qkv_weight.sizes()[1],
"expected `qkv_weight` second dim to be embed_Dim");
TORCH_CHECK(
qkv_bias.dim() == 1,
"expected 2-D `qkv_bias`, got ",
qkv_bias.dim(),
"-D tensor");
TORCH_CHECK(
qkv_bias.sizes()[0] == 3 * D,
"expected `qkv_bias` first dim and first dim of query to be equal");
TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`");
#ifndef NDEBUG
const auto B = query.is_nested()
? get_nested_tensor_impl(query)->get_nested_size_tensor().size(0)
: query.sizes()[0];
auto T = query.is_nested() ? 0 : query.sizes()[1];
const auto dim_per_head = D / num_head;
#endif
// shape: [B, T, 3 x D]
Tensor qkv;
if (key.is_same(value)) {
if (query.is_same(key)) {
// self-attention
qkv = gemm_nt(query, qkv_weight);
} else {
// encoder-decoder attention
// TODO: is there a more efficient way to set this up?
// TODO: can we stay nested insted of using cat? Probably just make a
// NestedTensor out of the matmul results or something?
auto q_kv_weight_s =
at::native::split_with_sizes(qkv_weight, {D, D * 2}, 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
q_kv_weight_s.size() == 2,
"expected split to produce 2 tensors but it produced ",
q_kv_weight_s.size());
auto q = gemm_nt(query, q_kv_weight_s[0]);
auto kv = gemm_nt(key, q_kv_weight_s[1]);
qkv = at::cat({q, kv}, 2);
}
} else {
auto q_k_v_weight_s = at::native::chunk(qkv_weight, 3, 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
q_k_v_weight_s.size() == 3,
"expected chunk to produce 3 tensors but it produced ",
q_k_v_weight_s.size());
// TODO: can we stay nested instead of using cat?
auto q = gemm_nt(query, q_k_v_weight_s[0]);
auto k = gemm_nt(key, q_k_v_weight_s[1]);
auto v = gemm_nt(value, q_k_v_weight_s[2]);
qkv = at::cat({q, k, v}, 2);
}
if (!qkv.is_nested() && qkv.numel() == 0) {
if (query.is_nested()) {
return std::make_tuple(Tensor(), Tensor());
}
return std::make_tuple(at::empty_like(query), Tensor());
}
#ifndef NDEBUG
if (!query.is_nested() || !qkv.is_nested()) {
if (query.is_nested()) {
T = qkv.size(1);
}
debug_assert_shape(__LINE__, qkv, {B, T, 3 * D});
}
#endif
#ifdef DEBUG_PRINT_EACH_STEP
if (!qkv.is_nested()) {
std::cerr << "qkv: " << qkv << std::endl;
}
#endif
// shape: 3 x [B, num_head, T, dim_per_head]
auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
qkv = Tensor(); // Not used any more, allow free
auto& q = std::get<0>(q_k_v);
const auto& k = std::get<1>(q_k_v);
const auto& v = std::get<2>(q_k_v);
#ifndef NDEBUG
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
#endif
#ifdef DEBUG_PRINT_EACH_STEP
std::cerr << "q: " << q << std::endl;
std::cerr << "k: " << k << std::endl;
std::cerr << "v: " << v << std::endl;
#endif
// shape: [B, num_head, T, T]
auto qkt = bmm_nt(q, k);
// q & k are dead but cannot be freed because they were packed with v
#ifndef NDEBUG
debug_assert_shape(__LINE__, qkt, {B, num_head, T, T});
#endif
#ifdef DEBUG_PRINT_EACH_STEP
std::cerr << "qkt: " << qkt << std::endl;
#endif
// shape: [B, num_head, T, T]
// TODO: long-term, have a kernel that works with
// NestedTensor directly if there is no mask passed
qkt = masked_softmax(qkt, mask, query);
#ifdef DEBUG_PRINT_EACH_STEP
std::cerr << "qkt after softmax: " << qkt << std::endl;
#endif
// shape: [B, num_head, T, dim_per_head]
// reuse storage for q; we're done with it
auto attn_ctx = bmm_nn(q, qkt, v);
// qkv is not dead; we just reused storage for q!
if (!need_weights) {
qkt = Tensor();
}
#ifndef NDEBUG
debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
#endif
#ifdef DEBUG_PRINT_EACH_STEP
std::cerr << "attn_ctx: " << attn_ctx << std::endl;
#endif
// shape: [B, T, D]
// Fuse transform_0213 inside
auto proj = transform0213_gemm_nt_bias(
attn_ctx, proj_weight, proj_bias, query);
#ifndef NDEBUG
debug_assert_shape(__LINE__, proj, {B, T, D});
#endif
if (need_weights && average_attn_weights) {
// weights are not needed for full transformer, so don't worry too
// much about performance -- we implement this just to make use
// cases that don't disable need_weights still get some speedup.
qkt = qkt.sum(1);
qkt /= num_head;
}
return std::make_tuple(std::move(proj), std::move(qkt));
}
} // namespace native
} // namespace at

View File

@ -1,400 +0,0 @@
#include <type_traits>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/PersistentSoftmax.cuh>
#include <ATen/native/cuda/block_reduce.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/native/nested/NestedTensorMath.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
namespace at {
namespace native {
namespace {
static constexpr int TRANSFORM_BIAS_RESCALE_VEC = 4;
template <typename scalar_t, typename accscalar_t, bool assume_aligned>
__global__ void transform_bias_rescale_qkv_kernel(
// [B, T, 3 * D]
const PackedTensorAccessor64<scalar_t, 3, RestrictPtrTraits> qkv,
// [3 * D]
const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
// [3, B, NH, T, DH]
PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v,
const scalar_t inv_sqrt_dim_per_head) {
// warp per DH.
// so launch B * NH * T warps.
auto NH = q_k_v.size(2);
auto T = q_k_v.size(3);
auto DH = q_k_v.size(4);
auto t = blockIdx.x % T;
auto b = blockIdx.x / T;
auto D = NH * DH;
if (assume_aligned) {
constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
using LoadT = memory::aligned_vector<scalar_t, VEC>;
for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
auto d = d_v * VEC;
auto nh = d / DH;
auto dh = d % DH;
scalar_t qkv_bias_q[VEC];
scalar_t qkv_bias_k[VEC];
scalar_t qkv_bias_v[VEC];
scalar_t qkv_q[VEC];
scalar_t qkv_k[VEC];
scalar_t qkv_v[VEC];
// Here we require D % VEC == 0 for these vectorized loads.
*reinterpret_cast<LoadT*>(&qkv_bias_q) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
*reinterpret_cast<LoadT*>(&qkv_bias_k) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
*reinterpret_cast<LoadT*>(&qkv_bias_v) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
*reinterpret_cast<LoadT*>(&qkv_q) =
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 0 * D]);
*reinterpret_cast<LoadT*>(&qkv_k) =
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 1 * D]);
*reinterpret_cast<LoadT*>(&qkv_v) =
*reinterpret_cast<const LoadT*>(&qkv[b][t][d + 2 * D]);
#pragma unroll
// TODO: specialize for float2half2/half2float2?
for (auto ii = 0; ii < VEC; ++ii) {
qkv_q[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_q[ii]) +
static_cast<accscalar_t>(qkv_bias_q[ii])) *
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
qkv_k[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_k[ii]) +
static_cast<accscalar_t>(qkv_bias_k[ii])));
qkv_v[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_v[ii]) +
static_cast<accscalar_t>(qkv_bias_v[ii])));
}
// Here we require DH % VEC == 0 for these vectorized stores.
*reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_q);
*reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_k);
*reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_v);
}
} else {
// Same as above, but we can't vectorize memory access.
for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
auto nh = d / DH;
auto dh = d % DH;
scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
scalar_t qkv_q = qkv[b][t][d + 0 * D];
scalar_t qkv_k = qkv[b][t][d + 1 * D];
scalar_t qkv_v = qkv[b][t][d + 2 * D];
qkv_q = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_q) +
static_cast<accscalar_t>(qkv_bias_q)) *
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
qkv_k = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_k) +
static_cast<accscalar_t>(qkv_bias_k)));
qkv_v = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_v) +
static_cast<accscalar_t>(qkv_bias_v)));
q_k_v[0][b][nh][t][dh] = qkv_q;
q_k_v[1][b][nh][t][dh] = qkv_k;
q_k_v[2][b][nh][t][dh] = qkv_v;
}
}
}
template <typename scalar_t, typename accscalar_t, bool assume_aligned = false>
__global__ void transform_bias_rescale_qkv_add_padding_kernel(
// [B, T, 3 * D], but it's a NestedTensor buffer
const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv,
// [3 * D]
const PackedTensorAccessor64<scalar_t, 1, RestrictPtrTraits> qkv_bias,
const int* offsets,
const int* input_sizes,
// [3, B, NH, T, DH]
PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v,
const scalar_t inv_sqrt_dim_per_head) {
// warp per DH.
// so launch B * NH * T warps.
const auto NH = q_k_v.size(2);
const auto T = q_k_v.size(3);
const auto DH = q_k_v.size(4);
const auto t = blockIdx.x % T;
const auto b = blockIdx.x / T;
const auto D = NH * DH;
const auto _3D = 3 * D;
const auto offset_for_batch = offsets[b];
const auto input_dim = 1;
const auto* sizes_i = input_sizes + b * input_dim;
if (assume_aligned) {
constexpr int VEC = TRANSFORM_BIAS_RESCALE_VEC;
using LoadT = memory::aligned_vector<scalar_t, VEC>;
for (int32_t d_v = threadIdx.x; d_v < D / VEC; d_v += blockDim.x) {
auto d = d_v * VEC;
auto nh = d / DH;
auto dh = d % DH;
scalar_t qkv_bias_q[VEC];
scalar_t qkv_bias_k[VEC];
scalar_t qkv_bias_v[VEC];
scalar_t qkv_q[VEC];
scalar_t qkv_k[VEC];
scalar_t qkv_v[VEC];
const auto first_item_offset = t * _3D + d;
const auto last_item_offset = first_item_offset + VEC - 1;
const bool first_item_in_bounds = first_item_offset < sizes_i[0];
const bool entire_vec_in_bounds = last_item_offset < sizes_i[0];
// Here we require D % VEC == 0 for these vectorized loads.
*reinterpret_cast<LoadT*>(&qkv_bias_q) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 0 * D]);
*reinterpret_cast<LoadT*>(&qkv_bias_k) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 1 * D]);
*reinterpret_cast<LoadT*>(&qkv_bias_v) =
*reinterpret_cast<const LoadT*>(&qkv_bias[d + 2 * D]);
if (entire_vec_in_bounds) {
const auto offset = offset_for_batch + first_item_offset;
*reinterpret_cast<LoadT*>(&qkv_q) =
*reinterpret_cast<const LoadT*>(&qkv[offset + 0 * D]);
*reinterpret_cast<LoadT*>(&qkv_k) =
*reinterpret_cast<const LoadT*>(&qkv[offset + 1 * D]);
*reinterpret_cast<LoadT*>(&qkv_v) =
*reinterpret_cast<const LoadT*>(&qkv[offset + 2 * D]);
#pragma unroll
// TODO: specialize for float2half2/half2float2?
for (auto ii = 0; ii < VEC; ++ii) {
qkv_q[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_q[ii]) +
static_cast<accscalar_t>(qkv_bias_q[ii])) *
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
qkv_k[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_k[ii]) +
static_cast<accscalar_t>(qkv_bias_k[ii])));
qkv_v[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_v[ii]) +
static_cast<accscalar_t>(qkv_bias_v[ii])));
}
} else if (first_item_in_bounds) {
const auto offset = offset_for_batch + first_item_offset;
qkv_q[0] = qkv[offset + 0 * D];
qkv_k[0] = qkv[offset + 1 * D];
qkv_v[0] = qkv[offset + 2 * D];
qkv_q[0] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_q[0]) +
static_cast<accscalar_t>(qkv_bias_q[0])) *
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
qkv_k[0] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_k[0]) +
static_cast<accscalar_t>(qkv_bias_k[0])));
qkv_v[0] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_v[0]) +
static_cast<accscalar_t>(qkv_bias_v[0])));
#pragma unroll
for (auto ii = 1; ii < VEC; ++ii) {
const auto loop_offset = offset + ii;
if (loop_offset < sizes_i[0]) {
qkv_q[ii] = qkv[loop_offset + 0 * D];
qkv_k[ii] = qkv[loop_offset + 1 * D];
qkv_v[ii] = qkv[loop_offset + 2 * D];
qkv_q[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_q[ii]) +
static_cast<accscalar_t>(qkv_bias_q[ii])) *
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
qkv_k[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_k[ii]) +
static_cast<accscalar_t>(qkv_bias_k[ii])));
qkv_v[ii] = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_v[ii]) +
static_cast<accscalar_t>(qkv_bias_v[ii])));
} else {
qkv_q[ii] = 0;
qkv_k[ii] = 0;
qkv_v[ii] = 0;
}
}
} else {
#pragma unroll
for (auto ii = 0; ii < VEC; ++ii) {
qkv_q[ii] = 0;
qkv_k[ii] = 0;
qkv_v[ii] = 0;
}
}
// Here we require DH % VEC == 0 for these vectorized stores.
*reinterpret_cast<LoadT*>(&q_k_v[0][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_q);
*reinterpret_cast<LoadT*>(&q_k_v[1][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_k);
*reinterpret_cast<LoadT*>(&q_k_v[2][b][nh][t][dh]) =
*reinterpret_cast<const LoadT*>(&qkv_v);
}
} else {
for (int32_t d = threadIdx.x; d < D; d += blockDim.x) {
auto nh = d / DH;
auto dh = d % DH;
scalar_t qkv_bias_q = qkv_bias[d + 0 * D];
scalar_t qkv_bias_k = qkv_bias[d + 1 * D];
scalar_t qkv_bias_v = qkv_bias[d + 2 * D];
const auto item_offset = t * _3D + d;
const bool in_bounds = item_offset < sizes_i[0];
scalar_t qkv_q, qkv_k, qkv_v;
if (in_bounds) {
const auto qkv_offset = offset_for_batch + item_offset;
qkv_q = qkv[qkv_offset + 0 * D];
qkv_k = qkv[qkv_offset + 1 * D];
qkv_v = qkv[qkv_offset + 2 * D];
qkv_q = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_q) +
static_cast<accscalar_t>(qkv_bias_q)) *
static_cast<accscalar_t>(inv_sqrt_dim_per_head));
qkv_k = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_k) +
static_cast<accscalar_t>(qkv_bias_k)));
qkv_v = static_cast<scalar_t>(
(static_cast<accscalar_t>(qkv_v) +
static_cast<accscalar_t>(qkv_bias_v)));
} else {
qkv_q = 0;
qkv_k = 0;
qkv_v = 0;
}
q_k_v[0][b][nh][t][dh] = qkv_q;
q_k_v[1][b][nh][t][dh] = qkv_k;
q_k_v[2][b][nh][t][dh] = qkv_v;
}
}
}
Tensor collapse_dims_1_and_2(const Tensor& sizes) {
auto sizes_dim1 = at::native::narrow(sizes, 1, 0, 1);
auto sizes_dim2 = at::native::narrow(sizes, 1, 1, 1);
return (sizes_dim1 * sizes_dim2).contiguous();
}
} // namespace
// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
__host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
const Tensor& qkv,
const Tensor& qkv_bias,
const int64_t num_head) {
auto B = qkv.is_nested()
? get_nested_tensor_impl(qkv)->get_nested_size_tensor().size(0)
: qkv.size(0);
// TODO: calculate this without the std::vector -- NestedTensor_to_mask wants
// this too
auto T = qkv.is_nested()
? NestedTensor_get_max_size(*get_nested_tensor_impl(qkv))[0]
: qkv.size(1);
auto _3D = qkv_bias.size(0);
auto D = _3D / 3;
TORCH_CHECK(D % num_head == 0);
const auto dim_per_head = D / num_head;
auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_bias.options());
#define CALL_KERNEL(assume_aligned) \
transform_bias_rescale_qkv_kernel<scalar_t, accscalar_t, assume_aligned> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
qkv.packed_accessor64<scalar_t, 3, RestrictPtrTraits>(), \
qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>(), \
1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head)))
#define CALL_ADD_PADDING_KERNEL(assume_aligned) \
transform_bias_rescale_qkv_add_padding_kernel< \
scalar_t, \
accscalar_t, \
assume_aligned> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
nt_qkv->get_buffer() \
.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
qkv_bias.packed_accessor64<scalar_t, 1, RestrictPtrTraits>(), \
offsets_ptr, \
sizes_ptr, \
q_k_v.packed_accessor64<scalar_t, 5, RestrictPtrTraits>(), \
1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head)))
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
qkv.scalar_type(),
"transform_bias_rescale_qkv",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
auto threads = std::max(
std::min<int32_t>(1024, D / TRANSFORM_BIAS_RESCALE_VEC), 1);
auto blocks = B * T;
const bool aligned =
((dim_per_head % TRANSFORM_BIAS_RESCALE_VEC) == 0) &&
((reinterpret_cast<intptr_t>(qkv_bias.data_ptr()) %
TRANSFORM_BIAS_RESCALE_VEC) == 0);
if (aligned) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
D % TRANSFORM_BIAS_RESCALE_VEC == 0,
"D = num_heads * dim_per_head, so we should have dim_per_head % "
"TRANSFORM_BIAS_RESCALE_VEC == 0 => "
"D % TRANSFORM_BIAS_RESCALE_VEC == 0");
}
if (qkv.is_nested()) {
auto* nt_qkv = get_nested_tensor_impl(qkv);
auto sizes = collapse_dims_1_and_2(nt_qkv->get_nested_size_tensor());
auto offsets =
NestedTensor_batch_offsets_from_size_tensor(sizes, sizes.numel());
at::native::narrow(offsets, 0, sizes.numel() + 1, sizes.numel())
.copy_(sizes.reshape({-1}));
auto metadata = offsets.to(at::Device(kCUDA), at::kInt, true, true);
const auto offsets_ptr = metadata.data_ptr<int>();
const auto sizes_ptr = offsets_ptr + sizes.numel() + 1;
const auto input_dim = sizes.sizes()[1];
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input_dim == 1);
if (aligned &&
((reinterpret_cast<intptr_t>(nt_qkv->get_buffer().data_ptr()) %
TRANSFORM_BIAS_RESCALE_VEC) == 0)) {
CALL_ADD_PADDING_KERNEL(true);
} else {
CALL_ADD_PADDING_KERNEL(false);
}
} else if (aligned) {
CALL_KERNEL(true);
} else {
CALL_KERNEL(false);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#undef CALL_ADD_PADDING_KERNEL
#undef CALL_KERNEL
auto q_k_v_s =
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
}
} // namespace native
} // namespace at

View File

@ -1,137 +0,0 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/NestedTensorImpl.h>
#include <torch/library.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
namespace at {
namespace native {
namespace {
Tensor linear_for_ffn(
const Tensor& bias,
const Tensor& mat1,
const Tensor& mat2,
c10::optional<bool> use_gelu) {
if (mat1.is_nested()) {
return NestedTensor_times_Tensor_plus_Tensor_addmm(
bias, mat1, mat2.t(), 1, 1, use_gelu);
}
auto mat1_ = mat1.view({mat1.sizes()[0] * mat1.sizes()[1], mat1.sizes()[2]});
Tensor result;
if (use_gelu.has_value()) {
result = at::_addmm_activation(bias, mat1_, mat2.t(), 1, 1, *use_gelu);
} else {
result = at::addmm(bias, mat1_, mat2.t());
}
return result.view({mat1.sizes()[0], mat1.sizes()[1], -1});
}
Tensor ffn(
const Tensor& input,
const Tensor& w1,
const Tensor& b1,
const Tensor& w2,
const Tensor& b2,
bool use_gelu,
bool add_norm) {
TORCH_CHECK(add_norm == false, "TODO add_norm to be supported in FFN");
TORCH_CHECK(input.dim() == 3, "batched input size should be 3");
TORCH_CHECK(w1.dim() == 2, "2d weights expected");
TORCH_CHECK(w2.dim() == 2, "2d weights expected");
Tensor res = linear_for_ffn(b1, input, w1, use_gelu);
res = linear_for_ffn(b2, res, w2, c10::nullopt);
return res;
}
} // namespace
Tensor transformer_encoder_layer_forward(
const Tensor& src,
const int64_t embed_dim,
const int64_t num_heads,
const Tensor& qkv_weight,
const Tensor& qkv_bias,
const Tensor& proj_weight,
const Tensor& proj_bias,
const bool use_gelu,
const bool norm_first,
const double layer_norm_eps,
const Tensor& layer_norm_weight_1,
const Tensor& layer_norm_bias_1,
const Tensor& layer_norm_weight_2,
const Tensor& layer_norm_bias_2,
const Tensor& ffn_weight_1,
const Tensor& ffn_bias_1,
const Tensor& ffn_weight_2,
const Tensor& ffn_bias_2,
const c10::optional<Tensor>& mask) {
{
const Tensor& check_for_empty = src.is_nested() ? get_nested_tensor_impl(src)->get_buffer() : src;
if (check_for_empty.numel() == 0) {
return src.is_nested()
? at::detail::make_tensor<NestedTensorImpl>(check_for_empty, get_nested_tensor_impl(src)->get_nested_size_tensor())
: src.clone();
}
}
TORCH_CHECK(!norm_first, "norm_first is not supported yet");
const bool use_nested_tensor = src.is_nested();
auto x = std::get<0>(native_multi_head_attention(
src,
src,
src,
embed_dim,
num_heads,
qkv_weight,
qkv_bias,
proj_weight,
proj_bias,
mask,
false /* need_weights */));
if (use_nested_tensor) {
NestedTensor_add_NestedTensor_in_place(x, src);
x = NestedTensor_layer_norm(
x, layer_norm_weight_1, layer_norm_bias_1, layer_norm_eps);
} else {
x.add_(src);
x = at::layer_norm(
x,
{embed_dim},
layer_norm_weight_1,
layer_norm_bias_1,
layer_norm_eps,
true);
}
auto pre_ffn_res = x;
x = ffn(
x,
ffn_weight_1,
ffn_bias_1,
ffn_weight_2,
ffn_bias_2,
use_gelu,
/* add_norm* */ false);
if (use_nested_tensor) {
NestedTensor_add_NestedTensor_in_place(x, pre_ffn_res);
x = NestedTensor_layer_norm(
x, layer_norm_weight_2, layer_norm_bias_2, layer_norm_eps);
} else {
x.add_(pre_ffn_res);
x = at::layer_norm(
x,
{embed_dim},
layer_norm_weight_2,
layer_norm_bias_2,
layer_norm_eps,
true);
}
return x;
}
} // namespace native
} // namespace at

View File

@ -1,306 +0,0 @@
# Owner(s): ["module: nn"]
import math
import torch
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
instantiate_device_type_tests,
onlyCUDA,
skipMeta,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestMHADeviceType(TestCase):
@torch.no_grad()
def _test_transform_bias_rescale_qkv_impl(
self, device, dtype, use_nt, use_padding=False
):
tests = [
(64, 4, 16, 8),
# dim_per_head = 12 does not divide evenly by CPU vectorization length of 8
(24, 2, 4, 2),
# Make sure CUDA can handle small input sizes
(2, 2, 2, 2),
# dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4,
# causes alignment issues
(24, 4, 4, 2),
(48, 4, 16, 8),
]
for (embed_dim, num_heads, bs, sl) in tests:
with self.subTest(embed_dim=embed_dim, num_heads=num_heads, bs=bs, sl=sl):
torch.manual_seed(9343)
dense_x = x = (
torch.randn(bs, sl, 3 * embed_dim, device=device, dtype=dtype) * 10
)
if use_padding:
x[0][-1] = torch.full(x[0][-1].shape, float("-Inf"))
if use_nt:
xs = list(torch.unbind(x))
if use_padding:
xs[0] = xs[0][:-1]
x = torch.nested_tensor(xs, device=device, dtype=dtype)
qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
# We have to use inference_mode here because q/k/v are
# all views of the same Tensor, which autograd doesn't
# like. This is fine because this function is only
# exposed to Python for purposes of writing this test.
with torch.inference_mode():
(q, k, v) = torch._transform_bias_rescale_qkv(
x, qkv.bias, num_heads=num_heads
)
def simple_transform_bias_rescale_qkv(qkv, bias):
(q, k, v) = torch.split(qkv, embed_dim, dim=-1)
(q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1)
return tuple(
x.reshape(
(bs, sl, num_heads, embed_dim // num_heads)
).transpose(2, 1)
for x in (
(q + q_bias) / math.sqrt(embed_dim // num_heads),
(k + k_bias),
(v + v_bias),
)
)
correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv(
dense_x, qkv.bias
)
if use_nt and use_padding:
for t in (correct_q, correct_k, correct_v):
t[t == float("-Inf")] = 0
self.assertEqual(q.size(), correct_q.size())
torch.testing.assert_close(q, correct_q)
torch.testing.assert_close(k, correct_k)
torch.testing.assert_close(v, correct_v)
@dtypesIfCUDA(torch.float)
@dtypes(torch.float)
@skipMeta
def test_transform_bias_rescale_qkv(self, device, dtype):
for use_padding in (False, True):
with self.subTest(use_padding=use_padding):
self._test_transform_bias_rescale_qkv_impl(
device, dtype, use_nt=False, use_padding=use_padding
)
@dtypesIfCUDA(torch.float)
@dtypes(torch.float)
@skipMeta
@onlyCUDA
def test_transform_bias_rescale_qkv_nested(self, device, dtype):
for use_padding in (False, True):
with self.subTest(use_padding=use_padding):
self._test_transform_bias_rescale_qkv_impl(
device, dtype, use_nt=True, use_padding=use_padding
)
def _test_multihead_attention_impl(
self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False
):
embed_dim = 64
num_heads = 4
bs = 16
sl = 8
q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
if use_padding:
if pad_all:
for q_i in q:
q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=dtype)
mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool)
for mask_i in mask:
mask_i[-1] = True
else:
q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=dtype)
mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool)
mask[0][-1] = True
if mode == "self":
k = q
v = q
elif mode == "encdec":
k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
v = k
elif mode == "generic":
k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10
else:
self.fail(f"invalid mode `{mode}`!")
qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
pt = torch.nn.MultiheadAttention(
embed_dim, num_heads, batch_first=True, device=device, dtype=dtype
)
pt.in_proj_weight = qkv.weight
pt.in_proj_bias = qkv.bias
pt.out_proj.weight = proj.weight
pt.out_proj.bias = proj.bias
class NativeMHA(torch.nn.Module):
def __init__(self, embed_dim, num_heads, qkv, proj):
super().__init__()
self.qkv = qkv
self.proj = proj
self.embed_dim = embed_dim
self.num_heads = num_heads
def forward(self, q, k, v, key_padding_mask):
return torch._native_multi_head_attention(
q,
k,
v,
self.embed_dim,
self.num_heads,
self.qkv.weight,
self.qkv.bias,
self.proj.weight,
self.proj.bias,
key_padding_mask,
need_weights=need_weights,
average_attn_weights=average_attn_weights,
)
npt = NativeMHA(
embed_dim=embed_dim, num_heads=num_heads, qkv=qkv, proj=proj
).to(dtype)
if device == "cuda":
pt = pt.cuda()
npt = npt.cuda()
ypt, weight_pt = pt(
q,
k,
v,
need_weights=need_weights,
average_attn_weights=average_attn_weights,
key_padding_mask=mask if use_padding else None,
)
if use_nt:
qs = list(torch.unbind(q))
if use_padding:
if pad_all:
qs = [x[:-1] for x in qs]
else:
qs[0] = qs[0][:-1]
q = torch.nested_tensor(qs, device=device, dtype=dtype)
if mode == "self":
k = v = q
elif mode == "encdec":
k = torch.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
v = k
else:
k = torch.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
v = torch.nested_tensor(torch.unbind(v), device=device, dtype=dtype)
ynpt, weight_npt = npt(
q, k, v, key_padding_mask=mask if use_padding and not use_nt else None
)
if use_nt:
ynpt = ynpt.to_padded_tensor(0)
if pad_all:
ynpt_final = torch.zeros_like(ypt)
ynpt_final[:, :ynpt.shape[1], :] = ynpt
ynpt = ynpt_final
def do_pad_all(tensors):
for t in tensors:
for t_i in t:
t_i[-1] = torch.zeros_like(t_i[-1], device=device, dtype=dtype)
# PyTorch implementation returns non-zero junk in the padding
# locations; overwrite it so that the comparison works out.
if use_padding:
ypt[0][-1] = torch.zeros_like(ypt[0][-1], device=device, dtype=dtype)
ynpt[0][-1] = torch.zeros_like(ynpt[0][-1], device=device, dtype=dtype)
if pad_all:
do_pad_all((ypt, ynpt))
# Zero the last row of each TxT weight matrix
if need_weights:
if average_attn_weights:
weight_pt[0][-1] = torch.zeros_like(weight_pt[0][-1], device=device, dtype=dtype)
weight_npt[0][-1] = torch.zeros_like(weight_npt[0][-1], device=device, dtype=dtype)
if pad_all:
do_pad_all((weight_pt, weight_npt))
else:
for nh in range(num_heads):
weight_pt[0][nh][-1] = torch.zeros_like(weight_pt[0][nh][-1], device=device, dtype=dtype)
weight_npt[0][nh][-1] = torch.zeros_like(weight_npt[0][nh][-1], device=device, dtype=dtype)
if dtype == torch.half:
torch.testing.assert_close(ypt, ynpt, atol=1e-3, rtol=1e-3)
else:
# High rtol seems necessary for
# test_native_multihead_attention_cpu_float32 on Windows,
# otherwise 2e-4 would likely be fine.
torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3)
if need_weights:
torch.testing.assert_close(weight_pt, weight_npt)
else:
self.assertEqual(weight_pt, weight_npt)
@dtypesIfCUDA(torch.float, torch.half)
@dtypes(torch.float)
@skipMeta
@torch.no_grad()
def test_native_multihead_self_attention(self, device, dtype):
for (use_padding, pad_all) in ((False, False), (True, False), (True, True)):
for use_nt in (False, True):
# Figuring out exactly which elements of the weights are garbage in this
# case eludes me, and it's not particularly enlightening to test anyway
# because padding doesn't especially affect the intermediate weights.
for need_weights in (False, not pad_all):
for average_attn_weights in (False, True):
with self.subTest(use_padding=use_padding, pad_all=pad_all,
use_nt=use_nt, need_weights=need_weights,
average_attn_weights=average_attn_weights):
self._test_multihead_attention_impl(
device,
dtype,
"self",
use_nt=use_nt,
use_padding=use_padding,
pad_all=pad_all,
need_weights=need_weights,
average_attn_weights=average_attn_weights,
)
@dtypesIfCUDA(torch.float, torch.half)
@dtypes(torch.float)
@skipMeta
@torch.no_grad()
def test_native_multihead_encoder_decoder_attention(self, device, dtype):
self._test_multihead_attention_impl(
device,
dtype,
"encdec",
use_nt=False,
need_weights=False,
average_attn_weights=False,
)
@dtypesIfCUDA(torch.float, torch.half)
@dtypes(torch.float)
@skipMeta
@torch.no_grad()
def test_native_multihead_attention(self, device, dtype):
self._test_multihead_attention_impl(
device,
dtype,
"generic",
use_nt=False,
need_weights=False,
average_attn_weights=False,
)
instantiate_device_type_tests(TestMHADeviceType, globals())
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: nn"]
import contextlib
import math
import random
import string
@ -39,8 +38,8 @@ from torch.nn.parallel._functions import Broadcast
from torch.testing._internal.common_dtype import integral_types, floating_types_and, get_all_math_dtypes, \
floating_and_complex_types_and
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \
download_file, get_function_arglist, load_tests, \
skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
get_function_arglist, load_tests, \
suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC, \
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, set_default_dtype
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
@ -5766,7 +5765,9 @@ class TestNN(NNTestCase):
self.assertIsNone(mha.in_proj_bias)
self.assertIsNone(mha.out_proj.bias)
def _test_multihead_attn_invalid_shape_impl(self, mha):
def test_multihead_attn_invalid_shape(self):
mha = torch.nn.MultiheadAttention(3, 3)
# Batched (3D) query cases
query = torch.randn(3, 3, 3)
key = torch.randn(3, 3, 3)
@ -5822,113 +5823,6 @@ class TestNN(NNTestCase):
with self.assertRaisesRegex(AssertionError, msg):
mha(query, key, value, attn_mask=torch.randn(4, 3, 3).bernoulli_().to(torch.bool))
def test_multihead_attn_invalid_shape(self):
mha = torch.nn.MultiheadAttention(3, 3)
self._test_multihead_attn_invalid_shape_impl(mha)
# Give the test a chance to hit the fast path. (Right now, it
# won't, but gating may be less restricted in the future.)
with torch.no_grad():
self._test_multihead_attn_invalid_shape_impl(mha.eval())
@torch.no_grad()
def test_multihead_attn_fast_path_invalid_shape(self):
mha = torch.nn.MultiheadAttention(3, 3, batch_first=True).eval()
# Batched (3D) query cases
query = torch.randn(3, 3, 3)
key = torch.randn(3, 3, 3)
value = torch.randn(3, 3, 3)
# Currently, this case will just go to the slow path and get
# the usual message because it fails the requirement to be
# batched.
msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively"
# 3D query, 2D key and 3D value
with self.assertRaisesRegex(AssertionError, msg):
mha(query, torch.randn(3, 3), value, need_weights=False)
# Currently, this case will just go to the slow path and get
# the usual message because it fails the requirement to be
# batched.
msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively"
# 3D query, 3D key and 2D value
with self.assertRaisesRegex(AssertionError, msg):
mha(query, key, torch.randn(3, 3), need_weights=False)
msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead"
# 3D query, 3D key, 3D value and 1D key_padding_mask
with self.assertRaisesRegex(AssertionError, msg):
mha(query, key, value, key_padding_mask=torch.tensor([False, True, True], dtype=torch.bool), need_weights=False)
msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead"
# 3D query, 3D key, 3D value and 1D attn_mask
with self.assertRaisesRegex(AssertionError, msg):
mha(query, key, value, attn_mask=torch.tensor([False, True, True], dtype=torch.bool), need_weights=False)
# Unbatched (2D) query cases
# NOTE: error messages are the same as regular path because the fast path doesn't support 2D.
query = torch.randn(3, 3)
key = torch.randn(3, 3)
value = torch.randn(3, 3)
msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively"
# 2D query, 3D key and 2D value
with self.assertRaisesRegex(AssertionError, msg):
mha(query, torch.randn(3, 3, 3), value)
msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively"
# 2D query, 3D key and 2D value
with self.assertRaisesRegex(AssertionError, msg):
mha(query, key, torch.randn(3, 3, 3))
msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead"
# 2D query, 2D key, 2D value and 1D key_padding_mask
with self.assertRaisesRegex(AssertionError, msg):
mha(query, key, value, key_padding_mask=torch.tensor([[False, True, True] * 2], dtype=torch.bool))
msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead"
# 2D query, 2D key, 2D value and 1D attn_mask
with self.assertRaisesRegex(AssertionError, msg):
mha(query, key, value, attn_mask=torch.tensor([False, True, True], dtype=torch.bool))
msg = r"Expected `attn_mask` shape to be \(3, 3, 3\)"
# 2D query, 2D key, 2D value and 3D incorrect attn_mask
with self.assertRaisesRegex(AssertionError, msg):
mha(query, key, value, attn_mask=torch.randn(4, 3, 3).bernoulli_().to(torch.bool))
def test_multihead_attn_nested_tensor_outside_fast_path(self):
mha = torch.nn.MultiheadAttention(3, 3, batch_first=True).eval()
nt = torch.nested_tensor([torch.randn(3, 3)])
# One tested platform (linux-bionic-py3.7-clang) has a torch_function for one
# or more of these. Take advantage of that to test the torch_function bailout.
has_torch_func = torch.overrides.has_torch_function(
(nt, mha.in_proj_weight, mha.in_proj_bias, mha.out_proj.weight, mha.out_proj.bias))
if has_torch_func:
msg = "MultiheadAttention does not support NestedTensor.*argument has_torch_function"
else:
msg = ("MultiheadAttention does not support NestedTensor outside of its fast path.*grad is " +
"enabled and.*or biases requires_grad")
with self.assertRaisesRegex(AssertionError, msg):
mha(nt, nt, nt)
if has_torch_func:
# Just give up, they're all going to fail with the same message.
return
with torch.no_grad():
mha(nt, nt, nt)
with torch.inference_mode():
mha(nt, nt, nt)
nt = torch.nested_tensor([torch.randn(3, 3, requires_grad=False)])
nt.requires_grad = False
with self.assertRaisesRegex(AssertionError, msg):
mha(nt, nt, nt)
mha.in_proj_weight.requires_grad = False
mha.in_proj_bias.requires_grad = False
mha.out_proj.weight.requires_grad = False
mha.out_proj.bias.requires_grad = False
mha(nt, nt, nt)
def test_normalize(self):
inputs = torch.randn(1, 3, 4, 4, requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
@ -7824,7 +7718,7 @@ class TestNN(NNTestCase):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
def _test(batch_first, training):
for batch_first in (True, False):
def perm_fn(x):
return x.transpose(1, 0) if batch_first else x
@ -7832,8 +7726,6 @@ class TestNN(NNTestCase):
batch_first=batch_first)
model = nn.TransformerEncoder(encoder_layer, 1).to(device)
if not training:
model = model.eval()
# deterministic input
encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
@ -7887,8 +7779,6 @@ class TestNN(NNTestCase):
# test case 2, multiple layers no norm
model = nn.TransformerEncoder(encoder_layer, 2).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
[2.419102, 0.017452, -0.608703, -0.085026]],
@ -7905,8 +7795,6 @@ class TestNN(NNTestCase):
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
model = nn.TransformerEncoder(encoder_layer, 6).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
[2.419101, 0.017453, -0.608704, -0.085025]],
@ -7926,8 +7814,6 @@ class TestNN(NNTestCase):
# d_model = 4
norm = nn.LayerNorm(4)
model = nn.TransformerEncoder(encoder_layer, 2, norm=norm).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
[1.695955, -0.357639, -0.893050, -0.445266]],
@ -7944,8 +7830,6 @@ class TestNN(NNTestCase):
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
model = nn.TransformerEncoder(encoder_layer, 6, norm=norm).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
[1.695955, -0.357639, -0.893051, -0.445265]],
@ -7960,15 +7844,7 @@ class TestNN(NNTestCase):
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
for batch_first in (True, False):
for training in (True, False):
# Fast path requires inference mode.
if training:
cm = contextlib.nullcontext()
else:
cm = torch.no_grad()
with cm:
_test(batch_first, training)
def test_transformerdecoder(self):
def get_a_test_layer(use_cuda, activation, batch_first=False):
@ -13045,20 +12921,17 @@ class TestNNDeviceType(NNTestCase):
output.sum().backward()
self.assertEqualTypeString(output, input)
def _test_module_empty_input(self, module, inp, check_size=True, inference=False):
if not inference:
inp.requires_grad_(True)
def _test_module_empty_input(self, module, inp, check_size=True):
inp.requires_grad_(True)
out = module(inp)
if not inference:
gO = torch.rand_like(out)
out.backward(gO)
gO = torch.rand_like(out)
out.backward(gO)
if check_size:
self.assertEqual(out.size(), inp.size())
if not inference:
for p in module.parameters():
if p.requires_grad:
self.assertEqual(p.grad, torch.zeros_like(p.grad))
self.assertEqual(inp.grad, torch.zeros_like(inp))
for p in module.parameters():
if p.requires_grad:
self.assertEqual(p.grad, torch.zeros_like(p.grad))
self.assertEqual(inp.grad, torch.zeros_like(inp))
def _test_module_empty_inputs(self, module, inputs):
for _inp in inputs:
@ -14627,29 +14500,11 @@ class TestNNDeviceType(NNTestCase):
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
@onlyNativeDeviceTypes
def test_TransformerEncoderLayer_empty(self, device):
for training in (True, False):
for batch_first, input_shape in [(True, (0, 10, 512)),
(False, (10, 0, 512))]:
input = torch.rand(*input_shape, device=device)
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first).to(device)
if not training:
encoder_layer = encoder_layer.eval()
with torch.no_grad():
self._test_module_empty_input(encoder_layer, input, check_size=False, inference=True)
if batch_first and not TEST_WITH_CROSSREF:
with torch.no_grad():
# A NestedTensor with no tensors inside it doesn't have dim 3 (or dim
# 2, for that matter) so it can't hit the fast path, nor can we give a
# result.
with self.assertRaisesRegex(
AssertionError, 'MultiheadAttention does not support NestedTensor outside'):
nt = torch.nested_tensor([], device=device)
self._test_module_empty_input(encoder_layer, nt, check_size=False, inference=True)
nt = torch.nested_tensor([torch.rand(0, 512, device=device)], device=device)
self._test_module_empty_input(encoder_layer, nt, check_size=False, inference=True)
else:
self._test_module_empty_input(encoder_layer, input, check_size=False)
for batch_first, input_shape in [(True, (0, 10, 512)),
(False, (10, 0, 512))]:
input = torch.rand(*input_shape, device=device)
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first).to(device)
self._test_module_empty_input(encoder_layer, input, check_size=False)
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
@onlyNativeDeviceTypes
@ -18080,32 +17935,6 @@ class TestNNDeviceType(NNTestCase):
self.assertEqual(q.size(), out[0].size())
self.assertEqual(dtype, out[0].dtype)
@onlyCUDA
@dtypes(torch.half, torch.float, torch.double)
def test_multihead_attention_dtype_batch_first(self, device, dtype):
embed_dim = 128
num_heads = 8
sl = 10
bs = 8
# With batch_first=True, we have the possibility of hitting
# the native fast path if we call .eval() and enable inference
# mode. Test both paths.
for training in (True, False):
model = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda().to(dtype)
if not training:
model = model.eval()
cm = torch.no_grad()
else:
cm = contextlib.nullcontext()
with cm:
q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
# fast path currently doesn't support weights
out = model(q, k, v, need_weights=False)
self.assertEqual(q.size(), out[0].size())
self.assertEqual(dtype, out[0].dtype)
@dtypesIfCUDA(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []))
@dtypes(torch.float)
def test_Conv2d_naive_groups(self, device, dtype):
@ -20069,7 +19898,7 @@ class TestNNDeviceType(NNTestCase):
m(output_size)(t)
@dtypes(torch.float)
@dtypesIfCUDA(torch.double, torch.float, torch.half)
@dtypesIfCUDA(torch.float, torch.half)
def test_transformerencoderlayer(self, device, dtype):
# this is a deterministic test for TransformerEncoderLayer
d_model = 4
@ -20084,17 +19913,13 @@ class TestNNDeviceType(NNTestCase):
atol = 1e-3
rtol = 1e-2
def _test(training, batch_first, atol, rtol):
for batch_first in (False, True):
def perm_fn(x):
return x.transpose(1, 0) if batch_first else x
model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
batch_first=batch_first, device=device, dtype=dtype)
if not training:
assert dropout == 0
model = model.eval()
# set constant weights of the model
for idx, p in enumerate(model.parameters()):
x = p.data
@ -20111,7 +19936,6 @@ class TestNNDeviceType(NNTestCase):
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
# 0 values are NOT masked. This shouldn't mask anything.
mask = torch.tensor([[0]], device=device) == 1
# TODO: enable fast path for calls with a mask!
result = model(encoder_input, src_key_padding_mask=mask)
self.assertEqual(result.shape, ref_output.shape)
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
@ -20166,7 +19990,6 @@ class TestNNDeviceType(NNTestCase):
[2.422901, 0.024187, -0.606178, -0.074929]]], device=device, dtype=dtype))
self.assertEqual(result.shape, ref_output.shape)
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
# all 0
mask = torch.zeros([2, 5], device=device) == 1
result = model(encoder_input, src_key_padding_mask=mask)
@ -20189,65 +20012,6 @@ class TestNNDeviceType(NNTestCase):
self.assertEqual(result.shape, ref_output.shape)
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
# NestedTensor is only supported for the fast path
# currently, which won't be used if training.
if (batch_first and not training and
('cuda' in str(device) or 'cpu' in str(device)) and not TEST_WITH_CROSSREF):
encoder_input[0][-1] = torch.zeros_like(encoder_input[0][1])
mask = torch.zeros(encoder_input.shape[:-1], device=device, dtype=torch.bool)
mask[0][-1] = True
nt = torch.nested_tensor([encoder_input[0][:-1], encoder_input[1]], device=device)
result = model(nt)
ref_output = torch.tensor(
[
[
[2.4268184, 0.02042419, -0.603311, -0.08476824],
[2.423306, 0.01889652, -0.6057701, -0.08519465],
[2.431538, 0.02078694, -0.5999354, -0.08746159],
[2.4348664, 0.02212971, -0.5975677, -0.08733892],
[2.423133, 0.02097577, -0.60594773, -0.08113337],
],
[
[2.4279876, 0.02121329, -0.60249615, -0.08410317],
[2.4138637, 0.02221113, -0.6124869, -0.07249016],
[2.4251041, 0.01974815, -0.6045152, -0.08483928],
[2.4335563, 0.0218913, -0.59850943, -0.08683228],
[2.4229012, 0.02418739, -0.6061784, -0.07492948],
],
],
device=device, dtype=dtype
)
result = result.to_padded_tensor(0)
ref_output[0][-1] = torch.zeros_like(
ref_output[0][-1], device=device, dtype=dtype
)
result[0][-1] = torch.zeros_like(
result[0][-1], device=device, dtype=dtype
)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
if 'cuda' in device:
if dtype == torch.float:
atol = 2e-4
rtol = 4e-3
else:
atol = 7e-4
rtol = 2e-2
torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
else:
torch.testing.assert_close(result, ref_output)
for batch_first in (True, False):
for training in (True, False):
if training:
cm = contextlib.nullcontext()
else:
# Fast path requires inference mode.
cm = torch.no_grad()
with cm:
_test(batch_first=batch_first, training=training, atol=atol, rtol=rtol)
@dtypes(torch.float)
@dtypesIfCUDA(torch.half, torch.float)
def test_transformerencoderlayer_gelu(self, device, dtype):
@ -20264,15 +20028,12 @@ class TestNNDeviceType(NNTestCase):
atol = 1e-3
rtol = 1e-2
def _test(activation, batch_first, training):
for activation, batch_first in product(('gelu', F.gelu, nn.GELU()), (True, False)):
def perm_fn(x):
return x.transpose(1, 0) if batch_first else x
model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, batch_first=batch_first, device=device, dtype=dtype)
if not training:
assert dropout == 0
model = model.eval()
# set constant weights of the model
for idx, p in enumerate(model.parameters()):
@ -20319,14 +20080,6 @@ class TestNNDeviceType(NNTestCase):
[[2.41383916, 0.02686345, -0.61256377, -0.06380707],
[2.42000277, 0.03800944, -0.60824798, -0.04754947]]], device=device, dtype=dtype))
torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol)
for activation, batch_first, training in product(('gelu', F.gelu, nn.GELU()), (True, False), (True, False)):
# Fast path requires inference mode.
if training:
cm = contextlib.nullcontext()
else:
cm = torch.no_grad()
with cm:
_test(activation=activation, batch_first=batch_first, training=training)
class TestModuleGlobalHooks(TestCase):

View File

@ -88,7 +88,6 @@ includes = [
"aten/src/ATen/native/nested/cuda/*",
"aten/src/ATen/native/sparse/cuda/*",
"aten/src/ATen/native/quantized/cuda/*",
"aten/src/ATen/native/transformers/cuda/*",
"aten/src/THC/*",
"aten/src/ATen/test/*",
# CMakeLists.txt isn't processed by default, but there are a few

View File

@ -1393,8 +1393,6 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/sparse/SparseTensorMath.cpp",
"aten/src/ATen/native/sparse/SparseUnaryOps.cpp",
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
"aten/src/ATen/native/transformers/attention.cpp",
"aten/src/ATen/native/transformers/transformer.cpp",
"aten/src/ATen/native/utils/Factory.cpp",
"aten/src/ATen/native/xnnpack/Activation.cpp",
"aten/src/ATen/native/xnnpack/ChannelShuffle.cpp",

View File

@ -893,29 +893,6 @@ class MultiheadAttention(Module):
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
``forward()`` will use a special optimized implementation if all of the following
conditions are met:
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
restriction will be loosened in the future.)
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- training is disabled (using ``.eval()``)
- dropout is 0
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``batch_first`` is ``True`` and the input is batched
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
nor ``attn_mask`` is passed
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
``query``/``key``/``value`` to represent padding more efficiently than using a
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
will be returned, and an additional speedup proportional to the fraction of the input
that is padding can be expected.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
@ -934,7 +911,6 @@ class MultiheadAttention(Module):
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ['batch_first']
bias_k: Optional[torch.Tensor]
@ -1058,67 +1034,6 @@ class MultiheadAttention(Module):
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
why_not_fast_path = ''
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.dropout:
why_not_fast_path = f"dropout was {self.dropout}, required zero"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input"
elif not query.is_nested and key_padding_mask is not None and attn_mask is not None:
why_not_fast_path = "key_padding_mask and attn_mask were both supplied"
if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
why_not_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
f"The fast path was not hit because {why_not_fast_path}")
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:

View File

@ -289,29 +289,6 @@ class TransformerEncoderLayer(Module):
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
Fast path:
forward() will use a special optimized implementation if all of the following
conditions are met:
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
argument ``requires_grad``
- training is disabled (using ``.eval()``)
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
- norm_first is ``False`` (this restriction may be loosened in the future)
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
nor ``src_key_padding_mask`` is passed
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
unless the caller has manually modified one without modifying the other)
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
passed for ``src`` to represent padding more efficiently than using a padding
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
returned, and an additional speedup proportional to the fraction of the input that
is padding can be expected.
"""
__constants__ = ['batch_first', 'norm_first']
@ -336,25 +313,16 @@ class TransformerEncoderLayer(Module):
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
# We can't test self.activation in forward() in TorchScript,
# so stash some information about it instead.
if activation is F.relu:
self.activation_relu_or_gelu = 1
elif activation is F.gelu:
self.activation_relu_or_gelu = 2
self.activation = _get_activation_fn(activation)
else:
self.activation_relu_or_gelu = 0
self.activation = activation
self.activation = activation
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
@ -368,53 +336,6 @@ class TransformerEncoderLayer(Module):
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
if (not self.norm_first and not self.training and
self.self_attn.batch_first and src.dim() == 3 and self.self_attn._qkv_same_embed_dim and
self.activation_relu_or_gelu and self.norm1.eps == self.norm2.eps and
((src_mask is None and src_key_padding_mask is None)
if src.is_nested
else (src_mask is None or src_key_padding_mask is None))):
tensor_args = (
src,
self.self_attn.in_proj_weight,
self.self_attn.in_proj_bias,
self.self_attn.out_proj.weight,
self.self_attn.out_proj.bias,
self.norm1.weight,
self.norm1.bias,
self.norm2.weight,
self.norm2.bias,
self.linear1.weight,
self.linear1.bias,
self.linear2.weight,
self.linear2.bias,
)
if (not torch.overrides.has_torch_function(tensor_args) and
# We have to use a list comprehension here because TorchScript
# doesn't support generator expressions.
all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]) and
(not torch.is_grad_enabled() or all([not x.requires_grad for x in tensor_args]))):
return torch._transformer_encoder_layer_fwd(
src,
self.self_attn.embed_dim,
self.self_attn.num_heads,
self.self_attn.in_proj_weight,
self.self_attn.in_proj_bias,
self.self_attn.out_proj.weight,
self.self_attn.out_proj.bias,
self.activation_relu_or_gelu == 2,
False, # norm_first, currently not supported
self.norm1.eps,
self.norm1.weight,
self.norm1.bias,
self.norm2.weight,
self.norm2.bias,
self.linear1.weight,
self.linear1.bias,
self.linear2.weight,
self.linear2.bias,
src_mask if src_mask is not None else src_key_padding_mask,
)
x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
@ -567,7 +488,7 @@ def _get_clones(module, N):
return ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":

View File

@ -811,7 +811,6 @@ def preprocessor(
or f.startswith("ATen/native/nested/cuda")
or f.startswith("ATen/native/quantized/cuda")
or f.startswith("ATen/native/sparse/cuda")
or f.startswith("ATen/native/transformers/cuda")
or f.startswith("THC/")
or (f.startswith("THC") and not f.startswith("THCP"))
):