mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 02:45:19 +08:00
Compare commits
1 Commits
ciflow/tru
...
update_sub
| Author | SHA1 | Date | |
|---|---|---|---|
| ed2dcd679c |
@ -96,6 +96,7 @@ function pip_build_and_install() {
|
||||
python3 -m pip wheel \
|
||||
--no-build-isolation \
|
||||
--no-deps \
|
||||
--no-use-pep517 \
|
||||
-w "${wheel_dir}" \
|
||||
"${build_target}"
|
||||
fi
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: docker-cache-rocm
|
||||
name: docker-cache-mi300
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
@ -31,7 +31,7 @@ jobs:
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4.1.7
|
||||
with:
|
||||
run-id: ${{ github.event.workflow_run.id }}
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
path: ./docker-builds-artifacts
|
||||
merge-multiple: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -3541,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
|
||||
const int64_t out_features) {
|
||||
auto M = inp.size(0);
|
||||
TORCH_CHECK(
|
||||
inp.dtype() == kFloat,
|
||||
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
|
||||
__func__,
|
||||
" : expect input to be 32-bit float tensor.");
|
||||
" : expect input to be float32 or bfloat16 tensor.");
|
||||
TORCH_CHECK(
|
||||
block_size == in_features ||
|
||||
(!(block_size % 32) && !(in_features % block_size)),
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <cmath>
|
||||
#include <c10/util/Unroll.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
@ -793,6 +794,139 @@ bool can_use_kleidiai(
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t k,
|
||||
const uint16_t* lhs_bf16,
|
||||
const uint8_t* rhs_qs4cx,
|
||||
const float* rhs_scales,
|
||||
uint16_t* dst_bf16,
|
||||
float scalar_min,
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Roundup lambda for internal stride calculations
|
||||
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
|
||||
|
||||
// Cast bfloat16 to float32 inline
|
||||
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
|
||||
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
|
||||
float f;
|
||||
std::memcpy(&f, &tmp, sizeof(f));
|
||||
return f;
|
||||
};
|
||||
|
||||
// Cast float32 to bfloat16 inline
|
||||
auto cast_f32_to_bf16 = [](float f) {
|
||||
uint32_t bits;
|
||||
std::memcpy(&bits, &f, sizeof(bits));
|
||||
return static_cast<uint16_t>(bits >> 16);
|
||||
};
|
||||
|
||||
// Quantization pack lambda (channelwise QA8DX)
|
||||
auto quant_pack_8bit_channelwise =
|
||||
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const uint16_t* row_ptr = src_bf16 + i * K;
|
||||
// find min/max
|
||||
float mn = FLT_MAX, mx = -FLT_MAX;
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
mn = std::min(mn, v);
|
||||
mx = std::max(mx, v);
|
||||
}
|
||||
float rmin = std::min(0.0f, mn);
|
||||
float rmax = std::max(0.0f, mx);
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
|
||||
float recip = scale ? 1.0f / scale : 0.0f;
|
||||
int32_t zp;
|
||||
float des_min = rmin * scale;
|
||||
float des_max = rmax * scale;
|
||||
float err_min = qmin + des_min;
|
||||
float err_max = qmax + des_max;
|
||||
float zp_f =
|
||||
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
|
||||
zp_f = std::clamp(zp_f, qmin, qmax);
|
||||
zp = std::lrintf(zp_f);
|
||||
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
|
||||
// store header
|
||||
*reinterpret_cast<float*>(out_ptr) = recip;
|
||||
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
|
||||
out_ptr += sizeof(float) + sizeof(int32_t);
|
||||
// quantize
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
|
||||
q = std::clamp(
|
||||
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
|
||||
*out_ptr++ = static_cast<int8_t>(q);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// MatMul lambda (MXN x MXK -> MNXK BF16)
|
||||
auto matmul_kernel = [&](size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const int8_t* lhs,
|
||||
const uint8_t* rhs,
|
||||
const float* scales,
|
||||
uint16_t* dst,
|
||||
float lo,
|
||||
float hi) {
|
||||
const size_t lhs_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
const size_t rhs_stride = roundup(K, 2) / 2;
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const int8_t* lhs_row = lhs + i * lhs_stride;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
int32_t acc = 0;
|
||||
const int8_t* lptr = lhs_row;
|
||||
const uint8_t* rptr = rhs + j * rhs_stride;
|
||||
float lhs_scale = *reinterpret_cast<const float*>(lptr);
|
||||
int32_t lhs_off =
|
||||
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
|
||||
lptr += sizeof(float) + sizeof(int32_t);
|
||||
for (size_t t = 0; t < K; ++t) {
|
||||
int32_t lv = static_cast<int32_t>(lptr[t]);
|
||||
uint8_t bv = rptr[t / 2];
|
||||
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
|
||||
: (static_cast<int32_t>(bv >> 4) - 8);
|
||||
acc += lv * rv + lhs_off * rv;
|
||||
}
|
||||
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
|
||||
if (bias) {
|
||||
res += bias[j];
|
||||
}
|
||||
res = std::clamp(res, lo, hi);
|
||||
*dst++ = cast_f32_to_bf16(res);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// allocate and run
|
||||
std::unique_ptr<int8_t[]> packed(
|
||||
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
|
||||
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
|
||||
matmul_kernel(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
packed.get(),
|
||||
rhs_qs4cx,
|
||||
rhs_scales,
|
||||
dst_bf16,
|
||||
scalar_min,
|
||||
scalar_max);
|
||||
}
|
||||
|
||||
/**
|
||||
* The Int4 quantized weights must be represented as a uint8 tensor
|
||||
* For matrix multiplication with a weight shape of (N x K)
|
||||
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
if (can_use_kleidiai(scales_zeros, K, block_size)) {
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
|
||||
packed_weights.resize_({weight_packed_size});
|
||||
kleidiai::kai_pack_int4_rhs(
|
||||
packed_weights, weights, scales_zeros, bias, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
TORCH_CHECK(
|
||||
bias.has_value() == 0,
|
||||
__func__,
|
||||
" : Bias is unsupported in reference implementation");
|
||||
packed_weights = packed_weights.to(kFloat);
|
||||
auto weight_reshaped = weights.view({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
|
||||
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
|
||||
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
|
||||
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
|
||||
if (bias.has_value()) {
|
||||
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
|
||||
}
|
||||
auto res = at::cat(tensors_to_cat, 0);
|
||||
packed_weights.resize_(res.sizes()).copy_(res);
|
||||
}
|
||||
}
|
||||
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
const float* rhs_scales_f32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
|
||||
|
||||
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
|
||||
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
// required format for matmul
|
||||
auto input_quant_pack_8bit_channelwise =
|
||||
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
}
|
||||
|
||||
// Maximum/minimum int8 values
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
|
||||
// Round to nearest integer
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
|
||||
|
||||
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
|
||||
v0_s32 = v0_s32 + nudged_zero_point0;
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[n_idx];
|
||||
}
|
||||
|
||||
// Clamp (min-max) operation
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float* rhs_scales_fp32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Lambda for LHS quantization
|
||||
auto lhs_quant_pack = [&](size_t m,
|
||||
size_t k,
|
||||
const float* lhs_f32,
|
||||
int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
min0 = std::min(src0_0, min0);
|
||||
}
|
||||
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
|
||||
zero_point0 = std::max(zero_point0, qmin);
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
|
||||
|
||||
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
v0_s32 = std::max(
|
||||
std::min(
|
||||
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
|
||||
static_cast<int32_t>(INT8_MIN));
|
||||
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
|
||||
static_cast<int32_t>(kI8Min));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[col_idx];
|
||||
}
|
||||
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
|
||||
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic Input Quant 4 bit weights matmul execution flow
|
||||
(INT4 Weights + FP scales + FP32 Bias)
|
||||
FP32 Input Packed Buffer
|
||||
| |
|
||||
Quantize Cast
|
||||
to INT8 to INT8
|
||||
| |
|
||||
v v
|
||||
INT8 Input INT8 Weights
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
INT8 Matrix Multiplication
|
||||
|
|
||||
v
|
||||
FP32 Dequantized and Accumulate in FP32
|
||||
|
|
||||
v
|
||||
FP32 Final Output
|
||||
|
||||
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||
* Dynamic INT4 weight-only MatMul with per-row input quantization.
|
||||
*
|
||||
* Execution Flow:
|
||||
*
|
||||
* (INT4 Weights + FP Scales [+ optional Bias])
|
||||
*
|
||||
* Input (FP32 or BF16) Packed Weight Buffer
|
||||
* | |
|
||||
* Row-wise Quantization (INT8) |
|
||||
* | |
|
||||
* INT8 Input Activation INT4 Quantized Weights + Scales
|
||||
* \ /
|
||||
* \ /
|
||||
* Quantized Matrix Multiply
|
||||
* |
|
||||
* Output Tensor (BF16 or FP32)
|
||||
*
|
||||
* Notes:
|
||||
* - Groupwise kernels expect BF16 scales
|
||||
* - Channelwise kernels expect FP32 scales
|
||||
* - Bias is currently unsupported in fallback path
|
||||
*/
|
||||
void dyn_quant_matmul_4bit_kernel(
|
||||
const Tensor& output,
|
||||
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
|
||||
const int64_t block_size) {
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
|
||||
if (weight_packed_size == packed_weights.numel()) {
|
||||
// KleidiAI interface internally handles the Channelwise and groupwise
|
||||
// distinction
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(
|
||||
output, inp, packed_weights, M, N, K, block_size);
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
|
||||
const auto weights_size = N * K / 2;
|
||||
// The weights needs to be in uint8_t data type after quantization
|
||||
auto extracted_weights =
|
||||
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
|
||||
auto float32_scales =
|
||||
(packed_weights.narrow(
|
||||
0, weights_size, packed_weights.size(0) - weights_size))
|
||||
.to(kFloat);
|
||||
uint8_t* rhs_4bit =
|
||||
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
|
||||
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_size,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
block_size == K || (!(block_size % 32) && !(K % block_size)),
|
||||
__func__,
|
||||
": Group size should be multiple 32 or in_features [",
|
||||
K,
|
||||
"]. Provided ",
|
||||
block_size);
|
||||
{
|
||||
void* input = inp.data_ptr();
|
||||
void* dst = output.data_ptr();
|
||||
|
||||
// Extract weights, sclaes and biases form from packed tensor
|
||||
const int weights_elements = N * K / 2;
|
||||
const int scale_elements = N * (K / block_size);
|
||||
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
|
||||
|
||||
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
|
||||
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
|
||||
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
|
||||
|
||||
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
|
||||
float* weight_scales = float32_scales.data_ptr<float>();
|
||||
|
||||
void* bias_data = nullptr;
|
||||
if (bias_elements) {
|
||||
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
|
||||
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
|
||||
bias_data = float32_bias.data_ptr();
|
||||
|
||||
}
|
||||
// 2 elements of 4 bit weights are packed into 1 uint8 packet
|
||||
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
|
||||
// Dispatch to reference kernels
|
||||
if (inp.scalar_type() == at::kBFloat16) {
|
||||
// BF16 input, BF16 output
|
||||
constexpr float BF16_MAX = 3.38953139e+38f;
|
||||
constexpr float BF16_MIN = -BF16_MAX;
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
M, N, K,
|
||||
(uint16_t*)input, weights_4bit, weight_scales,
|
||||
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
|
||||
}
|
||||
} else if (inp.scalar_type() == at::kFloat) {
|
||||
// FP32 input, FP32 output
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M, N, K,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M, N, K, block_size,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
}
|
||||
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
|
||||
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
|
||||
|
||||
@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
if (weight.scalar_type() == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype) {
|
||||
size_t packed_size = n * k;
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
if (tensor_dtype == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
});
|
||||
}
|
||||
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
// Kernel IDs for GEMM and GEMV
|
||||
constexpr kai_kernel_id gemm_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
|
||||
constexpr kai_kernel_id gemv_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
|
||||
|
||||
// Get total threads and select kernel
|
||||
const int64_t total_threads = at::get_num_threads();
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
|
||||
if (cpuinfo_has_arm_i8mm() && m > 1) {
|
||||
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
|
||||
}
|
||||
|
||||
// Thread blocking parameters
|
||||
const int64_t n_step = kernel_packet.ukernel.get_n_step();
|
||||
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||
const size_t kr = kernel_packet.ukernel.get_kr();
|
||||
const size_t sr = kernel_packet.ukernel.get_sr();
|
||||
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
|
||||
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||
const uint8_t* lhs_native_mtx_bf16 =
|
||||
reinterpret_cast<const uint8_t*>(input.data_ptr());
|
||||
const uint8_t* rhs_packed_mtx_qs4cx =
|
||||
reinterpret_cast<const uint8_t*>(weight.data_ptr());
|
||||
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||
|
||||
constexpr int32_t element_size = sizeof(uint16_t);
|
||||
const size_t lhs_stride = k * element_size;
|
||||
const size_t dst_stride = n * element_size;
|
||||
|
||||
// LHS quantization packing
|
||||
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
|
||||
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
|
||||
const size_t src_stride = vec_per_thread * lhs_stride;
|
||||
|
||||
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
vec_num,
|
||||
k,
|
||||
mr,
|
||||
kr,
|
||||
sr,
|
||||
0,
|
||||
(const uint16_t*)lhs_src_ptr,
|
||||
lhs_stride,
|
||||
lhs_packed_ptr);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
lhs_quant_pack(thread_id);
|
||||
}
|
||||
});
|
||||
|
||||
// Matrix multiplication
|
||||
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
|
||||
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
|
||||
|
||||
auto mm = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
|
||||
kernel_packet.ukernel.get_rhs_packed_offset(
|
||||
thread_id * vec_per_thread, k);
|
||||
auto dst_ptr = dst_act_mtx_bf16 +
|
||||
kernel_packet.ukernel.get_dst_offset(
|
||||
0, thread_id * vec_per_thread, dst_stride);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (n - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.ukernel.run_matmul(
|
||||
m,
|
||||
vec_num,
|
||||
k,
|
||||
lhs_packed_base,
|
||||
rhs_packed_ptr,
|
||||
(uint16_t*)dst_ptr,
|
||||
dst_stride,
|
||||
element_size, // dst_stride_col
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
mm(thread_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
const at::Tensor& output,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
const auto input_dtype = input.dtype();
|
||||
|
||||
if (input_dtype == at::kBFloat16) {
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
|
||||
}
|
||||
} else if (input_dtype == at::kFloat) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
|
||||
}
|
||||
} else if ((bl % 32 == 0) && (k % bl == 0)) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
output, input, weight, m, n, k, bl);
|
||||
}
|
||||
|
||||
@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl);
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype = at::kFloat);
|
||||
|
||||
/**
|
||||
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
|
||||
|
||||
@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
auto weight_packed_data =
|
||||
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||
const auto scales_data = scales.data_ptr<float>();
|
||||
|
||||
const auto scales_data = scales.to(kFloat).data_ptr<float>();
|
||||
|
||||
if (weight_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
|
||||
@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
|
||||
// Kernel Mapping - BF16 Channelwise
|
||||
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
|
||||
bf16_channelwise_8bit_4bit_kernels = {
|
||||
{kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
|
||||
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return bf16_channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
||||
|
||||
@ -10,21 +10,32 @@
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
enum class kai_kernel_id {
|
||||
// FP32 inputs, 4-bit weights, FP32 output
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
|
||||
0, // Groupwise 4 bit GEMV
|
||||
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
|
||||
1, // Groupwise 4 bit GEMM
|
||||
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
|
||||
2, // Channelwise 4 bit GEMV
|
||||
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
|
||||
3 // Channelwise 4 bit GEMM
|
||||
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
|
||||
|
||||
// BF16 inputs, 4-bit weights, BF16 output
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
|
||||
4, // Channelwise 4-bit GEMV with BF16 input/output
|
||||
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
|
||||
5 // Channelwise 4-bit GEMM with BF16 input/output
|
||||
};
|
||||
|
||||
// Channelwise Kernel mapping
|
||||
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
|
||||
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// bf16 Channelwise Kernel mapping
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
|
||||
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
|
||||
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
|
||||
size_t (*kai_get_lhs_packed_size)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
size_t (*kai_get_rhs_packed_size)(
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
void (*kai_run_lhs_quant_pack)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t m_idx_start,
|
||||
const void* lhs,
|
||||
size_t lhs_stride,
|
||||
void* lhs_packed);
|
||||
void (*kai_run_rhs_pack)(
|
||||
size_t num_groups,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
const uint8_t* rhs,
|
||||
const float* bias,
|
||||
const float* scale,
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
: ukernel(kernel),
|
||||
kai_get_lhs_packed_size(
|
||||
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
|
||||
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// Groupwise Kernel mapping
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
|
||||
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
|
||||
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||
|
||||
@ -1426,9 +1426,6 @@ static at::Tensor _fp8_convolution_onednn_ref(
|
||||
w_scales_new_shape[0] = -1;
|
||||
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
|
||||
auto output_padding = std::vector<int64_t>(kSpatialDim, 0);
|
||||
if (bias.has_value()){
|
||||
bias = bias.value().to(at::kFloat);
|
||||
}
|
||||
auto y_f32 = at::convolution(
|
||||
dqx, dqw, bias, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups
|
||||
);
|
||||
|
||||
@ -50,7 +50,7 @@ nfnet_l0,pass,7
|
||||
|
||||
|
||||
|
||||
repvgg_a2,pass,7
|
||||
repvgg_a2,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -14,10 +14,6 @@ Utils
|
||||
|
||||
sdpa_kernel
|
||||
SDPBackend
|
||||
register_flash_attention_impl
|
||||
activate_flash_attention_impl
|
||||
list_flash_attention_impls
|
||||
current_flash_attention_impl
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
|
||||
pip install ninja
|
||||
|
||||
# Install onnx
|
||||
pip install -e "$tp2_dir/onnx"
|
||||
pip install --no-use-pep517 -e "$tp2_dir/onnx"
|
||||
|
||||
# Install caffe2 and pytorch
|
||||
pip install -r "$top_dir/caffe2/requirements.txt"
|
||||
|
||||
@ -180,47 +180,6 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
||||
del model
|
||||
del optim
|
||||
|
||||
def _test_tracker_multihandler_hook(self):
|
||||
"""Should run without KeyError."""
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.norm1 = nn.RMSNorm(dim)
|
||||
self.output1 = nn.Linear(dim, dim)
|
||||
self.norm2 = nn.RMSNorm(dim)
|
||||
self.output2 = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.norm1(x)
|
||||
x = self.output1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.output2(x)
|
||||
return x
|
||||
|
||||
gc.collect()
|
||||
torch.manual_seed(42)
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
|
||||
with torch.device(dev):
|
||||
model = TestModule(128)
|
||||
|
||||
mesh = init_device_mesh(dev.type, (self.world_size,))
|
||||
fully_shard([model.norm1, model.output1], mesh=mesh)
|
||||
fully_shard([model.norm2, model.output2], mesh=mesh)
|
||||
fully_shard(model, mesh=mesh)
|
||||
|
||||
fmt = FSDPMemTracker(model)
|
||||
|
||||
with fmt:
|
||||
inp = torch.randn(16, 128, device=dev)
|
||||
y = model(inp)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
del inp
|
||||
del model
|
||||
|
||||
|
||||
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
|
||||
@property
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -372,7 +371,6 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
|
||||
# is producing a joint graph with backward region missing
|
||||
@unittest.expectedFailure
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ import torch._functorch.config
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from functorch.compile import default_partition, min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
|
||||
)
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
@ -281,7 +281,14 @@ class ActivationCheckpointingViaTagsTests(
|
||||
|
||||
run(export_compiler)
|
||||
|
||||
def test_tags_function(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -297,11 +304,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_function_via_global_checkpoint(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -316,17 +334,28 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_function_with_kwargs(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_with_kwargs(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
|
||||
gn, torch.sin(x), y, use_reentrant=False
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
@ -336,11 +365,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_sequential_layers(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_sequential_layers(self, device, partition_fn):
|
||||
def gn(x):
|
||||
x = x.cos()
|
||||
for _ in range(3):
|
||||
@ -361,11 +401,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
freqs=[2, 18],
|
||||
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_multiple_checkpoints(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_multiple_checkpoints(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -383,11 +434,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=6, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_module(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_module(self, device, partition_fn):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -411,11 +473,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_decomps(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_decomps(self, device, partition_fn):
|
||||
# Ensures that tags are passed on through decompositions as well
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -443,6 +516,7 @@ class ActivationCheckpointingViaTagsTests(
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=lambda: import_module(
|
||||
"torch._inductor.compile_fx"
|
||||
).select_decomp_table(),
|
||||
@ -702,7 +776,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
|
||||
def context_fn_must_recompute_mm():
|
||||
must_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -723,9 +804,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
),
|
||||
)
|
||||
|
||||
def _test(context_fn, bw_compiler):
|
||||
def _test(context_fn, bw_compiler, partition_fn):
|
||||
def gn(x):
|
||||
return torch.sigmoid(torch.matmul(x, x))
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
@ -739,14 +820,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freq=1,
|
||||
freq=2,
|
||||
op=torch.ops.aten.mm.default,
|
||||
)
|
||||
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@ -754,17 +835,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
context_fn=context_fn_must_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
|
||||
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
_test(
|
||||
context_fn=context_fn_no_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=2, # 2 bwd mm ops per fwd matmul
|
||||
freq=4, # 2 bwd mm ops per fwd matmul
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
def test_sac_with_partial_context_fn(self):
|
||||
@ -801,7 +884,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(
|
||||
self, device, partition_fn
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -841,15 +933,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
||||
self, device
|
||||
self, device, partition_fn
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
@ -889,7 +988,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
@ -897,7 +996,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
|
||||
# Copy of the above test, but make sure that having a triton kernel in the
|
||||
# region does not error.
|
||||
def add_one(x):
|
||||
@ -957,14 +1063,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1007,14 +1120,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
|
||||
def _get_custom_policy(meta):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1072,14 +1192,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn(no_recompute_list):
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
@ -1118,14 +1245,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1163,14 +1297,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_list_ops(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
# recompute everything
|
||||
no_recompute_list = []
|
||||
@ -1206,7 +1347,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1217,7 +1358,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
"requires TorchDispatchMode + torch.compile work to complete"
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1257,7 +1405,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1265,7 +1413,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
def test_compile_selective_checkpoint_random_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
|
||||
for preserve_rng_state in [True, False]:
|
||||
|
||||
def selective_checkpointing_context_fn():
|
||||
@ -1312,7 +1467,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
|
||||
@ -1324,7 +1479,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_invalid_context(self):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y)) * y
|
||||
|
||||
@ -1353,7 +1515,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "must generate a tuple of two `TorchDispatchMode`s"
|
||||
@ -1362,7 +1524,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
def test_compile_selective_checkpoint_parametrization(self):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
|
||||
def sac_policy():
|
||||
def _recomp_policy():
|
||||
def _custom_policy(ctx, func, *args, **kwargs):
|
||||
@ -1425,7 +1594,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freqs=[
|
||||
2, # 1 from mul recompute, 1 from mul backward
|
||||
# 1 from mul recompute, 1 from mul backward
|
||||
# w/o CSE, we have one extra mul
|
||||
3 if partition_fn is default_partition else 2,
|
||||
1,
|
||||
],
|
||||
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
||||
@ -1434,7 +1605,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
model = MLPModule()
|
||||
|
||||
@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
|
||||
return grad_output * x, grad_output * x
|
||||
|
||||
def f(a, b):
|
||||
return FwBwMutation.apply(a, b)
|
||||
return FwBwMutation.apply(a, b).sin_().clone()
|
||||
|
||||
inps = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
return (mul, add)""",
|
||||
clone = torch.ops.aten.clone.default(mul)
|
||||
sin_ = torch.ops.aten.sin_.default(mul); mul = None
|
||||
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
|
||||
return (clone_1, add, clone)""",
|
||||
)
|
||||
|
||||
# important bit: there is 1 mutation in the bw
|
||||
self.assertExpectedInline(
|
||||
bw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, add, tangents_1):
|
||||
def forward(self, add, clone, tangents_1):
|
||||
cos = torch.ops.aten.cos.default(clone); clone = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
|
||||
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
|
||||
return (mul_1, None)""",
|
||||
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
|
||||
return (mul_2, None)""",
|
||||
)
|
||||
|
||||
def test_fw_bw_mutation_no_functionalization2(self):
|
||||
|
||||
@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
self.assertEqual(len(mm_nodes), 4)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
|
||||
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)
|
||||
|
||||
@ -2476,12 +2476,11 @@ class CommonTemplate:
|
||||
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
||||
self.common(fn, (a, b_int8pack, b_scales, c))
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
|
||||
def test__dyn_quant_pack_4bit_weight(self):
|
||||
def test__dyn_quant_pack_4bit_weight_fp32(self):
|
||||
q_group = 32
|
||||
k = 128
|
||||
n = 128
|
||||
@ -2512,12 +2511,46 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (b, in_features, out_features))
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
|
||||
def test__dyn_quant_pack_4bit_weight_bf16(self):
|
||||
q_group = 32
|
||||
k = 128
|
||||
n = 128
|
||||
|
||||
torch.manual_seed(1)
|
||||
b = torch.rand((k, n), dtype=torch.bfloat16)
|
||||
in_features = b.size(0)
|
||||
out_features = b.size(1)
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def fn(b, in_features, out_features):
|
||||
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
|
||||
return b_int4pack
|
||||
|
||||
self.common(fn, (b, in_features, out_features))
|
||||
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
|
||||
def test__dyn_quant_matmul_4bit(self):
|
||||
def test__dyn_quant_matmul_4bit_fp32_input(self):
|
||||
q_group = 32
|
||||
m = 32
|
||||
k = 128
|
||||
@ -2557,6 +2590,60 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (a, q_group, in_features, out_features))
|
||||
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
|
||||
def test__dyn_quant_matmul_4bit_bf16_input(self):
|
||||
m = 32
|
||||
k = 128
|
||||
n = 128
|
||||
q_group = k
|
||||
|
||||
torch.manual_seed(1)
|
||||
a = torch.rand((m, k), dtype=torch.bfloat16)
|
||||
b = torch.rand((k, n), dtype=torch.bfloat16)
|
||||
|
||||
# codegen_dynamic_shape test fails without explicitly marking these dynamic
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(b, 1)
|
||||
|
||||
in_features = b.size(0)
|
||||
out_features = b.size(1)
|
||||
|
||||
if not self.is_dtype_supported(torch.bfloat16):
|
||||
raise unittest.SkipTest(
|
||||
f"torch.bfloat16 not supported for device {self.device}"
|
||||
)
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def fn(a, q_group, in_features, out_features):
|
||||
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
|
||||
res = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
return res
|
||||
|
||||
self.common(fn, (a, q_group, in_features, out_features), atol=1, rtol=0.5)
|
||||
|
||||
def test_expanded_reduction(self):
|
||||
def fn(x, y):
|
||||
z = x * y
|
||||
|
||||
@ -7798,7 +7798,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test__dyn_quant_matmul_4bit(self, device, m, k, n):
|
||||
def test__dyn_quant_matmul_4bit_fp32(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
@ -7870,7 +7870,86 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n):
|
||||
def test__dyn_quant_matmul_4bit_bf16(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
|
||||
torch.manual_seed(1)
|
||||
a_bfloat16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
b_bfloat16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
|
||||
in_features = b_bfloat16.size(0)
|
||||
out_features = b_bfloat16.size(1)
|
||||
q_group = in_features
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def dyn_quant_matmul_4bit(
|
||||
a, b_int4pack, q_group, in_features, out_features
|
||||
):
|
||||
return torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
|
||||
b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight(
|
||||
b_bfloat16, in_features, out_features
|
||||
)
|
||||
|
||||
dtypes = [torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
a = a_bfloat16.to(dtype=dtype)
|
||||
b = b_bfloat16.to(dtype=dtype)
|
||||
ref = torch.mm(a, b)
|
||||
res = dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
|
||||
# Mean relative error check
|
||||
expected_mean_err = 0.00952
|
||||
mean_err_tol = 0.005 # allow small deviation (±0.005)
|
||||
mean_err = ((res - ref).abs() / ref.abs().clamp(min=1e-5)).mean()
|
||||
self.assertTrue(
|
||||
abs(mean_err - expected_mean_err) < mean_err_tol,
|
||||
f"Mean relative error {mean_err:.6f} deviates from expected {expected_mean_err}"
|
||||
)
|
||||
|
||||
# Elementwise relative error check
|
||||
elementwise_diff = (res - ref).abs()
|
||||
elementwise_relative_error = elementwise_diff / ref.abs().clamp(min=torch.finfo(ref.dtype).eps)
|
||||
self.assertTrue(
|
||||
torch.all(elementwise_relative_error < 0.070),
|
||||
"Some elements have relative error >= 7%"
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test_compile_dyn_quant_matmul_4bit_fp32(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
@ -7928,6 +8007,83 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test_compile_dyn_quant_matmul_4bit_bf16(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
|
||||
torch.manual_seed(1)
|
||||
a_bfloat16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
b_bfloat16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
|
||||
in_features = b_bfloat16.size(0)
|
||||
out_features = b_bfloat16.size(1)
|
||||
q_group = in_features
|
||||
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b_bfloat16, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16)
|
||||
|
||||
@torch.compile
|
||||
def dyn_quant_matmul_4bit(
|
||||
a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features
|
||||
):
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
return torch._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
|
||||
res = dyn_quant_matmul_4bit(
|
||||
a_bfloat16,
|
||||
b_uint8,
|
||||
b_scales_and_zeros,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
ref = torch.mm(a_bfloat16, b_bfloat16)
|
||||
|
||||
# === Accuracy checks ===
|
||||
|
||||
# Mean relative error check
|
||||
expected_mean_err = 0.00952
|
||||
mean_err_tol = 0.005 # allow small deviation (±0.005)
|
||||
mean_err = ((res - ref).abs() / ref.abs().clamp(min=1e-5)).mean()
|
||||
self.assertTrue(
|
||||
abs(mean_err - expected_mean_err) < mean_err_tol,
|
||||
f"Mean relative error {mean_err:.6f} deviates from expected {expected_mean_err}"
|
||||
)
|
||||
|
||||
# Avoid divide-by-zero with clamp
|
||||
denominator = ref.abs().clamp(min=torch.finfo(ref.dtype).eps)
|
||||
|
||||
# Compute elementwise relative error — always non-negative
|
||||
elementwise_relative_error = (res - ref).abs() / denominator
|
||||
|
||||
# Check if all elements are within 6% error
|
||||
assert torch.all(elementwise_relative_error >= 0), "Relative error should never be negative"
|
||||
self.assertTrue(
|
||||
torch.all(elementwise_relative_error < 0.070),
|
||||
"Some elements have relative error >= 7%"
|
||||
)
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("m", [32, 64])
|
||||
@parametrize("k", [32, 64])
|
||||
@parametrize("n", [48, 64])
|
||||
|
||||
2
third_party/kineto
vendored
2
third_party/kineto
vendored
Submodule third_party/kineto updated: 57c561f4ca...1725f1a4d2
@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_proxy_tensor_disable_update_tensor_tracker,
|
||||
get_proxy_mode,
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
@ -295,6 +296,10 @@ def create_joint(
|
||||
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
|
||||
fn, primals
|
||||
)
|
||||
mode = get_proxy_mode()
|
||||
assert mode is not None
|
||||
for node in mode.tracer.graph.nodes:
|
||||
node.meta["partitioner_tag"] = "is_forward"
|
||||
|
||||
# TODO: I think this hook can also be eliminated now
|
||||
if joint_fn_handle and joint_fn_handle.post_forward:
|
||||
|
||||
@ -51,6 +51,7 @@ from ._activation_checkpointing.knapsack import (
|
||||
)
|
||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
|
||||
from ._aot_autograd.functional_utils import assert_functional_graph
|
||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
|
||||
@ -297,6 +298,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_backward"
|
||||
|
||||
|
||||
def _has_tag_is_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_forward"
|
||||
|
||||
|
||||
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
|
||||
|
||||
@ -1021,105 +1026,95 @@ def default_partition(
|
||||
Returns:
|
||||
Returns the generated forward and backward Fx graph modules.
|
||||
"""
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
# Respect the original placement of ops rather than rely on dataflow.
|
||||
forward_nodes = []
|
||||
last_node = None
|
||||
for node in joint_module.graph.nodes:
|
||||
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
|
||||
last_node = node
|
||||
assert last_node is not None
|
||||
for node in joint_module.graph.nodes:
|
||||
if not _is_tangent(node):
|
||||
forward_nodes.append(node)
|
||||
if node is last_node:
|
||||
break
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
node.name for node in forward_nodes if node.op != "output"
|
||||
)
|
||||
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
assert_functional_graph(joint_module.graph)
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
|
||||
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
def is_mutated_later_in_fw(node):
|
||||
if _has_tag_is_backward(node):
|
||||
return False
|
||||
tensor_arg_aliases = [
|
||||
x
|
||||
for x in node.args
|
||||
if isinstance(x, fx.Node)
|
||||
and "val" in x.meta
|
||||
and isinstance(x.meta["val"], torch.Tensor)
|
||||
]
|
||||
while len(tensor_arg_aliases) > 0:
|
||||
a = tensor_arg_aliases.pop()
|
||||
for u in a.users:
|
||||
if not isinstance(u.target, torch._ops.OpOverload):
|
||||
continue
|
||||
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
|
||||
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
|
||||
if (
|
||||
# one of the args was mutated
|
||||
u.target._schema.is_mutable
|
||||
# and the mutation happens "later"
|
||||
and order[u] > order[node]
|
||||
# and the mutation happened during the forward
|
||||
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
|
||||
):
|
||||
for idx, alias_info in enumerate(u.target._schema.arguments):
|
||||
if alias_info.is_write and u.args[idx] is a:
|
||||
return True
|
||||
elif u.target.is_view:
|
||||
tensor_arg_aliases.append(u)
|
||||
return False
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
# if a node isn't "required" to be in the forward, but any of its arguments
|
||||
# are later mutated in the forward, then it must have been run in the forward
|
||||
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
|
||||
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
|
||||
if is_mutated_later_in_fw(node):
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes.append(node)
|
||||
elif (
|
||||
continue
|
||||
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if node.is_impure(impure_random=False) and node.op not in (
|
||||
"placeholder",
|
||||
"output",
|
||||
):
|
||||
# See is_impure in torch/fx/node.py
|
||||
assert not graph_has_recomputable_ops, (
|
||||
"Trying to apply AC on a graph with impure op",
|
||||
node,
|
||||
node.target,
|
||||
)
|
||||
saved_values.append(node)
|
||||
continue
|
||||
backward_usages = [n for n in node.users if n.name not in forward_node_names]
|
||||
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
continue
|
||||
if (
|
||||
"tensor_meta" not in node.meta
|
||||
and node.op == "call_function"
|
||||
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
|
||||
):
|
||||
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
||||
users = node.users
|
||||
assert all(user.target is operator.getitem for user in users)
|
||||
saved_values.extend(users)
|
||||
else:
|
||||
backward_usages = [
|
||||
n for n in node.users if n.name not in forward_node_names
|
||||
]
|
||||
if "tensor_meta" in node.meta and all(
|
||||
is_sym_node(n) for n in backward_usages
|
||||
):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
else:
|
||||
saved_values.append(node)
|
||||
assert all(user.target == operator.getitem for user in node.users)
|
||||
continue
|
||||
if not must_recompute(node):
|
||||
saved_values.append(node)
|
||||
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
return _extract_fwd_bwd_modules(
|
||||
if config._sync_decision_cross_ranks:
|
||||
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
|
||||
|
||||
if static_lifetime_input_nodes is None:
|
||||
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
|
||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
@ -1127,6 +1122,24 @@ def default_partition(
|
||||
static_lifetime_input_nodes=static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
fw_module, bw_module = functionalize_rng_ops(
|
||||
joint_module, fw_module, bw_module, len(saved_sym_nodes)
|
||||
)
|
||||
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
||||
|
||||
# raise all getitem ops to as early as possible
|
||||
# this is helpful for memory, especially in the case of aot_eager backend
|
||||
fw_module = raise_getitems(fw_module)
|
||||
bw_module = raise_getitems(bw_module)
|
||||
|
||||
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
|
||||
if len(node_info.required_bw_nodes) > 0:
|
||||
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
|
||||
|
||||
return fw_module, bw_module
|
||||
|
||||
|
||||
INT_INF = int(1e6)
|
||||
|
||||
@ -1621,7 +1634,9 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
||||
break
|
||||
|
||||
|
||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
def cleanup_recompute_tags(
|
||||
joint_module: fx.GraphModule, *, is_default_partition: bool
|
||||
) -> fx.GraphModule:
|
||||
"""
|
||||
If there are two consecutive checkpointed blocks with no operator in
|
||||
between, we would still want to stash the tensor at the boundary of
|
||||
@ -1658,6 +1673,16 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
|
||||
# in forward graph outputs. With this, we can break the above circular dependency.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
elif (
|
||||
"ac_graph_id" not in node.meta
|
||||
and any(must_recompute(user) for user in node.users)
|
||||
and is_default_partition
|
||||
):
|
||||
# This node is not part of the AC region and a user is marked as recompute.
|
||||
# This means it's an input to the AC region and we should save it.
|
||||
# For ease of landing, gate this to default partitioner only, but we should think
|
||||
# about flipping the switch in general as well.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
return joint_module
|
||||
|
||||
|
||||
@ -2765,6 +2790,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
|
||||
return module
|
||||
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
|
||||
def min_cut_rematerialization_partition(
|
||||
joint_module: fx.GraphModule,
|
||||
_joint_inputs,
|
||||
@ -2813,68 +2891,16 @@ def min_cut_rematerialization_partition(
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
joint_module = cleanup_recompute_tags(joint_module)
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(
|
||||
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
|
||||
)
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
# networkx blows up on graphs with no required backward nodes
|
||||
# Since there's nothing to partition anyway, and the default partitioner can "handle"
|
||||
|
||||
@ -7099,19 +7099,13 @@ def sym_constrain_range(a, min=None, max=None):
|
||||
@register_lowering(aten.sym_size.int)
|
||||
def sym_size(a, dim):
|
||||
val = V.graph.current_node.meta["val"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
return val.node.expr
|
||||
else:
|
||||
return int(val)
|
||||
return val.node.expr
|
||||
|
||||
|
||||
@register_lowering(aten.sym_stride.int)
|
||||
def sym_stride(a, dim):
|
||||
val = V.graph.current_node.meta["val"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
return val.node.expr
|
||||
else:
|
||||
return int(val)
|
||||
return val.node.expr
|
||||
|
||||
|
||||
@register_lowering(aten.sym_numel)
|
||||
|
||||
@ -3722,6 +3722,7 @@ def kai_roundup(a: int, b: int) -> int:
|
||||
|
||||
def get_kai_packed_weight_size(n_bits, N, K, groupsize):
|
||||
if n_bits == 4:
|
||||
# Works for both fp32 and bf16 Kernels
|
||||
if groupsize == K: # channelwise
|
||||
# dotprod params only [1x8x32_neon_dotprod]
|
||||
kai_nr = 8
|
||||
@ -3851,6 +3852,8 @@ def meta__dyn_quant_pack_4bit_weight(
|
||||
)
|
||||
return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
|
||||
packed_weight_size = weights.numel() + scales_zeros.numel()
|
||||
if bias is not None:
|
||||
packed_weight_size += bias.numel()
|
||||
return weights.new_empty(packed_weight_size, dtype=torch.float)
|
||||
|
||||
|
||||
@ -3864,8 +3867,12 @@ def meta__dyn_quant_matmul_4bit(
|
||||
):
|
||||
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
|
||||
torch._check(
|
||||
inp.dtype == torch.float32,
|
||||
lambda: f"expected input to be f32, got {inp.dtype}",
|
||||
(inp.dtype == torch.float32)
|
||||
or (inp.dtype == torch.bfloat16 and block_size == in_features),
|
||||
lambda: (
|
||||
f"expected input to be f32 or bf16 (bf16 requires block_size == in_features), "
|
||||
f"got {inp.dtype} with block_size={block_size} and in_features={in_features}"
|
||||
),
|
||||
)
|
||||
M = inp.size(0)
|
||||
return inp.new_empty(M, out_features, dtype=inp.dtype)
|
||||
|
||||
@ -702,7 +702,7 @@ def exp2(a):
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
@out_wrapper()
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a",),
|
||||
type_promoting_args=("a,"),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
|
||||
)
|
||||
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
|
||||
|
||||
import torch
|
||||
@ -17,9 +17,6 @@ from torch.utils._pytree import tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary, weakref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
_TOTAL_KEY = "Total"
|
||||
|
||||
__all__ = ["FSDPMemTracker"]
|
||||
@ -368,28 +365,14 @@ class FSDPMemTracker(MemTracker):
|
||||
# `FSDPParamGroup.post_forward` because during AC these won't be called.
|
||||
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
|
||||
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
|
||||
|
||||
# get the unique _MultiHandlers/RemoveHandlers and store in dictionary
|
||||
# the _MultiHandlers object will only need to be grabbed once.
|
||||
unique_handlers: dict[RemovableHandle, bool] = {}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
if not unique_handlers.get(fsdp_state._pre_forward_hook_handle):
|
||||
unique_handlers[fsdp_state._pre_forward_hook_handle] = True
|
||||
if not unique_handlers.get(fsdp_state._post_forward_hook_handle):
|
||||
unique_handlers[fsdp_state._post_forward_hook_handle] = True
|
||||
# call remove on the handles once
|
||||
for f_hook_handle in unique_handlers.keys():
|
||||
f_hook_handle.remove()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
|
||||
fsdp_state._pre_forward_hook_handle.remove()
|
||||
fsdp_state._post_forward_hook_handle.remove()
|
||||
fsdp_state._pre_forward_hook_handle = (
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
module.register_forward_pre_hook(
|
||||
|
||||
@ -259,13 +259,14 @@ else:
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_rank_map = tuple(self._rank_map.tolist())
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._thread_id = None
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
# Skip process group initialization if xla device or init backend is False
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
self._thread_id = None
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
@ -296,6 +297,11 @@ else:
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_rank_map = tuple(self._rank_map.tolist())
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""Returns the device type of the mesh."""
|
||||
|
||||
@ -19,13 +19,8 @@ __all__: list[str] = [
|
||||
"SDPBackend",
|
||||
"sdpa_kernel",
|
||||
"WARN_FOR_UNFUSED_KERNELS",
|
||||
"register_flash_attention_impl",
|
||||
"activate_flash_attention_impl",
|
||||
"list_flash_attention_impls",
|
||||
"current_flash_attention_impl",
|
||||
]
|
||||
|
||||
|
||||
# Note: [SDPA warnings]
|
||||
# TODO: Consider using this for sdpa regardless of subclasses
|
||||
# This only effects users of bias subclasses
|
||||
@ -167,23 +162,3 @@ def _sdpa_kernel_variadic(*backends: SDPBackend):
|
||||
def _get_flash_version() -> str:
|
||||
"""This returns the closest matching tag for the flash attention backend"""
|
||||
return "2.5.7"
|
||||
|
||||
|
||||
from . import _registry
|
||||
|
||||
|
||||
# Re-export registry types and functions for public API
|
||||
_FlashAttentionImpl = _registry._FlashAttentionImpl
|
||||
_RegisterFn = _registry._RegisterFn
|
||||
register_flash_attention_impl = _registry.register_flash_attention_impl
|
||||
activate_flash_attention_impl = _registry.activate_flash_attention_impl
|
||||
list_flash_attention_impls = _registry.list_flash_attention_impls
|
||||
current_flash_attention_impl = _registry.current_flash_attention_impl
|
||||
|
||||
register_flash_attention_impl.__module__ = __name__
|
||||
activate_flash_attention_impl.__module__ = __name__
|
||||
list_flash_attention_impls.__module__ = __name__
|
||||
current_flash_attention_impl.__module__ = __name__
|
||||
|
||||
# Import built-in implementations to trigger self-registration
|
||||
from . import _fa4 # noqa: F401
|
||||
|
||||
@ -1,444 +0,0 @@
|
||||
"""UBER PROTOTYPE!!!"""
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
from . import _registry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
|
||||
__all__ = [
|
||||
"register_flash_attention_fa4",
|
||||
]
|
||||
|
||||
|
||||
_FA4_MODULE_PATH: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FA4Handle:
|
||||
library: Library | None
|
||||
|
||||
def remove(self) -> None:
|
||||
self.library = None
|
||||
|
||||
|
||||
@cache
|
||||
def _get_device_major(device: torch.device) -> int:
|
||||
major, _ = torch.cuda.get_device_capability(device)
|
||||
return major
|
||||
|
||||
|
||||
def register_flash_attention_fa4(
|
||||
module_path: str = "flash_attn.cute.interface",
|
||||
) -> _FA4Handle:
|
||||
"""
|
||||
Register FA4 flash attention kernels with the PyTorch dispatcher.
|
||||
|
||||
Args:
|
||||
module_path: Python module path to the FA4 implementation.
|
||||
"""
|
||||
global _FA4_MODULE_PATH
|
||||
_ = _fa4_import_module(module_path)
|
||||
_FA4_MODULE_PATH = module_path
|
||||
return _FA4Handle(_fa4_register_kernels())
|
||||
|
||||
|
||||
@cache
|
||||
def _fa4_import_module(module_path: str) -> ModuleType:
|
||||
module = importlib.import_module(module_path)
|
||||
if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
|
||||
raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
|
||||
return module
|
||||
|
||||
|
||||
def _fa4_register_kernels() -> Library:
|
||||
lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
|
||||
lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
|
||||
lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
|
||||
lib.impl(
|
||||
"_scaled_dot_product_flash_attention",
|
||||
_fa4_scaled_dot_product_flash_attention_forward_impl,
|
||||
"CUDA",
|
||||
)
|
||||
lib.impl(
|
||||
"_scaled_dot_product_flash_attention_backward",
|
||||
_fa4_scaled_dot_product_flash_attention_backward_impl,
|
||||
"CUDA",
|
||||
)
|
||||
return lib
|
||||
|
||||
|
||||
def _fa4_common_support_error(
|
||||
query: torch.Tensor,
|
||||
tensors: tuple[torch.Tensor, ...],
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
|
||||
) -> str | None:
|
||||
if not all(t.is_cuda for t in tensors):
|
||||
return "inputs must be CUDA tensors"
|
||||
if len({t.device for t in tensors}) != 1:
|
||||
return "inputs must share device"
|
||||
if query.dtype not in (torch.float16, torch.bfloat16):
|
||||
return "query dtype must be float16 or bfloat16"
|
||||
for name, tensor in require_fp32:
|
||||
if tensor.dtype != torch.float32:
|
||||
return f"{name} dtype must be float32"
|
||||
if cum_seq_q is None and query.dim() != 4:
|
||||
return "dense query must be 4D"
|
||||
if cum_seq_q is not None and query.dim() != 3:
|
||||
return "ragged query must be 3D"
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA not available"
|
||||
if _get_device_major(query.device) not in (9, 10):
|
||||
return "FA4 requires compute capability 9.0 or 10.0"
|
||||
return None
|
||||
|
||||
|
||||
def _fa4_forward_support_error(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float,
|
||||
return_debug_mask: bool,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
seqused_k: torch.Tensor | None,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
) -> str | None:
|
||||
if dropout_p != 0.0:
|
||||
return "dropout_p must be 0"
|
||||
if return_debug_mask:
|
||||
return "return_debug_mask must be False"
|
||||
if alibi_slopes is not None:
|
||||
return "alibi_slopes not supported"
|
||||
if seqused_k is not None:
|
||||
if seqused_k.dtype != torch.int32:
|
||||
return "seqused_k must be int32"
|
||||
if not seqused_k.is_cuda:
|
||||
return "seqused_k must be CUDA"
|
||||
error = _fa4_common_support_error(
|
||||
query,
|
||||
(query, key, value),
|
||||
cum_seq_q,
|
||||
)
|
||||
if error is not None:
|
||||
if error == "inputs must share device":
|
||||
return "query, key, value must be on same device"
|
||||
return error
|
||||
return None
|
||||
|
||||
|
||||
def _fa4_backward_support_error(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
dropout_p: float,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
window_size_left: int | None,
|
||||
window_size_right: int | None,
|
||||
) -> str | None:
|
||||
if dropout_p != 0.0:
|
||||
return "dropout_p must be 0"
|
||||
if window_size_left is not None or window_size_right is not None:
|
||||
return "windowed attention not supported"
|
||||
error = _fa4_common_support_error(
|
||||
query,
|
||||
(grad_out, query, key, value, out, logsumexp),
|
||||
cum_seq_q,
|
||||
require_fp32=(("logsumexp", logsumexp),),
|
||||
)
|
||||
if error is not None:
|
||||
return error
|
||||
return None
|
||||
|
||||
|
||||
Ts = TypeVarTuple("Ts")
|
||||
|
||||
|
||||
def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
|
||||
return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _fa4_run_forward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor | None,
|
||||
cu_seq_k: torch.Tensor | None,
|
||||
scale: float | None,
|
||||
is_causal: bool,
|
||||
window_size_left: int | None,
|
||||
window_size_right: int | None,
|
||||
seqused_k: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if _FA4_MODULE_PATH is None:
|
||||
raise RuntimeError("FA4 not registered")
|
||||
module = _fa4_import_module(_FA4_MODULE_PATH)
|
||||
kwargs: dict[str, Any] = {
|
||||
"softmax_scale": scale,
|
||||
"causal": is_causal,
|
||||
"window_size_left": window_size_left,
|
||||
"window_size_right": window_size_right,
|
||||
"return_lse": True,
|
||||
"cu_seqlens_q": cu_seq_q,
|
||||
"cu_seqlens_k": cu_seq_k,
|
||||
"seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
|
||||
}
|
||||
out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
|
||||
return out, lse.contiguous()
|
||||
|
||||
|
||||
def _fa4_run_backward(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor | None,
|
||||
cu_seq_k: torch.Tensor | None,
|
||||
scale: float | None,
|
||||
is_causal: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if _FA4_MODULE_PATH is None:
|
||||
raise RuntimeError("FA4 not registered")
|
||||
module = _fa4_import_module(_FA4_MODULE_PATH)
|
||||
dq, dk, dv = module._flash_attn_bwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
grad_out,
|
||||
logsumexp.contiguous(),
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
cu_seqlens_q=cu_seq_q,
|
||||
cu_seqlens_k=cu_seq_k,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
def _fa4_flash_attention_forward_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
return_debug_mask: bool,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
window_size_left: int | None = None,
|
||||
window_size_right: int | None = None,
|
||||
seqused_k: torch.Tensor | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
):
|
||||
error = _fa4_forward_support_error(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
return_debug_mask,
|
||||
alibi_slopes,
|
||||
seqused_k,
|
||||
cum_seq_q,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
|
||||
out, lse = _fa4_run_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
seqused_k,
|
||||
)
|
||||
rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
|
||||
philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
|
||||
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
||||
return out, lse, rng_state, philox_offset, debug_mask
|
||||
|
||||
|
||||
def _fa4_flash_attention_backward_impl(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
unused: torch.Tensor,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
window_size_left: int | None = None,
|
||||
window_size_right: int | None = None,
|
||||
):
|
||||
error = _fa4_backward_support_error(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
dropout_p,
|
||||
cum_seq_q,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
|
||||
dq, dk, dv = _fa4_run_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
scale,
|
||||
is_causal,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
def _fa4_scaled_dot_product_flash_attention_forward_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
return_debug_mask: bool = False,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
):
|
||||
error = _fa4_forward_support_error(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
return_debug_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
|
||||
q, k, v = _transpose_dense(query, key, value)
|
||||
|
||||
max_q_flash = q.size(1)
|
||||
max_k_flash = k.size(1)
|
||||
out, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
max_q_flash,
|
||||
max_k_flash,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
return_debug_mask,
|
||||
scale=scale,
|
||||
)
|
||||
(out,) = _transpose_dense(out)
|
||||
max_q = query.size(2)
|
||||
max_k = key.size(2)
|
||||
return (
|
||||
out,
|
||||
lse,
|
||||
None,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
rng_state,
|
||||
philox_offset,
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
||||
def _fa4_scaled_dot_product_flash_attention_backward_impl(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
philox_seed: torch.Tensor,
|
||||
philox_offset: torch.Tensor,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
):
|
||||
error = _fa4_backward_support_error(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
dropout_p,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
|
||||
q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
|
||||
max_q = query.size(2)
|
||||
max_k = key.size(2)
|
||||
dq, dk, dv = _fa4_flash_attention_backward_impl(
|
||||
go,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
logsumexp,
|
||||
None,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
scale=scale,
|
||||
)
|
||||
dq, dk, dv = _transpose_dense(dq, dk, dv)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)
|
||||
@ -1,108 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Registry for flash attention implementations.
|
||||
|
||||
This module contains the registration system for flash attention implementations.
|
||||
It has no torch dependencies to avoid circular imports during initialization.
|
||||
"""
|
||||
|
||||
from typing import Callable, Literal, Protocol
|
||||
|
||||
|
||||
class FlashAttentionHandle(Protocol):
|
||||
def remove(self) -> None: ...
|
||||
|
||||
|
||||
_RegisterFn = Callable[..., FlashAttentionHandle | None]
|
||||
_FlashAttentionImpl = Literal["FA4"]
|
||||
|
||||
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
|
||||
|
||||
_FLASH_ATTENTION_ACTIVE: str | None = None
|
||||
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
|
||||
|
||||
|
||||
def register_flash_attention_impl(
|
||||
impl: str | _FlashAttentionImpl,
|
||||
*,
|
||||
register_fn: _RegisterFn,
|
||||
) -> None:
|
||||
"""
|
||||
Register the callable that activates a flash attention impl.
|
||||
|
||||
.. note::
|
||||
This function is intended for SDPA backend providers to register their
|
||||
implementations. End users should use :func:`activate_flash_attention_impl`
|
||||
to activate a registered implementation.
|
||||
|
||||
Args:
|
||||
impl: Implementation identifier (e.g., ``"FA4"``).
|
||||
register_fn: Callable that performs the actual dispatcher registration.
|
||||
This function will be invoked by :func:`activate_flash_attention_impl`
|
||||
and should register custom kernels with the PyTorch dispatcher.
|
||||
It may optionally return a handle implementing
|
||||
:class:`FlashAttentionHandle` to keep any necessary state alive.
|
||||
|
||||
Example:
|
||||
>>> def my_impl_register(module_path: str = "my_flash_impl"):
|
||||
... # Register custom kernels with torch dispatcher
|
||||
... pass # doctest: +SKIP
|
||||
>>> register_flash_attention_impl(
|
||||
... "MyImpl", register_fn=my_impl_register
|
||||
... ) # doctest: +SKIP
|
||||
"""
|
||||
_FLASH_ATTENTION_IMPLS[impl] = register_fn
|
||||
|
||||
|
||||
def activate_flash_attention_impl(
|
||||
impl: str | _FlashAttentionImpl,
|
||||
) -> None:
|
||||
"""
|
||||
Activate into the dispatcher a previously registered flash attention impl.
|
||||
|
||||
.. note::
|
||||
Backend providers should NOT automatically activate their implementation
|
||||
on import. Users should explicitly opt-in by calling this function or via
|
||||
environment variables to ensure multiple provider libraries can coexist.
|
||||
|
||||
Args:
|
||||
impl: Implementation identifier to activate. See
|
||||
:func:`~torch.nn.attention.list_flash_attention_impls` for available
|
||||
implementations.
|
||||
If the backend's :func:`register_flash_attention_impl` callable
|
||||
returns a :class:`FlashAttentionHandle`, the registry keeps that
|
||||
handle alive for the lifetime of the process (until explicit
|
||||
uninstall support exists).
|
||||
|
||||
Example:
|
||||
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
|
||||
"""
|
||||
global _FLASH_ATTENTION_ACTIVE
|
||||
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
|
||||
if register_fn is None:
|
||||
raise ValueError(
|
||||
f"Unknown flash attention impl '{impl}'. "
|
||||
f"Available implementations: {list_flash_attention_impls()}"
|
||||
)
|
||||
# TODO: The only way to actually register a new impl is to unregister the current impl
|
||||
# reinstall the default impl and then register the new impl
|
||||
if _FLASH_ATTENTION_ACTIVE == impl:
|
||||
return
|
||||
|
||||
handle = register_fn()
|
||||
if handle is not None:
|
||||
_FLASH_ATTENTION_HANDLES[impl] = handle
|
||||
_FLASH_ATTENTION_ACTIVE = impl
|
||||
|
||||
|
||||
def list_flash_attention_impls() -> list[str]:
|
||||
"""Return the names of all available flash attention implementations."""
|
||||
return sorted(_FLASH_ATTENTION_IMPLS.keys())
|
||||
|
||||
|
||||
def current_flash_attention_impl() -> str | None:
|
||||
"""
|
||||
Return the currently activated flash attention impl name, if any.
|
||||
|
||||
``None`` indicates that no custom impl has been activated.
|
||||
"""
|
||||
return _FLASH_ATTENTION_ACTIVE
|
||||
Reference in New Issue
Block a user