mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 21:59:56 +08:00
Compare commits
12 Commits
ciflow/tru
...
fa_u8_brge
| Author | SHA1 | Date | |
|---|---|---|---|
| f10edf1ecd | |||
| 9d645a6025 | |||
| 676da3c16a | |||
| 46769004e5 | |||
| 9cb324d903 | |||
| 325db8f2a3 | |||
| 97922c4754 | |||
| d72ab195da | |||
| 43b5c4101d | |||
| b640cf15ab | |||
| 67ccb2ce72 | |||
| 4a2715e652 |
@ -133,6 +133,32 @@ struct VecConvert<int32_t, 1, uint8_t, 1> {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
struct VecConvert<int32_t, 1, float, 1> {
|
||||
static inline VectorizedN<int32_t, 1> apply(
|
||||
const VectorizedN<float, 1>& src) {
|
||||
return Vectorized<int32_t>(_mm256_cvttps_epi32(src[0]));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<float, 1, int32_t, 1> {
|
||||
static inline VectorizedN<float, 1> apply(
|
||||
const VectorizedN<int32_t, 1>& src) {
|
||||
return Vectorized<float>(_mm256_cvtepi32_ps(src[0]));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<int16_t, 1, uint8_t, 1> {
|
||||
static inline VectorizedN<int16_t, 1> apply(
|
||||
const VectorizedN<uint8_t, 1>& src) {
|
||||
auto src128 = _mm256_castsi256_si128(src[0]);
|
||||
return Vectorized<int16_t>(_mm256_cvtepu8_epi16(src128));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dst_t, typename src_t>
|
||||
struct VecConvert<
|
||||
dst_t,
|
||||
|
||||
@ -246,6 +246,12 @@ public:
|
||||
return _mm256_floor_pd(values);
|
||||
}
|
||||
Vectorized<double> frac() const;
|
||||
double reduce_add() const {
|
||||
return values[0];
|
||||
}
|
||||
double reduce_max() const {
|
||||
return values[0];
|
||||
}
|
||||
Vectorized<double> neg() const {
|
||||
return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
|
||||
}
|
||||
|
||||
@ -342,6 +342,12 @@ public:
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
float reduce_add() const {
|
||||
return values[0];
|
||||
}
|
||||
float reduce_max() const {
|
||||
return values[0];
|
||||
}
|
||||
Vectorized<float> neg() const {
|
||||
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
|
||||
}
|
||||
|
||||
@ -241,6 +241,12 @@ public:
|
||||
Vectorized<int32_t> abs() const {
|
||||
return _mm256_abs_epi32(values);
|
||||
}
|
||||
int32_t reduce_add() const {
|
||||
return values[0];
|
||||
}
|
||||
int32_t reduce_max() const {
|
||||
return values[0];
|
||||
}
|
||||
Vectorized<int32_t> real() const {
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
#include <iostream>
|
||||
|
||||
namespace at {
|
||||
namespace vec {
|
||||
@ -43,6 +44,9 @@ static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
|
||||
}
|
||||
|
||||
static inline __m256i cvtfp32_bf16(const __m512& src) {
|
||||
// #if defined(CPU_CAPABILITY_AVX512_BF16)
|
||||
// return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src));
|
||||
// #else
|
||||
__m512i value = _mm512_castps_si512(src);
|
||||
__m512i nan = _mm512_set1_epi32(0xffff);
|
||||
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
|
||||
@ -59,6 +63,7 @@ static inline __m256i cvtfp32_bf16(const __m512& src) {
|
||||
// Check NaN before converting back to bf16
|
||||
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
|
||||
return _mm512_cvtusepi32_epi16(t_value);
|
||||
// #endif
|
||||
}
|
||||
|
||||
static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) {
|
||||
|
||||
@ -117,6 +117,49 @@ struct VecConvert<int32_t, 1, uint8_t, 1> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<int32_t, 1, float, 1> {
|
||||
static inline VectorizedN<int32_t, 1> apply(
|
||||
const VectorizedN<float, 1>& src) {
|
||||
return Vectorized<int32_t>(_mm512_cvttps_epi32(src[0]));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<float, 1, int32_t, 1> {
|
||||
static inline VectorizedN<float, 1> apply(
|
||||
const VectorizedN<int32_t, 1>& src) {
|
||||
return Vectorized<float>(_mm512_cvtepi32_ps(src[0]));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<int16_t, 1, uint8_t, 1> {
|
||||
static inline VectorizedN<int16_t, 1> apply(
|
||||
const VectorizedN<uint8_t, 1>& src) {
|
||||
auto src256 = _mm512_castsi512_si256(src[0]);
|
||||
return Vectorized<int16_t>(_mm512_cvtepu8_epi16(src256));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<int8_t, 1, int32_t, 1> {
|
||||
static inline VectorizedN<int8_t, 1> apply(
|
||||
const VectorizedN<int32_t, 1>& src) {
|
||||
auto src128 = _mm512_cvtepi32_epi8(src[0]);
|
||||
return Vectorized<int8_t>(_mm512_castsi128_si512(src128));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecConvert<int8_t, 1, int16_t, 1> {
|
||||
static inline VectorizedN<int8_t, 1> apply(
|
||||
const VectorizedN<int16_t, 1>& src) {
|
||||
auto src256 = _mm512_cvtepi16_epi8(src[0]);
|
||||
return Vectorized<int8_t>(_mm512_castsi256_si512(src256));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dst_t, typename src_t>
|
||||
struct VecConvert<
|
||||
dst_t,
|
||||
|
||||
@ -255,6 +255,12 @@ public:
|
||||
return _mm512_floor_pd(values);
|
||||
}
|
||||
Vectorized<double> frac() const;
|
||||
double reduce_add() const {
|
||||
return values[0];
|
||||
}
|
||||
double reduce_max() const {
|
||||
return values[0];
|
||||
}
|
||||
Vectorized<double> neg() const {
|
||||
return _mm512_xor_pd(_mm512_set1_pd(-0.), values);
|
||||
}
|
||||
|
||||
@ -236,27 +236,27 @@ public:
|
||||
}
|
||||
Vectorized<float> exp_u20() const {
|
||||
// A faster version of exp with ULP=20
|
||||
static __m512 vec_factorial_1 =
|
||||
const __m512 vec_factorial_1 =
|
||||
_mm512_set1_ps(0.999999701f); // 1/factorial(1)
|
||||
static __m512 vec_factorial_2 =
|
||||
const __m512 vec_factorial_2 =
|
||||
_mm512_set1_ps(0.499991506f); // 1/factorial(2)
|
||||
static __m512 vec_factorial_3 =
|
||||
const __m512 vec_factorial_3 =
|
||||
_mm512_set1_ps(0.166676521f); // 1/factorial(3)
|
||||
static __m512 vec_factorial_4 =
|
||||
const __m512 vec_factorial_4 =
|
||||
_mm512_set1_ps(0.0418978221f); // 1/factorial(4)
|
||||
static __m512 vec_factorial_5 =
|
||||
const __m512 vec_factorial_5 =
|
||||
_mm512_set1_ps(0.00828929059f); // 1/factorial(5)
|
||||
static __m512 vec_exp_log2ef =
|
||||
const __m512 vec_exp_log2ef =
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e)
|
||||
static __m512 vec_half = _mm512_set1_ps(0.5f);
|
||||
static __m512 vec_one = _mm512_set1_ps(1.f);
|
||||
static __m512 vec_zero = _mm512_set1_ps(0.f);
|
||||
static __m512 vec_two = _mm512_set1_ps(2.f);
|
||||
static __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2)
|
||||
static __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50));
|
||||
static __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218));
|
||||
static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
|
||||
static int n_mantissa_bits = 23;
|
||||
const __m512 vec_half = _mm512_set1_ps(0.5f);
|
||||
const __m512 vec_one = _mm512_set1_ps(1.f);
|
||||
const __m512 vec_zero = _mm512_set1_ps(0.f);
|
||||
const __m512 vec_two = _mm512_set1_ps(2.f);
|
||||
const __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2)
|
||||
const __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50));
|
||||
const __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218));
|
||||
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
|
||||
const int n_mantissa_bits = 23;
|
||||
|
||||
// exp(x) =
|
||||
// = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem
|
||||
@ -364,6 +364,12 @@ public:
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
float reduce_add() const {
|
||||
return _mm512_reduce_add_ps(values);
|
||||
}
|
||||
float reduce_max() const {
|
||||
return _mm512_reduce_max_ps(values);
|
||||
}
|
||||
Vectorized<float> neg() const {
|
||||
return _mm512_xor_ps(_mm512_set1_ps(-0.f), values);
|
||||
}
|
||||
@ -473,26 +479,26 @@ inline Vectorized<float> Vectorized<float>::frac() const {
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
auto zero_vec = _mm512_set1_epi32(0);
|
||||
auto max = _mm512_max_ps(a, b);
|
||||
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
|
||||
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
|
||||
0xFFFFFFFF));
|
||||
// Exploit the fact that all-ones is a NaN.
|
||||
return _mm512_or_ps(max, isnan);
|
||||
// auto zero_vec = _mm512_set1_epi32(0);
|
||||
return _mm512_max_ps(a, b);
|
||||
// auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
|
||||
// auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
|
||||
// 0xFFFFFFFF));
|
||||
// // Exploit the fact that all-ones is a NaN.
|
||||
// return _mm512_or_ps(max, isnan);
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
||||
// either input is a NaN.
|
||||
template <>
|
||||
Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
auto zero_vec = _mm512_set1_epi32(0);
|
||||
auto min = _mm512_min_ps(a, b);
|
||||
auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
|
||||
auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
|
||||
0xFFFFFFFF));
|
||||
// auto zero_vec = _mm512_set1_epi32(0);
|
||||
return _mm512_min_ps(a, b);
|
||||
// auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
|
||||
// auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
|
||||
// 0xFFFFFFFF));
|
||||
// Exploit the fact that all-ones is a NaN.
|
||||
return _mm512_or_ps(min, isnan);
|
||||
// return _mm512_or_ps(min, isnan);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
||||
@ -267,6 +267,12 @@ public:
|
||||
Vectorized<int32_t> abs() const {
|
||||
return _mm512_abs_epi32(values);
|
||||
}
|
||||
int32_t reduce_add() const {
|
||||
return _mm512_reduce_add_epi32(values);
|
||||
}
|
||||
int32_t reduce_max() const {
|
||||
return _mm512_reduce_max_epi32(values);
|
||||
}
|
||||
Vectorized<int32_t> real() const {
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -542,6 +542,12 @@ public:
|
||||
// We do not use std::round because we would like to round midway numbers to the nearest even integer.
|
||||
return map(at::native::round_impl);
|
||||
}
|
||||
T reduce_add() const {
|
||||
return values[0];
|
||||
}
|
||||
T reduce_max() const {
|
||||
return values[0];
|
||||
}
|
||||
Vectorized<T> sin() const {
|
||||
return map(std::sin);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -14708,7 +14708,7 @@
|
||||
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
|
||||
autogen: _native_multi_head_attention.out
|
||||
|
||||
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
|
||||
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, int q_zp=0, float q_scale=0.0, int k_zp=0, float k_scale=0.0, int v_zp=0, float v_scale=0.0, int a_zp=0, float a_scale=0.0, int o_zp=0, float o_scale=0.0) -> Tensor
|
||||
python_module: nn
|
||||
variants: function
|
||||
autogen: scaled_dot_product_attention.out
|
||||
@ -14722,7 +14722,7 @@
|
||||
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
|
||||
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, int q_zp=0, float q_scale=0.0, int k_zp=0, float k_scale=0.0, int v_zp=0, float v_scale=0.0, int a_zp=0, float a_scale=0.0, int o_zp=0, float o_scale=0.0) -> (Tensor, Tensor)
|
||||
variants: function
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
@ -14732,7 +14732,7 @@
|
||||
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)
|
||||
- func: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None, int q_zp=0, float q_scale=0.0, int k_zp=0, float k_scale=0.0, int v_zp=0, float v_scale=0.0, int a_zp=0, float a_scale=0.0, int o_zp=0, float o_scale=0.0) -> (Tensor output, Tensor logsumexp)
|
||||
dispatch:
|
||||
CPU: _scaled_dot_product_flash_attention_cpu
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include <type_traits>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -48,6 +49,9 @@
|
||||
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward.h>
|
||||
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_native.h>
|
||||
#include <ATen/ops/_softmax.h>
|
||||
#include <ATen/ops/clamp_max.h>
|
||||
#include <ATen/ops/clamp_min.h>
|
||||
#include <ATen/ops/round.h>
|
||||
#include <ATen/ops/_transform_bias_rescale_qkv.h>
|
||||
#include <ATen/ops/_transform_bias_rescale_qkv_native.h>
|
||||
#include <ATen/ops/_triton_multi_head_attention_native.h>
|
||||
@ -503,7 +507,7 @@ inline void validate_sdpa_input(
|
||||
query_.dim(), " key.dim: ", key.dim(), " and value.dim: ", value.dim(), " instead.");
|
||||
if (attn_mask_.has_value()){
|
||||
auto mask_dtype = attn_mask_->dtype();
|
||||
TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == at::kFloat || mask_dtype == query_.dtype(),
|
||||
TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == at::kFloat || mask_dtype == at::kBFloat16 || mask_dtype == query_.dtype(),
|
||||
"Expected attn_mask dtype to be bool or float or to match query dtype, but got attn_mask.dtype: ",
|
||||
mask_dtype, " and query.dtype: ", query_.dtype(), " instead.");
|
||||
TORCH_CHECK(
|
||||
@ -647,7 +651,17 @@ Tensor scaled_dot_product_attention(
|
||||
const std::optional<Tensor>& attn_mask_,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
std::optional<double> scale) {
|
||||
std::optional<double> scale,
|
||||
long q_zp,
|
||||
double q_scale,
|
||||
long k_zp,
|
||||
double k_scale,
|
||||
long v_zp,
|
||||
double v_scale,
|
||||
long a_zp,
|
||||
double a_scale,
|
||||
long o_zp,
|
||||
double o_scale) {
|
||||
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
||||
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
|
||||
if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) {
|
||||
@ -677,7 +691,17 @@ Tensor scaled_dot_product_attention(
|
||||
}
|
||||
// For the CPU case we do not need to pad the last dim
|
||||
return std::get<0>(at::_scaled_dot_product_flash_attention_for_cpu(
|
||||
query_, key, value, dropout_p, is_causal, attn_mask, scale));
|
||||
query_, key, value, dropout_p, is_causal, attn_mask, scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale));
|
||||
}
|
||||
case sdp::SDPBackend::efficient_attention: {
|
||||
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
|
||||
@ -693,7 +717,7 @@ Tensor scaled_dot_product_attention(
|
||||
query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale);
|
||||
return std::get<0>(out_lse_softmax);
|
||||
}
|
||||
case sdp::SDPBackend::math:
|
||||
case sdp::SDPBackend::math: {
|
||||
return std::get<0>(at::_scaled_dot_product_attention_math(
|
||||
query_,
|
||||
key,
|
||||
@ -702,7 +726,18 @@ Tensor scaled_dot_product_attention(
|
||||
dropout_p,
|
||||
is_causal,
|
||||
c10::nullopt, /*dropout_mask*/
|
||||
scale));
|
||||
scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale));
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -714,34 +749,58 @@ Tensor scaled_dot_product_attention(
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
|
||||
const std::optional<Tensor>& dropout_mask, std::optional<double> scale) {
|
||||
const std::optional<Tensor>& dropout_mask, std::optional<double> scale,
|
||||
long q_zp,
|
||||
double q_scale,
|
||||
long k_zp,
|
||||
double k_scale,
|
||||
long v_zp,
|
||||
double v_scale,
|
||||
long a_zp,
|
||||
double a_scale,
|
||||
long o_zp,
|
||||
double o_scale) {
|
||||
// std::cout << "enter _scaled_dot_product_attention_math" << std::endl;
|
||||
C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
|
||||
auto dtype = query_.scalar_type();
|
||||
if (query_.is_nested() || key.is_nested() || value.is_nested()) {
|
||||
TORCH_CHECK(
|
||||
query_.is_contiguous() && key.is_contiguous() &&
|
||||
value.is_contiguous(),
|
||||
"scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
|
||||
}
|
||||
auto q = query_;
|
||||
auto k = key;
|
||||
auto v = value;
|
||||
// dequantize
|
||||
if (dtype == ScalarType::Byte) {
|
||||
q = (query_.to(at::kFloat) - q_zp) * q_scale;
|
||||
k = (key.to(at::kFloat) - k_zp) * k_scale;
|
||||
v = (value.to(at::kFloat) - v_zp) * v_scale;
|
||||
}
|
||||
auto attn_mask = attn_mask_;
|
||||
if (attn_mask.has_value()) {
|
||||
*attn_mask = (*attn_mask).to(at::kFloat);
|
||||
}
|
||||
// Naive, composite implementation defined here.
|
||||
|
||||
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
|
||||
const auto scaling_factor = sdp::calculate_scale(query_, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt();
|
||||
const auto scaling_factor = sdp::calculate_scale(q, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt();
|
||||
|
||||
const auto query = query_ * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor);
|
||||
q = q * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor);
|
||||
if (is_causal) {
|
||||
TORCH_CHECK(!attn_mask.has_value(),
|
||||
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
|
||||
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
|
||||
TORCH_CHECK(!q.is_nested() && !k.is_nested(),
|
||||
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
|
||||
|
||||
// Replace attn_mask with causal mask; lower triangular elements take part in attention.
|
||||
const auto L = query.sym_size(-2), S = key.sym_size(-2);
|
||||
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
|
||||
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
|
||||
const auto L = q.sym_size(-2), S = k.sym_size(-2);
|
||||
attn_mask = at::ones_symint({L, S}, q.options().dtype(at::kBool)).tril();
|
||||
attn_mask = convert_boolean_attn_mask(attn_mask, q.dtype());
|
||||
}
|
||||
auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);
|
||||
auto attn = at::matmul(q, k.transpose(-2, -1) * scaling_factor);
|
||||
if (attn_mask.has_value()) {
|
||||
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
|
||||
attn = attn.add(*attn_mask);
|
||||
@ -757,13 +816,25 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
|
||||
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
|
||||
auto dropout_scaling = 1.0 / (1 - dropout_p);
|
||||
return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
|
||||
return std::make_tuple(at::matmul(attn, v * dropout_scaling), attn);
|
||||
} else {
|
||||
attn = at::dropout(attn, dropout_p, true);
|
||||
}
|
||||
}
|
||||
if (dtype == ScalarType::Byte) {
|
||||
attn = at::clamp_max(
|
||||
at::clamp_min(at::round(attn / a_scale) + a_zp, 0), 255
|
||||
);
|
||||
attn = (attn - a_zp) * a_scale;
|
||||
}
|
||||
|
||||
return std::make_tuple(at::matmul(attn, value), attn);
|
||||
auto res = at::matmul(attn, v);
|
||||
if (dtype == ScalarType::Byte) {
|
||||
res = at::clamp_max(
|
||||
at::clamp_min(at::round(res / o_scale) + o_zp, 0), 255
|
||||
).to(at::kByte);
|
||||
}
|
||||
return std::make_tuple(res, attn);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
@ -774,15 +845,26 @@ _scaled_dot_product_flash_attention_cpu(
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
const std::optional<Tensor>& attn_mask,
|
||||
std::optional<double> scale) {
|
||||
std::optional<double> scale,
|
||||
long q_zp,
|
||||
double q_scale,
|
||||
long k_zp,
|
||||
double k_scale,
|
||||
long v_zp,
|
||||
double v_scale,
|
||||
long a_zp,
|
||||
double a_scale,
|
||||
long o_zp,
|
||||
double o_scale) {
|
||||
// std::cout << "enter _scaled_dot_product_flash_attention_cpu" << std::endl;
|
||||
const auto dtype = query.scalar_type();
|
||||
int64_t batchSize = query.size(0);
|
||||
int64_t qSize = query.size(2);
|
||||
int64_t num_head = query.size(1);
|
||||
int64_t headSize = query.size(3);
|
||||
|
||||
TORCH_CHECK(c10::isFloatingType(dtype),
|
||||
"scaled_dot_product_attention_flash_attention: Expected data type in FP32, FP64, BF16, FP16, but got ", dtype, " instead.");
|
||||
TORCH_CHECK(c10::isFloatingType(dtype) || dtype == ScalarType::Byte,
|
||||
"scaled_dot_product_attention_flash_attention: Expected data type in FP32, FP64, BF16, FP16, U8, but got ", dtype, " instead.");
|
||||
TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4,
|
||||
"scaled_dot_product_attention_flash_attention: Accept only 4 dims inputs shape of {B, H, T, K}");
|
||||
TORCH_CHECK(dropout_p == 0.0,
|
||||
@ -791,6 +873,7 @@ _scaled_dot_product_flash_attention_cpu(
|
||||
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
|
||||
TORCH_CHECK(!attn_mask.has_value() ||
|
||||
attn_mask.value().scalar_type() == at::kFloat ||
|
||||
attn_mask.value().scalar_type() == at::kBFloat16 ||
|
||||
dtype == attn_mask.value().scalar_type(),
|
||||
"scaled_dot_product_attention_flash_attention: Attention mask is the same data type as query");
|
||||
TORCH_CHECK(!attn_mask.has_value() ||
|
||||
@ -798,16 +881,27 @@ _scaled_dot_product_flash_attention_cpu(
|
||||
"scaled_dot_product_attention_flash_attention: Attention mask dim in {2, 4}");
|
||||
|
||||
at::Tensor output = at::empty({batchSize, qSize, num_head, headSize}, query.options());
|
||||
const auto accumulate_dtype = toOpMathType(dtype);
|
||||
const auto accumulate_dtype = dtype == ScalarType::Byte ? at::kFloat : toOpMathType(dtype);
|
||||
at::Tensor logsumexp = at::empty({batchSize, qSize, num_head},
|
||||
query.options().dtype(accumulate_dtype));
|
||||
|
||||
flash_attention_kernel(kCPU, output, logsumexp,
|
||||
query, key, value, dropout_p, is_causal, attn_mask, scale);
|
||||
query, key, value, dropout_p, is_causal, attn_mask, scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale);
|
||||
|
||||
output = output.transpose(1, 2);
|
||||
logsumexp = logsumexp.transpose(1, 2);
|
||||
|
||||
// std::cout << "output: " << output << std::endl;
|
||||
return std::make_tuple(std::move(output), std::move(logsumexp));
|
||||
}
|
||||
|
||||
|
||||
@ -54,7 +54,17 @@ using flash_attention_fn = void (*)(
|
||||
const Tensor& query, const Tensor& key, const Tensor& value,
|
||||
double dropout_p, bool is_causal,
|
||||
std::optional<Tensor> attn_mask,
|
||||
std::optional<double> scale);
|
||||
std::optional<double> scale,
|
||||
long q_zp,
|
||||
double q_scale,
|
||||
long k_zp,
|
||||
double k_scale,
|
||||
long v_zp,
|
||||
double v_scale,
|
||||
long a_zp,
|
||||
double a_scale,
|
||||
long o_zp,
|
||||
double o_scale);
|
||||
|
||||
using flash_attention_backward_fn = void (*)(
|
||||
const Tensor& grad_q, const Tensor& grad_k,
|
||||
|
||||
@ -34,7 +34,7 @@ bool check_head_dim_size_cpp(sdp_params const& params, bool debug) {
|
||||
|
||||
bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
|
||||
constexpr auto cpp_supported_flash_dtypes =
|
||||
array_of<at::ScalarType>(at::kFloat, at::kDouble, at::kBFloat16, at::kHalf);
|
||||
array_of<at::ScalarType>(at::kFloat, at::kDouble, at::kBFloat16, at::kHalf, at::kByte);
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be run
|
||||
constexpr auto constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
||||
|
||||
@ -273,7 +273,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
if(MSVC)
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512")
|
||||
else(MSVC)
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx512f -mavx512bw -mavx512vl -mavx512dq -mfma")
|
||||
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx512f -mavx512bw -mavx512vl -mavx512dq -mfma")# -mavx512bf16")
|
||||
endif(MSVC)
|
||||
endif(CXX_AVX512_FOUND)
|
||||
|
||||
|
||||
@ -83,11 +83,11 @@ IF(NOT MKLDNN_FOUND)
|
||||
|
||||
FIND_PACKAGE(BLAS)
|
||||
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl)
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} ${MKLDNN_ROOT}/src PATH_SUFFIXES include/oneapi/dnnl)
|
||||
IF(NOT MKLDNN_INCLUDE_DIR)
|
||||
MESSAGE("MKLDNN_INCLUDE_DIR not found")
|
||||
EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT})
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} ${MKLDNN_ROOT}/src PATH_SUFFIXES include)
|
||||
ENDIF(NOT MKLDNN_INCLUDE_DIR)
|
||||
IF(BUILD_ONEDNN_GRAPH)
|
||||
FIND_PATH(LLGA_INCLUDE_DIR dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include/oneapi/dnnl)
|
||||
@ -135,6 +135,7 @@ IF(NOT MKLDNN_FOUND)
|
||||
SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
|
||||
SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE)
|
||||
SET(DNNL_GRAPH_CPU_RUNTIME ${MKLDNN_CPU_RUNTIME} CACHE STRING "" FORCE)
|
||||
SET(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE)
|
||||
|
||||
IF(BUILD_ONEDNN_GRAPH)
|
||||
SET(DNNL_GRAPH_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
|
||||
@ -155,6 +156,7 @@ IF(NOT MKLDNN_FOUND)
|
||||
ENDIF()
|
||||
|
||||
ADD_SUBDIRECTORY(${MKLDNN_ROOT})
|
||||
INCLUDE_DIRECTORIES(${MKLDNN_ROOT}/src)
|
||||
|
||||
IF(NOT TARGET dnnl)
|
||||
MESSAGE("Failed to include MKL-DNN target")
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.debug_utils import aot_graph_input_parser
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_cuda import (
|
||||
@ -16,6 +17,12 @@ from torch.testing._internal.common_cuda import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
|
||||
X86InductorQuantizer,
|
||||
)
|
||||
|
||||
|
||||
def checkpoint_wrapper(fn):
|
||||
@ -24,7 +31,64 @@ def checkpoint_wrapper(fn):
|
||||
|
||||
return inner
|
||||
|
||||
# For int8 sdpa pattern match
|
||||
def _generate_qdq_quantized_model(mod, inputs, quantizer):
|
||||
with torch.no_grad():
|
||||
export_model = capture_pre_autograd_graph(mod, inputs)
|
||||
prepare_model = prepare_pt2e(export_model, quantizer)
|
||||
prepare_model(*inputs)
|
||||
convert_model = convert_pt2e(prepare_model)
|
||||
torch.ao.quantization.move_exported_model_to_eval(convert_model)
|
||||
return convert_model
|
||||
class SelfAttnLikeModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
has_mask,
|
||||
num_attention_heads=None,
|
||||
attention_head_size=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
|
||||
self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
|
||||
self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
assert num_attention_heads is not None
|
||||
assert attention_head_size is not None
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_size = attention_head_size
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size)
|
||||
self.dropout = torch.nn.Dropout(0)
|
||||
self.has_mask=has_mask
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute([0, 2, 1, 3])
|
||||
|
||||
def forward(self, x, mask):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
q = self.transpose_for_scores(q)
|
||||
k = self.transpose_for_scores(k)
|
||||
v = self.transpose_for_scores(v)
|
||||
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
|
||||
if self.has_mask:
|
||||
scores = scores + mask
|
||||
attention = self.softmax(scores)
|
||||
attention = self.dropout(attention)
|
||||
context_layer = torch.matmul(attention, v)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
context_layer = context_layer.view(
|
||||
context_layer.size()[:-2] + (self.all_head_size,)
|
||||
)
|
||||
return self.dense(context_layer)
|
||||
class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
use_static_shapes = True
|
||||
|
||||
@ -72,10 +136,12 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
|
||||
dropout_arg = [training] if has_dropout else []
|
||||
torch.manual_seed(1234)
|
||||
# breakpoint()
|
||||
result1 = dot_prod_attention(*(args1 + dropout_arg))
|
||||
|
||||
counters.clear()
|
||||
torch.manual_seed(1234)
|
||||
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): #FLASH_ATTENTION, MATH
|
||||
result2, source_code = run_and_get_code(
|
||||
torch.compile(dot_prod_attention, fullgraph=True),
|
||||
*(args2 + dropout_arg),
|
||||
@ -92,6 +158,8 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
|
||||
# some tests configured with very low dropout where we still want to check equality
|
||||
if not has_dropout or override_check_equal:
|
||||
# print("result1: ", result1)
|
||||
# print("result2: ", result2)
|
||||
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
|
||||
|
||||
if training:
|
||||
@ -960,6 +1028,38 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
check_train=False,
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@config.patch({"freezing": True})
|
||||
def _test_sdpa_rewriter_20_to_23(self):
|
||||
if not self.use_static_shapes:
|
||||
self.skipTest("Causes IndexError. TODO: investigate")
|
||||
|
||||
# pattern is different for bs=1
|
||||
for dtype, has_mask, bs in itertools.product([torch.float32, torch.bfloat16], [True, False], [56, 1]):
|
||||
mod = SelfAttnLikeModule(
|
||||
input_dim=64 * 16,
|
||||
has_mask=has_mask,
|
||||
num_attention_heads=16,
|
||||
attention_head_size=64,
|
||||
).eval()
|
||||
maybe_autocast = (
|
||||
torch.cpu.amp.autocast()
|
||||
if dtype == torch.bfloat16
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
inputs = [
|
||||
torch.randn((bs, 384, 64 * 16), device=self.device, dtype=dtype),
|
||||
torch.randn((bs, 1, 1, 384), device=self.device) if has_mask else None,
|
||||
]
|
||||
with torch.no_grad(), maybe_autocast:
|
||||
quantizer = X86InductorQuantizer()
|
||||
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) #is_reduce=True
|
||||
quantizer.set_function_type_qconfig(
|
||||
torch.matmul, quantizer.get_global_quantization_config()
|
||||
)
|
||||
convert_model = _generate_qdq_quantized_model(mod, inputs, quantizer)
|
||||
self._check_common(convert_model, args1=inputs, check_train=False, atol=1.0)
|
||||
|
||||
|
||||
if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION:
|
||||
|
||||
@ -1088,6 +1188,9 @@ if HAS_CPU:
|
||||
test_sdpa_rewriter_19_cpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19
|
||||
)
|
||||
test_sdpa_rewriter_20_to_23_cpu = functools.partialmethod(
|
||||
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_20_to_23
|
||||
)
|
||||
|
||||
class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests):
|
||||
use_static_shapes = False
|
||||
|
||||
@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, ANY
|
||||
import math
|
||||
import itertools
|
||||
import torch.optim as optim
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
|
||||
from typing import List, Tuple, Optional
|
||||
@ -1915,18 +1914,205 @@ class TestSDPA(NNTestCase):
|
||||
self.assertEqual(grad_k_actual, grad_k_ref, atol=atol, rtol=rtol)
|
||||
self.assertEqual(grad_v_actual, grad_v_ref, atol=atol, rtol=rtol)
|
||||
|
||||
def trace_handler(self, prof):
|
||||
print(prof.key_averages().table(
|
||||
sort_by="self_cpu_time_total", row_limit=-1))
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
|
||||
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
|
||||
@parametrize("batch_size", [2])
|
||||
@parametrize("q_seq_len", [267])
|
||||
@parametrize("kv_seq_len", [514])
|
||||
@parametrize("n_head", [3])
|
||||
@parametrize("head_dim", [8])
|
||||
@parametrize("mask_dim", [2, 4])
|
||||
@parametrize("bool_mask", [0, 1])
|
||||
@parametrize("train", [True, False])
|
||||
def test_scaled_dot_product_fused_attention_mask_vs_math_cpu(
|
||||
@parametrize("dtype", [torch.float32])
|
||||
# @parametrize("batch_size", [120])
|
||||
# @parametrize("q_seq_len", [384])
|
||||
# @parametrize("kv_seq_len", [384])
|
||||
# @parametrize("n_head", [16])
|
||||
@parametrize("batch_size", [224])
|
||||
@parametrize("q_seq_len", [197])
|
||||
@parametrize("kv_seq_len", [197])
|
||||
@parametrize("n_head", [12])
|
||||
@parametrize("head_dim", [64])
|
||||
@parametrize("mask_dim", [4])
|
||||
@parametrize("bool_mask", [0])
|
||||
@parametrize("train", [False])
|
||||
def test_scaled_dot_product_fused_attention_mask_vs_math_cpu_u8(
|
||||
self,
|
||||
device,
|
||||
fused_kernel,
|
||||
dtype,
|
||||
batch_size,
|
||||
q_seq_len,
|
||||
kv_seq_len,
|
||||
n_head,
|
||||
head_dim,
|
||||
mask_dim,
|
||||
bool_mask,
|
||||
train,
|
||||
):
|
||||
import time
|
||||
torch.set_printoptions(threshold=10_000)
|
||||
tol = Tolerances(3.0, 5e-6) # 2 for bf16 mask, 3 for big bs
|
||||
# tol = Tolerances(1e-5, 5e-6)
|
||||
# if dtype is torch.bfloat16:
|
||||
# tol = Tolerances(5e-2, 5e-2)
|
||||
# if dtype is torch.float16:
|
||||
# tol = Tolerances(1e-2, 1e-2)
|
||||
# for mask_shape in itertools.product(
|
||||
# [q_seq_len, 1], [kv_seq_len, 1]
|
||||
# ) if mask_dim == 2 else itertools.product(
|
||||
# [batch_size, 1], [n_head, 1], [q_seq_len, 1], [kv_seq_len, 1]
|
||||
# ):
|
||||
q_zp = 127
|
||||
q_scale = 1.7907238006591797
|
||||
k_zp = 125
|
||||
k_scale = 1.8039721250534058
|
||||
v_zp = 127
|
||||
v_scale = 1.839004635810852
|
||||
a_zp = 120
|
||||
a_scale = 0.003919653594493866
|
||||
o_zp = 128
|
||||
o_scale = 1.8191684484481812
|
||||
mask_shape = [batch_size, 1, 1, kv_seq_len]
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
|
||||
q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
|
||||
kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
|
||||
|
||||
mask_dtypes = [torch.bfloat16] #[None, torch.bfloat16, torch.float]
|
||||
for mask_dtype in mask_dtypes:
|
||||
q = make_tensor(q_shape) * 100
|
||||
k = make_tensor(kv_shape) * 100
|
||||
v = make_tensor(kv_shape) * 100
|
||||
q = q.to(torch.uint8)
|
||||
k = k.to(torch.uint8)
|
||||
v = v.to(torch.uint8)
|
||||
q2, k2, v2 = q.clone(), k.clone(), v.clone()
|
||||
|
||||
# if train:
|
||||
# q.requires_grad_(True)
|
||||
# k.requires_grad_(True)
|
||||
# v.requires_grad_(True)
|
||||
# q2.requires_grad_(True)
|
||||
# k2.requires_grad_(True)
|
||||
# v2.requires_grad_(True)
|
||||
|
||||
# if dtype in [torch.bfloat16, torch.float16]:
|
||||
# q2, k2, v2 = q2.float(), k2.float(), v2.float()
|
||||
# (B, nh, T, hs)
|
||||
q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
attn_mask = torch.randn(mask_shape, dtype=mask_dtype, device=device) if mask_dtype else None
|
||||
q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
|
||||
with sdpa_kernel(backends=[fused_kernel]):
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
|
||||
q_zp=q_zp, q_scale=q_scale,
|
||||
k_zp=k_zp, k_scale=k_scale,
|
||||
v_zp=v_zp, v_scale=v_scale,
|
||||
a_zp=a_zp, a_scale=a_scale,
|
||||
o_zp=o_zp, o_scale=o_scale)
|
||||
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
||||
if not bool_mask and dtype in [torch.bfloat16, torch.float16]:
|
||||
attn_mask = attn_mask.float()
|
||||
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
|
||||
q_zp=q_zp, q_scale=q_scale,
|
||||
k_zp=k_zp, k_scale=k_scale,
|
||||
v_zp=v_zp, v_scale=v_scale,
|
||||
a_zp=a_zp, a_scale=a_scale,
|
||||
o_zp=o_zp, o_scale=o_scale)
|
||||
|
||||
## for debugging
|
||||
# q2 = q2.to(torch.float)
|
||||
# k2 = k2.to(torch.float)
|
||||
# v2 = v2.to(torch.float)
|
||||
# scale_factor = 1 / math.sqrt(q2.size(-1))
|
||||
# attn = q2 @ k2.transpose(-2, -1)
|
||||
# # print("[math] qk: ", attn)
|
||||
# attn = attn * scale_factor
|
||||
# attn_max = attn.max(dim=-1, keepdim=True).values
|
||||
# attn = attn - attn_max
|
||||
# attn = torch.exp(attn)
|
||||
# attn_sum = torch.sum(attn, dim=-1, keepdim=True)
|
||||
# attn = attn / attn_sum
|
||||
# math_ref = attn @ v2
|
||||
# math_ref= math_ref.to(torch.uint8)
|
||||
|
||||
if dtype in [torch.bfloat16, torch.float16]:
|
||||
math_ref = math_ref.to(dtype)
|
||||
|
||||
# print("actual", actual)
|
||||
# print("math_ref", math_ref)
|
||||
|
||||
self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
|
||||
iter_n = 20
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=2,
|
||||
warmup=iter_n,
|
||||
active=20),
|
||||
on_trace_ready=self.trace_handler
|
||||
) as prof:
|
||||
for _ in range(iter_n + 22):
|
||||
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
# s = time.time()
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
|
||||
q_zp=q_zp, q_scale=q_scale,
|
||||
k_zp=k_zp, k_scale=k_scale,
|
||||
v_zp=v_zp, v_scale=v_scale,
|
||||
a_zp=a_zp, a_scale=a_scale,
|
||||
o_zp=o_zp, o_scale=o_scale)
|
||||
# print((time.time()-s)*1000)
|
||||
# print("iter",_, (time.time()-s)*1000)
|
||||
prof.step()
|
||||
|
||||
# for _ in range(100):
|
||||
# # with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
# # s = time.time()
|
||||
# print("\n[iter]:", _, flush=True)
|
||||
# actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
# q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
|
||||
# q_zp=q_zp, q_scale=q_scale,
|
||||
# k_zp=k_zp, k_scale=k_scale,
|
||||
# v_zp=v_zp, v_scale=v_scale,
|
||||
# a_zp=a_zp, a_scale=a_scale,
|
||||
# o_zp=o_zp, o_scale=o_scale)
|
||||
# # print((time.time()-s)*1000)
|
||||
# # print("iter",_, (time.time()-s)*1000)
|
||||
# # prof.step()
|
||||
|
||||
# if train:
|
||||
# actual.sum().backward()
|
||||
# math_ref.sum().backward()
|
||||
|
||||
# grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
|
||||
# grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
|
||||
|
||||
# self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
# self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
# self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
|
||||
@parametrize("dtype", [torch.bfloat16])
|
||||
# @parametrize("batch_size", [120])
|
||||
# @parametrize("q_seq_len", [384])
|
||||
# @parametrize("kv_seq_len", [384])
|
||||
# @parametrize("n_head", [16])
|
||||
@parametrize("batch_size", [224])
|
||||
@parametrize("q_seq_len", [197])
|
||||
@parametrize("kv_seq_len", [197])
|
||||
@parametrize("n_head", [12])
|
||||
@parametrize("head_dim", [64])
|
||||
@parametrize("mask_dim", [4])
|
||||
@parametrize("bool_mask", [0])
|
||||
@parametrize("train", [False])
|
||||
def test_scaled_dot_product_fused_attention_mask_vs_math_cpu_bf16(
|
||||
self,
|
||||
device,
|
||||
fused_kernel,
|
||||
@ -1945,65 +2131,74 @@ class TestSDPA(NNTestCase):
|
||||
tol = Tolerances(5e-2, 5e-2)
|
||||
if dtype is torch.float16:
|
||||
tol = Tolerances(1e-2, 1e-2)
|
||||
for mask_shape in itertools.product(
|
||||
[q_seq_len, 1], [kv_seq_len, 1]
|
||||
) if mask_dim == 2 else itertools.product(
|
||||
[batch_size, 1], [n_head, 1], [q_seq_len, 1], [kv_seq_len, 1]
|
||||
):
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
|
||||
q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
|
||||
kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
|
||||
q = make_tensor(q_shape)
|
||||
k = make_tensor(kv_shape)
|
||||
v = make_tensor(kv_shape)
|
||||
q2, k2, v2 = q.clone(), k.clone(), v.clone()
|
||||
# for mask_shape in itertools.product(
|
||||
# [q_seq_len, 1], [kv_seq_len, 1]
|
||||
# ) if mask_dim == 2 else itertools.product(
|
||||
# [batch_size, 1], [n_head, 1], [q_seq_len, 1], [kv_seq_len, 1]
|
||||
# ):
|
||||
mask_shape = [batch_size, 1, 1, kv_seq_len]
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
|
||||
q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
|
||||
kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
|
||||
q = make_tensor(q_shape)
|
||||
k = make_tensor(kv_shape)
|
||||
v = make_tensor(kv_shape)
|
||||
q2, k2, v2 = q.clone(), k.clone(), v.clone()
|
||||
|
||||
if train:
|
||||
q.requires_grad_(True)
|
||||
k.requires_grad_(True)
|
||||
v.requires_grad_(True)
|
||||
q2.requires_grad_(True)
|
||||
k2.requires_grad_(True)
|
||||
v2.requires_grad_(True)
|
||||
if train:
|
||||
q.requires_grad_(True)
|
||||
k.requires_grad_(True)
|
||||
v.requires_grad_(True)
|
||||
q2.requires_grad_(True)
|
||||
k2.requires_grad_(True)
|
||||
v2.requires_grad_(True)
|
||||
|
||||
if dtype in [torch.bfloat16, torch.float16]:
|
||||
q2, k2, v2 = q2.float(), k2.float(), v2.float()
|
||||
# (B, nh, T, hs)
|
||||
q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
if bool_mask:
|
||||
attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
|
||||
else:
|
||||
attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
|
||||
q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
if dtype in [torch.bfloat16, torch.float16]:
|
||||
q2, k2, v2 = q2.float(), k2.float(), v2.float()
|
||||
# (B, nh, T, hs)
|
||||
q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
# if bool_mask:
|
||||
# attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
|
||||
# else:
|
||||
attn_mask = torch.randn(mask_shape, dtype=torch.bfloat16, device=device)
|
||||
q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
|
||||
|
||||
with sdpa_kernel(backends=[fused_kernel]):
|
||||
# with sdpa_kernel(backends=[fused_kernel]):
|
||||
# actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
# q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
# with sdpa_kernel(backends=[SDPBackend.MATH]):
|
||||
# if not bool_mask and dtype in [torch.bfloat16, torch.float16]:
|
||||
# attn_mask = attn_mask.float()
|
||||
# math_ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
# q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
# if dtype in [torch.bfloat16, torch.float16]:
|
||||
# math_ref = math_ref.to(dtype)
|
||||
|
||||
# # # print("actual", actual)
|
||||
# # # print("math_ref", math_ref)
|
||||
|
||||
# self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
|
||||
iter_n = 20
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=2,
|
||||
warmup=iter_n,
|
||||
active=20),
|
||||
on_trace_ready=self.trace_handler
|
||||
) as prof:
|
||||
for _ in range(iter_n + 22):
|
||||
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
||||
if not bool_mask and dtype in [torch.bfloat16, torch.float16]:
|
||||
attn_mask = attn_mask.float()
|
||||
math_ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
if dtype in [torch.bfloat16, torch.float16]:
|
||||
math_ref = math_ref.to(dtype)
|
||||
|
||||
self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
|
||||
if train:
|
||||
actual.sum().backward()
|
||||
math_ref.sum().backward()
|
||||
|
||||
grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
|
||||
grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
|
||||
|
||||
self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
|
||||
prof.step()
|
||||
|
||||
@parametrize("kernel", [SDPBackend.MATH])
|
||||
def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
|
||||
|
||||
@ -2839,7 +2839,7 @@
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
- name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)
|
||||
- name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None, int q_zp=0, float q_scale=0.0, int k_zp=0, float k_scale=0.0, int v_zp=0, float v_scale=0.0, int a_zp=0, float a_scale=0.0, int o_zp=0, float o_scale=0.0) -> (Tensor output, Tensor logsumexp)
|
||||
output_differentiability: [True, False]
|
||||
query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale)
|
||||
|
||||
|
||||
@ -4763,12 +4763,22 @@ def scaled_dot_product_flash_attention_for_cpu(
|
||||
*,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
q_zp: int = 0,
|
||||
q_scale: float = 0.0,
|
||||
k_zp: int = 0,
|
||||
k_scale: float = 0.0,
|
||||
v_zp: int = 0,
|
||||
v_scale: float = 0.0,
|
||||
a_zp: int = 0,
|
||||
a_scale: float = 0.0,
|
||||
o_zp: int = 0,
|
||||
o_scale: float = 0.0,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
dtype = query.dtype
|
||||
torch._check(
|
||||
torch.is_floating_point(query),
|
||||
lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",
|
||||
)
|
||||
# torch._check(
|
||||
# torch.is_floating_point(query),
|
||||
# lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",
|
||||
# )
|
||||
torch._check(
|
||||
query.dim() == 4 and key.dim() == 4 and value.dim() == 4,
|
||||
lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}",
|
||||
@ -4790,6 +4800,16 @@ def scaled_dot_product_flash_attention_for_cpu(
|
||||
is_causal=is_causal,
|
||||
dropout_mask=None,
|
||||
scale=scale,
|
||||
q_zp=q_zp,
|
||||
q_scale=q_scale,
|
||||
k_zp=k_zp,
|
||||
k_scale=k_scale,
|
||||
v_zp=v_zp,
|
||||
v_scale=v_scale,
|
||||
a_zp=a_zp,
|
||||
a_scale=a_scale,
|
||||
o_zp=o_zp,
|
||||
o_scale=o_scale,
|
||||
)
|
||||
# Why this change?
|
||||
# In pre-dispatch export scaled_dot_product_attention is executed via
|
||||
|
||||
@ -548,6 +548,359 @@ def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p):
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_pattern_20(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
inv_scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale,
|
||||
dropout,
|
||||
):
|
||||
# int8-mix-fp32 QUANTIZED SDPA with mask
|
||||
q = query.permute([0, 2, 1, 3])
|
||||
q = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
q, float(q_scale), int(q_zp), 0, 255, torch.uint8
|
||||
)
|
||||
k = key.permute([0, 2, 1, 3]).transpose(-2, -1)
|
||||
k = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
k, float(k_scale), int(k_zp), 0, 255, torch.uint8
|
||||
)
|
||||
v = value.permute([0, 2, 1, 3])
|
||||
v = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
v, float(v_scale), int(v_zp), 0, 255, torch.uint8
|
||||
)
|
||||
a = torch.nn.functional.dropout(
|
||||
(torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1),
|
||||
dropout,
|
||||
)
|
||||
qa = torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
a, float(a_scale), int(a_zp), 0, 255, torch.uint8
|
||||
)
|
||||
a = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
qa, float(a_scale), int(a_zp), 0, 255, torch.uint8
|
||||
)
|
||||
o = a.matmul(v)
|
||||
o = o.permute(0, 2, 1, 3).contiguous()
|
||||
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
o, float(o_scale), int(o_zp), 0, 255, torch.uint8
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_replacement_20(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
inv_scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale,
|
||||
dropout,
|
||||
):
|
||||
counters["inductor"]["fuse_attention"] += 1
|
||||
print("hit sdpa pattern 20")
|
||||
res = _scaled_dot_product_attention(
|
||||
query.transpose(1, 2),
|
||||
key.transpose(1, 2),
|
||||
value.transpose(1, 2),
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout,
|
||||
is_causal=False,
|
||||
scale=1.0 / inv_scale,
|
||||
q_zp=q_zp,
|
||||
q_scale=q_scale,
|
||||
k_zp=k_zp,
|
||||
k_scale=k_scale,
|
||||
v_zp=v_zp,
|
||||
v_scale=v_scale,
|
||||
a_zp=a_zp,
|
||||
a_scale=a_scale,
|
||||
o_zp=o_zp,
|
||||
o_scale=o_scale,
|
||||
)
|
||||
return res.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
|
||||
def _sfdp_pattern_21(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
inv_scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale,
|
||||
dropout,
|
||||
):
|
||||
# int8-mix-reduce QUANTIZED SDPA with mask
|
||||
q = query.permute([0, 2, 1, 3])
|
||||
q = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
q, float(q_scale), int(q_zp), 0, 255, torch.uint8
|
||||
).to(torch.float16)
|
||||
k = key.permute([0, 2, 1, 3]).transpose(-2, -1)
|
||||
k = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
k, float(k_scale), int(k_zp), 0, 255, torch.uint8
|
||||
).to(torch.float16)
|
||||
v = value.permute([0, 2, 1, 3])
|
||||
v = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
v, float(v_scale), int(v_zp), 0, 255, torch.uint8
|
||||
).to(torch.float16)
|
||||
a = torch.nn.functional.dropout(
|
||||
(torch.matmul(q, k).div(inv_scale) + attn_mask).softmax(dim=-1),
|
||||
dropout,
|
||||
)
|
||||
qa = torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
a, float(a_scale), int(a_zp), 0, 255, torch.uint8
|
||||
)
|
||||
a = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
qa, float(a_scale), int(a_zp), 0, 255, torch.uint8
|
||||
).to(torch.float16)
|
||||
o = a.matmul(v)
|
||||
o = o.permute(0, 2, 1, 3).contiguous()
|
||||
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
o, float(o_scale), int(o_zp), 0, 255, torch.uint8
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_replacement_21(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
inv_scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale,
|
||||
dropout,
|
||||
):
|
||||
counters["inductor"]["fuse_attention"] += 1
|
||||
print("hit sdpa pattern 21")
|
||||
res = _scaled_dot_product_attention(
|
||||
query.transpose(1, 2),
|
||||
key.transpose(1, 2),
|
||||
value.transpose(1, 2),
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout,
|
||||
is_causal=False,
|
||||
scale=1.0 / inv_scale,
|
||||
q_zp=q_zp,
|
||||
q_scale=q_scale,
|
||||
k_zp=k_zp,
|
||||
k_scale=k_scale,
|
||||
v_zp=v_zp,
|
||||
v_scale=v_scale,
|
||||
a_zp=a_zp,
|
||||
a_scale=a_scale,
|
||||
o_zp=o_zp,
|
||||
o_scale=o_scale,
|
||||
)
|
||||
return res.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
|
||||
def _sfdp_pattern_22(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
inv_scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale,
|
||||
dropout,
|
||||
):
|
||||
# int8-mix-fp32 QUANTIZED SDPA without mask
|
||||
q = query.permute([0, 2, 1, 3])
|
||||
q = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
q, float(q_scale), int(q_zp), 0, 255, torch.uint8
|
||||
)
|
||||
k = key.permute([0, 2, 1, 3]).transpose(-2, -1)
|
||||
k = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
k, float(k_scale), int(k_zp), 0, 255, torch.uint8
|
||||
)
|
||||
v = value.permute([0, 2, 1, 3])
|
||||
v = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
v, float(v_scale), int(v_zp), 0, 255, torch.uint8
|
||||
)
|
||||
a = torch.nn.functional.dropout(
|
||||
torch.matmul(q, k).div(inv_scale).softmax(dim=-1),
|
||||
dropout,
|
||||
)
|
||||
qa = torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
a, float(a_scale), int(a_zp), 0, 255, torch.uint8
|
||||
)
|
||||
a = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
qa, float(a_scale), int(a_zp), 0, 255, torch.uint8
|
||||
)
|
||||
o = a.matmul(v)
|
||||
o = o.permute(0, 2, 1, 3).contiguous()
|
||||
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
o, float(o_scale), int(o_zp), 0, 255, torch.uint8
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_replacement_22(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
inv_scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale,
|
||||
dropout,
|
||||
):
|
||||
counters["inductor"]["fuse_attention"] += 1
|
||||
print("hit sdpa pattern 22")
|
||||
res = _scaled_dot_product_attention(
|
||||
query.transpose(1, 2),
|
||||
key.transpose(1, 2),
|
||||
value.transpose(1, 2),
|
||||
dropout_p=dropout,
|
||||
is_causal=False,
|
||||
scale=1.0 / inv_scale,
|
||||
q_zp=q_zp,
|
||||
q_scale=q_scale,
|
||||
k_zp=k_zp,
|
||||
k_scale=k_scale,
|
||||
v_zp=v_zp,
|
||||
v_scale=v_scale,
|
||||
a_zp=a_zp,
|
||||
a_scale=a_scale,
|
||||
o_zp=o_zp,
|
||||
o_scale=o_scale,
|
||||
)
|
||||
return res.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
|
||||
def _sfdp_pattern_23(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
inv_scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale,
|
||||
dropout,
|
||||
):
|
||||
# int8-mix-reduce QUANTIZED SDPA without mask
|
||||
q = query.permute([0, 2, 1, 3])
|
||||
q = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
q, float(q_scale), int(q_zp), 0, 255, torch.uint8
|
||||
).to(torch.float16)
|
||||
k = key.permute([0, 2, 1, 3]).transpose(-2, -1)
|
||||
k = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
k, float(k_scale), int(k_zp), 0, 255, torch.uint8
|
||||
).to(torch.float16)
|
||||
v = value.permute([0, 2, 1, 3])
|
||||
v = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
v, float(v_scale), int(v_zp), 0, 255, torch.uint8
|
||||
).to(torch.float16)
|
||||
a = torch.nn.functional.dropout(
|
||||
torch.matmul(q, k).div(inv_scale).softmax(dim=-1),
|
||||
dropout,
|
||||
)
|
||||
qa = torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
a, float(a_scale), int(a_zp), 0, 255, torch.uint8
|
||||
)
|
||||
a = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
qa, float(a_scale), int(a_zp), 0, 255, torch.uint8
|
||||
).to(torch.float16)
|
||||
o = a.matmul(v)
|
||||
o = o.permute(0, 2, 1, 3).contiguous()
|
||||
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
o, float(o_scale), int(o_zp), 0, 255, torch.uint8
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_replacement_23(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
inv_scale,
|
||||
q_zp,
|
||||
q_scale,
|
||||
k_zp,
|
||||
k_scale,
|
||||
v_zp,
|
||||
v_scale,
|
||||
a_zp,
|
||||
a_scale,
|
||||
o_zp,
|
||||
o_scale,
|
||||
dropout,
|
||||
):
|
||||
counters["inductor"]["fuse_attention"] += 1
|
||||
print("hit sdpa pattern 23")
|
||||
res = _scaled_dot_product_attention(
|
||||
query.transpose(1, 2),
|
||||
key.transpose(1, 2),
|
||||
value.transpose(1, 2),
|
||||
dropout_p=dropout,
|
||||
is_causal=False,
|
||||
scale=1.0 / inv_scale,
|
||||
q_zp=q_zp,
|
||||
q_scale=q_scale,
|
||||
k_zp=k_zp,
|
||||
k_scale=k_scale,
|
||||
v_zp=v_zp,
|
||||
v_scale=v_scale,
|
||||
a_zp=a_zp,
|
||||
a_scale=a_scale,
|
||||
o_zp=o_zp,
|
||||
o_scale=o_scale,
|
||||
)
|
||||
return res.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
def _sfdp_params_check(match):
|
||||
assert all(k in match.kwargs for k in ("query", "key", "value"))
|
||||
query = match.kwargs["query"].meta["val"]
|
||||
@ -557,8 +910,9 @@ def _sfdp_params_check(match):
|
||||
query.device == key.device == value.device
|
||||
):
|
||||
return False
|
||||
add_mask_node = filter_nodes(match.nodes, aten.add.Tensor)
|
||||
add_nodes = filter_nodes(match.nodes, aten.add.Tensor)
|
||||
# Has attn_mask add.
|
||||
add_mask_node = [n for n in add_nodes if n.prev.target == torch.ops.aten.div.Tensor]
|
||||
if len(add_mask_node) > 0:
|
||||
attn_mask_node = add_mask_node[0].args[1]
|
||||
# attn_mask_node may be a float/int number.
|
||||
@ -641,10 +995,25 @@ def _get_sfdp_patterns():
|
||||
b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
|
||||
m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device)
|
||||
# inv_scale
|
||||
c_inp = functools.partial(torch.tensor, 2.0, device=device)
|
||||
c_inp = functools.partial(torch.tensor, 2, device=device)
|
||||
zp_inp = functools.partial(torch.tensor, 127, device=device)
|
||||
scale_inp = functools.partial(torch.tensor, 0.018, device=device)
|
||||
# workaround https://github.com/pytorch/pytorch/issues/97894
|
||||
# 0.113377 is a "magic" value that lets us recover the lost input arg relationship
|
||||
d = {"dropout_p": 0.113377}
|
||||
d_u8 = {
|
||||
"dropout": 0.113377,
|
||||
"q_zp": 23,
|
||||
"q_scale": 0.0111541,
|
||||
"k_zp": 14,
|
||||
"k_scale": 0.0256212,
|
||||
"v_zp": 28,
|
||||
"v_scale": 0.0164518,
|
||||
"a_zp": 12,
|
||||
"a_scale": 0.0572114,
|
||||
"o_zp": 36,
|
||||
"o_scale": 0.0235489,
|
||||
}
|
||||
|
||||
# we could also generate all these patterns in 3d.. TODO
|
||||
g_3d_inp = functools.partial(
|
||||
@ -675,6 +1044,10 @@ def _get_sfdp_patterns():
|
||||
m_bs1 = functools.partial(m_bs1_inp, dtype=dtype)
|
||||
m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float)
|
||||
m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool)
|
||||
g_u8 = functools.partial(g_inp, dtype=torch.uint8, requires_grad=False)
|
||||
g_u8_bs1 = functools.partial(g_bs1_inp, dtype=torch.uint8, requires_grad=False)
|
||||
zp = functools.partial(zp_inp, dtype=torch.int)
|
||||
scale = functools.partial(scale_inp, dtype=torch.float)
|
||||
|
||||
candidates = [
|
||||
(
|
||||
@ -849,14 +1222,202 @@ def _get_sfdp_patterns():
|
||||
_sfdp_extra_check(aten.div.Tensor, disable_cuda=True),
|
||||
)
|
||||
)
|
||||
# uint8 quantization
|
||||
if dtype == torch.float:
|
||||
int8_candidates = [
|
||||
(
|
||||
_sfdp_pattern_20,
|
||||
_sfdp_replacement_20,
|
||||
[
|
||||
g_u8(),
|
||||
g_u8(),
|
||||
g_u8(),
|
||||
m(),
|
||||
c(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
],
|
||||
d_u8,
|
||||
_sfdp_extra_check(aten.div.Tensor),
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_20,
|
||||
_sfdp_replacement_20,
|
||||
[
|
||||
g_u8_bs1(),
|
||||
g_u8_bs1(),
|
||||
g_u8_bs1(),
|
||||
m_bs1(),
|
||||
c(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
],
|
||||
d_u8,
|
||||
_sfdp_extra_check(aten.div.Tensor),
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_21,
|
||||
_sfdp_replacement_21,
|
||||
[
|
||||
g_u8(),
|
||||
g_u8(),
|
||||
g_u8(),
|
||||
m(),
|
||||
c(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
],
|
||||
d_u8,
|
||||
_sfdp_extra_check(aten.div.Tensor),
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_21,
|
||||
_sfdp_replacement_21,
|
||||
[
|
||||
g_u8_bs1(),
|
||||
g_u8_bs1(),
|
||||
g_u8_bs1(),
|
||||
m_bs1(),
|
||||
c(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
],
|
||||
d_u8,
|
||||
_sfdp_extra_check(aten.div.Tensor),
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_22,
|
||||
_sfdp_replacement_22,
|
||||
[
|
||||
g_u8(),
|
||||
g_u8(),
|
||||
g_u8(),
|
||||
c(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
],
|
||||
d_u8,
|
||||
_sfdp_extra_check(aten.div.Tensor),
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_22,
|
||||
_sfdp_replacement_22,
|
||||
[
|
||||
g_u8_bs1(),
|
||||
g_u8_bs1(),
|
||||
g_u8_bs1(),
|
||||
c(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
],
|
||||
d_u8,
|
||||
_sfdp_extra_check(aten.div.Tensor),
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_23,
|
||||
_sfdp_replacement_23,
|
||||
[
|
||||
g_u8(),
|
||||
g_u8(),
|
||||
g_u8(),
|
||||
c(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
],
|
||||
d_u8,
|
||||
_sfdp_extra_check(aten.div.Tensor),
|
||||
),
|
||||
(
|
||||
_sfdp_pattern_23,
|
||||
_sfdp_replacement_23,
|
||||
[
|
||||
g_u8_bs1(),
|
||||
g_u8_bs1(),
|
||||
g_u8_bs1(),
|
||||
c(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
zp(),
|
||||
scale(),
|
||||
],
|
||||
d_u8,
|
||||
_sfdp_extra_check(aten.div.Tensor),
|
||||
),
|
||||
]
|
||||
candidates.extend(int8_candidates)
|
||||
|
||||
for pattern, replacement, args, workaround, extra_check in candidates:
|
||||
# XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
|
||||
# gets serialized to a python file and does not require tracing at runtime.
|
||||
assert isinstance(workaround, dict)
|
||||
name = pattern.__name__
|
||||
is_uint8 = args[0].dtype == torch.uint8
|
||||
|
||||
if dtype != torch.float:
|
||||
if is_uint8:
|
||||
name += "_u8"
|
||||
elif dtype != torch.float:
|
||||
name += "_half"
|
||||
if (
|
||||
any(p in name for p in mask_fp32_patterns)
|
||||
@ -866,25 +1427,35 @@ def _get_sfdp_patterns():
|
||||
if args[0].size(0) == 1:
|
||||
name += "_bs1"
|
||||
|
||||
training_name = name + "_training"
|
||||
yield training_name, {
|
||||
"search_fn": pattern,
|
||||
"replace_fn": replacement,
|
||||
"example_inputs": args,
|
||||
"trace_fn": joint_fwd_bwd,
|
||||
"pass_dicts": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
}
|
||||
if not is_uint8:
|
||||
training_name = name + "_training"
|
||||
yield training_name, {
|
||||
"search_fn": pattern,
|
||||
"replace_fn": replacement,
|
||||
"example_inputs": args,
|
||||
"trace_fn": joint_fwd_bwd,
|
||||
"pass_dicts": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
}
|
||||
|
||||
if workaround:
|
||||
assert len(workaround) == 1 and "dropout_p" in workaround
|
||||
# functools.partial insufficient because we look at signature downstream
|
||||
pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
|
||||
replacement = partialize_and_update_signature(
|
||||
replacement, dropout_p=0.0
|
||||
)
|
||||
workaround = {}
|
||||
if len(workaround) >= 1:
|
||||
if "dropout_p" in workaround:
|
||||
# functools.partial insufficient because we look at signature downstream
|
||||
pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
|
||||
replacement = partialize_and_update_signature(
|
||||
replacement, dropout_p=0.0
|
||||
)
|
||||
workaround = {}
|
||||
else:
|
||||
# for uint8 pattern with more workarounds other than dropout,
|
||||
# we need to rename it to avoid influcing other patterns
|
||||
pattern = partialize_and_update_signature(pattern, dropout=0.0)
|
||||
replacement = partialize_and_update_signature(
|
||||
replacement, dropout=0.0
|
||||
)
|
||||
if "dropout" in workaround:
|
||||
del workaround["dropout"]
|
||||
|
||||
inference_name = name + "_inference"
|
||||
yield inference_name, {
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -0,0 +1,100 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default, KeywordArg('q_scale'), KeywordArg('q_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
||||
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_1 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_2, KeywordArg('k_scale'), KeywordArg('k_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_1, Ignored())
|
||||
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
||||
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
quantized_decomposed_quantize_per_tensor_default = CallFunction(quantized_decomposed.quantize_per_tensor.default, div_Tensor_1, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_2 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed_quantize_per_tensor_default, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_2, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_3 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_3, KeywordArg('v_scale'), KeywordArg('v_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, permute_default_4, memory_format=torch.contiguous_format)
|
||||
_sfdp_pattern_20_u8_inference = CallFunction(quantized_decomposed.quantize_per_tensor.default, clone_default_3, KeywordArg('o_scale'), KeywordArg('o_zp'), Ignored(), Ignored(), Ignored(), _users=0)
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default, KeywordArg('q_scale'), KeywordArg('q_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_1 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_2, KeywordArg('k_scale'), KeywordArg('k_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_1, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
quantized_decomposed_quantize_per_tensor_default = CallFunction(quantized_decomposed.quantize_per_tensor.default, div_Tensor_1, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_2 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed_quantize_per_tensor_default, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_2, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_3 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_3, KeywordArg('v_scale'), KeywordArg('v_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_3, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, permute_default_4, memory_format=torch.contiguous_format)
|
||||
_sfdp_pattern_20_u8_bs1_inference = CallFunction(quantized_decomposed.quantize_per_tensor.default, clone_default, KeywordArg('o_scale'), KeywordArg('o_zp'), Ignored(), Ignored(), Ignored(), _users=0)
|
||||
@ -0,0 +1,108 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default, KeywordArg('q_scale'), KeywordArg('q_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default, Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
||||
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_1 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_2, KeywordArg('k_scale'), KeywordArg('k_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
||||
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
quantized_decomposed_quantize_per_tensor_default = CallFunction(quantized_decomposed.quantize_per_tensor.default, div_Tensor_1, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_2 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed_quantize_per_tensor_default, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_3 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_3, KeywordArg('v_scale'), KeywordArg('v_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_3, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, permute_default_4, memory_format=torch.contiguous_format)
|
||||
_sfdp_pattern_21_u8_inference = CallFunction(quantized_decomposed.quantize_per_tensor.default, clone_default_3, KeywordArg('o_scale'), KeywordArg('o_zp'), Ignored(), Ignored(), Ignored(), _users=0)
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default, KeywordArg('q_scale'), KeywordArg('q_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default, Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_1 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_2, KeywordArg('k_scale'), KeywordArg('k_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
||||
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
quantized_decomposed_quantize_per_tensor_default = CallFunction(quantized_decomposed.quantize_per_tensor.default, div_Tensor_1, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_2 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed_quantize_per_tensor_default, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_3 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_3, KeywordArg('v_scale'), KeywordArg('v_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_3, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, permute_default_4, memory_format=torch.contiguous_format)
|
||||
_sfdp_pattern_21_u8_bs1_inference = CallFunction(quantized_decomposed.quantize_per_tensor.default, clone_default, KeywordArg('o_scale'), KeywordArg('o_zp'), Ignored(), Ignored(), Ignored(), _users=0)
|
||||
@ -0,0 +1,98 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default, KeywordArg('q_scale'), KeywordArg('q_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
||||
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_1 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_2, KeywordArg('k_scale'), KeywordArg('k_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_1, Ignored())
|
||||
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
||||
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
quantized_decomposed_quantize_per_tensor_default = CallFunction(quantized_decomposed.quantize_per_tensor.default, div_Tensor_1, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_2 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed_quantize_per_tensor_default, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_2, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_3 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_3, KeywordArg('v_scale'), KeywordArg('v_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_3, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, permute_default_4, memory_format=torch.contiguous_format)
|
||||
_sfdp_pattern_22_u8_inference = CallFunction(quantized_decomposed.quantize_per_tensor.default, clone_default_3, KeywordArg('o_scale'), KeywordArg('o_zp'), Ignored(), Ignored(), Ignored(), _users=0)
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default, KeywordArg('q_scale'), KeywordArg('q_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_1 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_2, KeywordArg('k_scale'), KeywordArg('k_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_1, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
quantized_decomposed_quantize_per_tensor_default = CallFunction(quantized_decomposed.quantize_per_tensor.default, div_Tensor_1, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_2 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed_quantize_per_tensor_default, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_2, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_3 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_3, KeywordArg('v_scale'), KeywordArg('v_zp'), Ignored(), Ignored(), Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, quantized_decomposed_dequantize_per_tensor_default_3, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, permute_default_4, memory_format=torch.contiguous_format)
|
||||
_sfdp_pattern_22_u8_bs1_inference = CallFunction(quantized_decomposed.quantize_per_tensor.default, clone_default, KeywordArg('o_scale'), KeywordArg('o_zp'), Ignored(), Ignored(), Ignored(), _users=0)
|
||||
@ -0,0 +1,110 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
CallModule,
|
||||
CallModuleVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default, KeywordArg('q_scale'), KeywordArg('q_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default, Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
||||
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_1 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_2, KeywordArg('k_scale'), KeywordArg('k_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
||||
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
quantized_decomposed_quantize_per_tensor_default = CallFunction(quantized_decomposed.quantize_per_tensor.default, convert_element_type_default_3, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_2 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed_quantize_per_tensor_default, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_4, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_3 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_3, KeywordArg('v_scale'), KeywordArg('v_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_3, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_5, Ignored())
|
||||
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
||||
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
clone_default_3 = CallFunction(aten.clone.default, permute_default_4, memory_format=torch.contiguous_format)
|
||||
_sfdp_pattern_23_u8_inference = CallFunction(quantized_decomposed.quantize_per_tensor.default, clone_default_3, KeywordArg('o_scale'), KeywordArg('o_zp'), Ignored(), Ignored(), Ignored(), _users=0)
|
||||
|
||||
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default, KeywordArg('q_scale'), KeywordArg('q_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default, Ignored())
|
||||
expand_default = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_1 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_2, KeywordArg('k_scale'), KeywordArg('k_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_1, Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
||||
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
||||
quantized_decomposed_quantize_per_tensor_default = CallFunction(quantized_decomposed.quantize_per_tensor.default, convert_element_type_default_3, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_2 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed_quantize_per_tensor_default, KeywordArg('a_scale'), KeywordArg('a_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_2, Ignored())
|
||||
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_4, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
||||
quantized_decomposed_dequantize_per_tensor_default_3 = CallFunction(quantized_decomposed.dequantize_per_tensor.default, permute_default_3, KeywordArg('v_scale'), KeywordArg('v_zp'), Ignored(), Ignored(), Ignored())
|
||||
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, quantized_decomposed_dequantize_per_tensor_default_3, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_5, Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored())
|
||||
clone_default = CallFunction(aten.clone.default, permute_default_4, memory_format=torch.contiguous_format)
|
||||
_sfdp_pattern_23_u8_bs1_inference = CallFunction(quantized_decomposed.quantize_per_tensor.default, clone_default, KeywordArg('o_scale'), KeywordArg('o_zp'), Ignored(), Ignored(), Ignored(), _users=0)
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -10,6 +10,7 @@ import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
|
||||
@ -1398,6 +1398,7 @@ def _serialize_pattern(
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
"""
|
||||
).format(msg=auto_generated_msg)
|
||||
|
||||
@ -5097,6 +5097,16 @@ def meta__scaled_dot_product_flash_attention_for_cpu(
|
||||
is_causal: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
q_zp: int = 0,
|
||||
q_scale: float = 0.0,
|
||||
k_zp: int = 0,
|
||||
k_scale: float = 0.0,
|
||||
v_zp: int = 0,
|
||||
v_scale: float = 0.0,
|
||||
a_zp: int = 0,
|
||||
a_scale: float = 0.0,
|
||||
o_zp: int = 0,
|
||||
o_scale: float = 0.0,
|
||||
):
|
||||
batch_size = query.size(0)
|
||||
num_heads = query.size(1)
|
||||
|
||||
@ -325,6 +325,7 @@ def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
|
||||
def get_default_x86_inductor_quantization_config(
|
||||
is_qat: bool = False,
|
||||
is_dynamic: bool = False,
|
||||
# is_reduce: bool = False,
|
||||
):
|
||||
extra_args: Dict[str, Any] = {"eps": 2**-12}
|
||||
if is_qat:
|
||||
@ -346,7 +347,7 @@ def get_default_x86_inductor_quantization_config(
|
||||
act_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255, # reduce_range=False
|
||||
quant_max=255, # quant_max=127 if is_reduce else 255, # reduce_range=False
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=is_dynamic,
|
||||
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
||||
|
||||
@ -2374,40 +2374,40 @@ def gen_source_files(
|
||||
header=True,
|
||||
includes="",
|
||||
)
|
||||
if update_aoti_c_shim:
|
||||
aoti_fm.write(
|
||||
header_file_name,
|
||||
lambda: new_header,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(aoti_fm.install_dir, header_file_name)
|
||||
) as old_file:
|
||||
old_header = old_file.read()
|
||||
assert (
|
||||
old_header == new_header
|
||||
), """
|
||||
# if update_aoti_c_shim:
|
||||
# aoti_fm.write(
|
||||
# header_file_name,
|
||||
# lambda: new_header,
|
||||
# )
|
||||
# else:
|
||||
# try:
|
||||
# with open(
|
||||
# os.path.join(aoti_fm.install_dir, header_file_name)
|
||||
# ) as old_file:
|
||||
# old_header = old_file.read()
|
||||
# assert (
|
||||
# old_header == new_header
|
||||
# ), """
|
||||
|
||||
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
|
||||
Only in a limited number of situations, this is allowed:
|
||||
# WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
|
||||
# indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
|
||||
# Only in a limited number of situations, this is allowed:
|
||||
|
||||
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
|
||||
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
|
||||
C shim header files.
|
||||
# 1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
|
||||
# If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
|
||||
# C shim header files.
|
||||
|
||||
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
|
||||
change in the AOTInductor land. In this case, you need to keep a manual copy of that existing
|
||||
fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
|
||||
number of that fallback op in the newly generated C shim files, and update the cpp wrapper
|
||||
codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
|
||||
# 2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
|
||||
# change in the AOTInductor land. In this case, you need to keep a manual copy of that existing
|
||||
# fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
|
||||
# number of that fallback op in the newly generated C shim files, and update the cpp wrapper
|
||||
# codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
|
||||
|
||||
"""
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
|
||||
)
|
||||
# """
|
||||
# except FileNotFoundError:
|
||||
# print(
|
||||
# f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
|
||||
# )
|
||||
|
||||
# cpp files are always generated on-the-fly
|
||||
def headers_for_aoti() -> str:
|
||||
|
||||
Reference in New Issue
Block a user