Compare commits

..

1 Commits

Author SHA1 Message Date
ed2dcd679c Automated submodule update: kineto 2025-11-11 12:07:07 -08:00
31 changed files with 1227 additions and 997 deletions

View File

@ -96,6 +96,7 @@ function pip_build_and_install() {
python3 -m pip wheel \
--no-build-isolation \
--no-deps \
--no-use-pep517 \
-w "${wheel_dir}" \
"${build_target}"
fi

View File

@ -1,4 +1,4 @@
name: docker-cache-rocm
name: docker-cache-mi300
on:
workflow_run:
@ -31,7 +31,7 @@ jobs:
- name: Download artifacts
uses: actions/download-artifact@v4.1.7
with:
run-id: ${{ github.event.workflow_run.id }}
run_id: ${{ github.event.workflow_run.id }}
path: ./docker-builds-artifacts
merge-multiple: true
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -3541,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
const int64_t out_features) {
auto M = inp.size(0);
TORCH_CHECK(
inp.dtype() == kFloat,
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
__func__,
" : expect input to be 32-bit float tensor.");
" : expect input to be float32 or bfloat16 tensor.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),

View File

@ -8,6 +8,7 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h>
#include <cmath>
#include <c10/util/Unroll.h>
#include <c10/util/irange.h>
@ -793,6 +794,139 @@ bool can_use_kleidiai(
}
#endif
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
size_t m,
size_t n,
size_t k,
const uint16_t* lhs_bf16,
const uint8_t* rhs_qs4cx,
const float* rhs_scales,
uint16_t* dst_bf16,
float scalar_min,
float scalar_max,
const float* bias) {
// Roundup lambda for internal stride calculations
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
// Cast bfloat16 to float32 inline
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
float f;
std::memcpy(&f, &tmp, sizeof(f));
return f;
};
// Cast float32 to bfloat16 inline
auto cast_f32_to_bf16 = [](float f) {
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
return static_cast<uint16_t>(bits >> 16);
};
// Quantization pack lambda (channelwise QA8DX)
auto quant_pack_8bit_channelwise =
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
for (size_t i = 0; i < M; ++i) {
const uint16_t* row_ptr = src_bf16 + i * K;
// find min/max
float mn = FLT_MAX, mx = -FLT_MAX;
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
mn = std::min(mn, v);
mx = std::max(mx, v);
}
float rmin = std::min(0.0f, mn);
float rmax = std::max(0.0f, mx);
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
float recip = scale ? 1.0f / scale : 0.0f;
int32_t zp;
float des_min = rmin * scale;
float des_max = rmax * scale;
float err_min = qmin + des_min;
float err_max = qmax + des_max;
float zp_f =
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
zp_f = std::clamp(zp_f, qmin, qmax);
zp = std::lrintf(zp_f);
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
// store header
*reinterpret_cast<float*>(out_ptr) = recip;
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
out_ptr += sizeof(float) + sizeof(int32_t);
// quantize
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
q = std::clamp(
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
*out_ptr++ = static_cast<int8_t>(q);
}
}
};
// MatMul lambda (MXN x MXK -> MNXK BF16)
auto matmul_kernel = [&](size_t M,
size_t N,
size_t K,
const int8_t* lhs,
const uint8_t* rhs,
const float* scales,
uint16_t* dst,
float lo,
float hi) {
const size_t lhs_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
const size_t rhs_stride = roundup(K, 2) / 2;
for (size_t i = 0; i < M; ++i) {
const int8_t* lhs_row = lhs + i * lhs_stride;
for (size_t j = 0; j < N; ++j) {
int32_t acc = 0;
const int8_t* lptr = lhs_row;
const uint8_t* rptr = rhs + j * rhs_stride;
float lhs_scale = *reinterpret_cast<const float*>(lptr);
int32_t lhs_off =
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
lptr += sizeof(float) + sizeof(int32_t);
for (size_t t = 0; t < K; ++t) {
int32_t lv = static_cast<int32_t>(lptr[t]);
uint8_t bv = rptr[t / 2];
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
: (static_cast<int32_t>(bv >> 4) - 8);
acc += lv * rv + lhs_off * rv;
}
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
if (bias) {
res += bias[j];
}
res = std::clamp(res, lo, hi);
*dst++ = cast_f32_to_bf16(res);
}
}
};
// allocate and run
std::unique_ptr<int8_t[]> packed(
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
matmul_kernel(
m,
n,
k,
packed.get(),
rhs_qs4cx,
rhs_scales,
dst_bf16,
scalar_min,
scalar_max);
}
/**
* The Int4 quantized weights must be represented as a uint8 tensor
* For matrix multiplication with a weight shape of (N x K)
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
#if AT_KLEIDIAI_ENABLED()
if (can_use_kleidiai(scales_zeros, K, block_size)) {
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
packed_weights.resize_({weight_packed_size});
kleidiai::kai_pack_int4_rhs(
packed_weights, weights, scales_zeros, bias, N, K, block_size);
} else
#endif
{
TORCH_CHECK(
bias.has_value() == 0,
__func__,
" : Bias is unsupported in reference implementation");
packed_weights = packed_weights.to(kFloat);
auto weight_reshaped = weights.view({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
if (bias.has_value()) {
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
}
auto res = at::cat(tensors_to_cat, 0);
packed_weights.resize_(res.sizes()).copy_(res);
}
}
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
const float* rhs_scales_f32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
// required format for matmul
auto input_quant_pack_8bit_channelwise =
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
}
// Maximum/minimum int8 values
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
zero_point0 = std::min(zero_point0, qmax);
// Round to nearest integer
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = v0_s32 + nudged_zero_point0;
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[n_idx];
}
// Clamp (min-max) operation
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float* rhs_scales_fp32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
// Lambda for LHS quantization
auto lhs_quant_pack = [&](size_t m,
size_t k,
const float* lhs_f32,
int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
min0 = std::min(src0_0, min0);
}
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
zero_point0 = std::max(zero_point0, qmin);
zero_point0 = std::min(zero_point0, qmax);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float src0_0 = src_ptr[k_idx];
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = std::max(
std::min(
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
static_cast<int32_t>(INT8_MIN));
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
static_cast<int32_t>(kI8Min));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[col_idx];
}
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
/**
* Dynamic Input Quant 4 bit weights matmul execution flow
(INT4 Weights + FP scales + FP32 Bias)
FP32 Input Packed Buffer
| |
Quantize Cast
to INT8 to INT8
| |
v v
INT8 Input INT8 Weights
\ /
\ /
\ /
INT8 Matrix Multiplication
|
v
FP32 Dequantized and Accumulate in FP32
|
v
FP32 Final Output
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
* Dynamic INT4 weight-only MatMul with per-row input quantization.
*
* Execution Flow:
*
* (INT4 Weights + FP Scales [+ optional Bias])
*
* Input (FP32 or BF16) Packed Weight Buffer
* | |
* Row-wise Quantization (INT8) |
* | |
* INT8 Input Activation INT4 Quantized Weights + Scales
* \ /
* \ /
* Quantized Matrix Multiply
* |
* Output Tensor (BF16 or FP32)
*
* Notes:
* - Groupwise kernels expect BF16 scales
* - Channelwise kernels expect FP32 scales
* - Bias is currently unsupported in fallback path
*/
void dyn_quant_matmul_4bit_kernel(
const Tensor& output,
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface internally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
} else
#endif
{
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
const auto weights_size = N * K / 2;
// The weights needs to be in uint8_t data type after quantization
auto extracted_weights =
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
auto float32_scales =
(packed_weights.narrow(
0, weights_size, packed_weights.size(0) - weights_size))
.to(kFloat);
uint8_t* rhs_4bit =
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M,
N,
K,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M,
N,
K,
block_size,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else {
TORCH_CHECK(
block_size == K || (!(block_size % 32) && !(K % block_size)),
__func__,
": Group size should be multiple 32 or in_features [",
K,
"]. Provided ",
block_size);
{
void* input = inp.data_ptr();
void* dst = output.data_ptr();
// Extract weights, sclaes and biases form from packed tensor
const int weights_elements = N * K / 2;
const int scale_elements = N * (K / block_size);
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
float* weight_scales = float32_scales.data_ptr<float>();
void* bias_data = nullptr;
if (bias_elements) {
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
bias_data = float32_bias.data_ptr();
}
// 2 elements of 4 bit weights are packed into 1 uint8 packet
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
// Dispatch to reference kernels
if (inp.scalar_type() == at::kBFloat16) {
// BF16 input, BF16 output
constexpr float BF16_MAX = 3.38953139e+38f;
constexpr float BF16_MIN = -BF16_MAX;
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
M, N, K,
(uint16_t*)input, weights_4bit, weight_scales,
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
}
} else if (inp.scalar_type() == at::kFloat) {
// FP32 input, FP32 output
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M, N, K,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M, N, K, block_size,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
}
} else {
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
}
}
}
}
} // anonymous namespace
}
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)

View File

@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
if (weight.scalar_type() == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl) {
const int64_t bl,
at::ScalarType tensor_dtype) {
size_t packed_size = n * k;
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
if (tensor_dtype == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
});
}
void kai_quant_pack_lhs_int4_mm(
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k) {
// Kernel IDs for GEMM and GEMV
constexpr kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
constexpr kai_kernel_id gemv_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
// Get total threads and select kernel
const int64_t total_threads = at::get_num_threads();
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm() && m > 1) {
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
}
// Thread blocking parameters
const int64_t n_step = kernel_packet.ukernel.get_n_step();
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
const uint8_t* lhs_native_mtx_bf16 =
reinterpret_cast<const uint8_t*>(input.data_ptr());
const uint8_t* rhs_packed_mtx_qs4cx =
reinterpret_cast<const uint8_t*>(weight.data_ptr());
uint8_t* lhs_packed_base = lhs_packed.get();
constexpr int32_t element_size = sizeof(uint16_t);
const size_t lhs_stride = k * element_size;
const size_t dst_stride = n * element_size;
// LHS quantization packing
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
const size_t src_stride = vec_per_thread * lhs_stride;
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.kai_run_lhs_quant_pack(
vec_num,
k,
mr,
kr,
sr,
0,
(const uint16_t*)lhs_src_ptr,
lhs_stride,
lhs_packed_ptr);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
lhs_quant_pack(thread_id);
}
});
// Matrix multiplication
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
auto mm = [=, &kernel_packet](int64_t thread_id) {
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(
thread_id * vec_per_thread, k);
auto dst_ptr = dst_act_mtx_bf16 +
kernel_packet.ukernel.get_dst_offset(
0, thread_id * vec_per_thread, dst_stride);
const int64_t vec_num = (thread_id == num_threads - 1)
? (n - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.ukernel.run_matmul(
m,
vec_num,
k,
lhs_packed_base,
rhs_packed_ptr,
(uint16_t*)dst_ptr,
dst_stride,
element_size, // dst_stride_col
-FLT_MAX,
FLT_MAX);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
mm(thread_id);
}
});
}
void kai_quant_pack_lhs_int4_mm(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else if (!(bl % 32) && !(k % bl)) {
const auto input_dtype = input.dtype();
if (input_dtype == at::kBFloat16) {
if (cpuinfo_has_arm_bf16()) {
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
}
} else if (input_dtype == at::kFloat) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
}
} else if ((bl % 32 == 0) && (k % bl == 0)) {
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
output, input, weight, m, n, k, bl);
}

View File

@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl);
const int64_t bl,
at::ScalarType tensor_dtype = at::kFloat);
/**
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )

View File

@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
const auto scales_data = scales.data_ptr<float>();
const auto scales_data = scales.to(kFloat).data_ptr<float>();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(

View File

@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return channelwise_8bit_4bit_kernels.at(id);
}
// Kernel Mapping - BF16 Channelwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
bf16_channelwise_8bit_4bit_kernels = {
{kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return bf16_channelwise_8bit_4bit_kernels.at(id);
}
} // namespace at::native::kleidiai
#endif

View File

@ -10,21 +10,32 @@
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
namespace at::native::kleidiai {
enum class kai_kernel_id {
// FP32 inputs, 4-bit weights, FP32 output
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
0, // Groupwise 4 bit GEMV
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
1, // Groupwise 4 bit GEMM
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
2, // Channelwise 4 bit GEMV
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
3 // Channelwise 4 bit GEMM
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
// BF16 inputs, 4-bit weights, BF16 output
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
4, // Channelwise 4-bit GEMV with BF16 input/output
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
5 // Channelwise 4-bit GEMM with BF16 input/output
};
// Channelwise Kernel mapping
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
// bf16 Channelwise Kernel mapping
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const void* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
const uint8_t* rhs,
const float* bias,
const float* scale,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
};
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
// Groupwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(

View File

@ -1426,9 +1426,6 @@ static at::Tensor _fp8_convolution_onednn_ref(
w_scales_new_shape[0] = -1;
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
auto output_padding = std::vector<int64_t>(kSpatialDim, 0);
if (bias.has_value()){
bias = bias.value().to(at::kFloat);
}
auto y_f32 = at::convolution(
dqx, dqw, bias, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups
);

View File

@ -50,7 +50,7 @@ nfnet_l0,pass,7
repvgg_a2,pass,7
repvgg_a2,fail_accuracy,7

1 name accuracy graph_breaks
50
51
52
53
54
55
56

View File

@ -14,10 +14,6 @@ Utils
sdpa_kernel
SDPBackend
register_flash_attention_impl
activate_flash_attention_impl
list_flash_attention_impls
current_flash_attention_impl
Submodules
----------

View File

@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
pip install ninja
# Install onnx
pip install -e "$tp2_dir/onnx"
pip install --no-use-pep517 -e "$tp2_dir/onnx"
# Install caffe2 and pytorch
pip install -r "$top_dir/caffe2/requirements.txt"

View File

@ -180,47 +180,6 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
del model
del optim
def _test_tracker_multihandler_hook(self):
"""Should run without KeyError."""
class TestModule(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.norm1 = nn.RMSNorm(dim)
self.output1 = nn.Linear(dim, dim)
self.norm2 = nn.RMSNorm(dim)
self.output2 = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm1(x)
x = self.output1(x)
x = self.norm2(x)
x = self.output2(x)
return x
gc.collect()
torch.manual_seed(42)
dev = torch.device(torch.accelerator.current_device_index())
with torch.device(dev):
model = TestModule(128)
mesh = init_device_mesh(dev.type, (self.world_size,))
fully_shard([model.norm1, model.output1], mesh=mesh)
fully_shard([model.norm2, model.output2], mesh=mesh)
fully_shard(model, mesh=mesh)
fmt = FSDPMemTracker(model)
with fmt:
inp = torch.randn(16, 128, device=dev)
y = model(inp)
loss = y.sum()
loss.backward()
del inp
del model
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
@property

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
import torch
import torch.distributed as dist
@ -372,7 +371,6 @@ class DTensorExportTest(TestCase):
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)

View File

@ -15,7 +15,7 @@ import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from functorch.compile import default_partition, min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
)
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
@ -281,7 +281,14 @@ class ActivationCheckpointingViaTagsTests(
run(export_compiler)
def test_tags_function(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -297,11 +304,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_via_global_checkpoint(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -316,17 +334,28 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_with_kwargs(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_with_kwargs(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
gn, torch.sin(x), y, use_reentrant=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
@ -336,11 +365,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_sequential_layers(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_sequential_layers(self, device, partition_fn):
def gn(x):
x = x.cos()
for _ in range(3):
@ -361,11 +401,22 @@ class ActivationCheckpointingViaTagsTests(
freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_multiple_checkpoints(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_multiple_checkpoints(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -383,11 +434,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_module(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_module(self, device, partition_fn):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -411,11 +473,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_decomps(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_decomps(self, device, partition_fn):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self) -> None:
@ -443,6 +516,7 @@ class ActivationCheckpointingViaTagsTests(
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
@ -702,7 +776,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_recompute(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
@ -723,9 +804,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
),
)
def _test(context_fn, bw_compiler):
def _test(context_fn, bw_compiler, partition_fn):
def gn(x):
return torch.sigmoid(torch.matmul(x, x))
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
@ -739,14 +820,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
fw_compiler = functools.partial(
count_ops,
freq=1,
freq=2,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@ -754,17 +835,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
_test(
context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=2, # 2 bwd mm ops per fwd matmul
freq=4, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
def test_sac_with_partial_context_fn(self):
@ -801,7 +884,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm(
self, device, partition_fn
):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -841,15 +933,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device
self, device, partition_fn
):
def selective_checkpointing_context_fn():
no_recompute_list = [
@ -889,7 +988,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
disable_functionalization=True,
)
self._validate(fn, backend, x, y)
@ -897,7 +996,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_triton_kernel(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
# Copy of the above test, but make sure that having a triton kernel in the
# region does not error.
def add_one(x):
@ -957,14 +1063,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_tensor_subclass(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1007,14 +1120,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_custom_rule(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1072,14 +1192,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
@ -1118,14 +1245,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_outplace_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1163,14 +1297,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_list_ops(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
@ -1206,7 +1347,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1217,7 +1358,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
"requires TorchDispatchMode + torch.compile work to complete"
)
@requires_cuda_and_triton
def test_compile_selective_checkpoint_inplace_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1257,7 +1405,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1265,7 +1413,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
def test_compile_selective_checkpoint_random_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
@ -1312,7 +1467,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
@ -1324,7 +1479,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_invalid_context(self):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
@ -1353,7 +1515,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
@ -1362,7 +1524,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_compile_selective_checkpoint_parametrization(self):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
def sac_policy():
def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
@ -1425,7 +1594,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
bw_compiler = functools.partial(
count_ops,
freqs=[
2, # 1 from mul recompute, 1 from mul backward
# 1 from mul recompute, 1 from mul backward
# w/o CSE, we have one extra mul
3 if partition_fn is default_partition else 2,
1,
],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
@ -1434,7 +1605,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
model = MLPModule()

View File

@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
return grad_output * x, grad_output * x
def f(a, b):
return FwBwMutation.apply(a, b)
return FwBwMutation.apply(a, b).sin_().clone()
inps = [
torch.ones(3, 3, requires_grad=True),
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
return (mul, add)""",
clone = torch.ops.aten.clone.default(mul)
sin_ = torch.ops.aten.sin_.default(mul); mul = None
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
return (clone_1, add, clone)""",
)
# important bit: there is 1 mutation in the bw
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, tangents_1):
def forward(self, add, clone, tangents_1):
cos = torch.ops.aten.cos.default(clone); clone = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
return (mul_2, None)""",
)
def test_fw_bw_mutation_no_functionalization2(self):

View File

@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
op="call_function", target=torch.ops.aten.mm.default
)
self.assertEqual(len(mm_nodes), 4)
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)

View File

@ -2476,12 +2476,11 @@ class CommonTemplate:
b_int8pack, b_scales = convert_weight_to_int8pack(b)
self.common(fn, (a, b_int8pack, b_scales, c))
@xfail_if_mps_unimplemented
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
def test__dyn_quant_pack_4bit_weight(self):
def test__dyn_quant_pack_4bit_weight_fp32(self):
q_group = 32
k = 128
n = 128
@ -2512,12 +2511,46 @@ class CommonTemplate:
self.common(fn, (b, in_features, out_features))
@xfail_if_mps_unimplemented
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
def test__dyn_quant_pack_4bit_weight_bf16(self):
q_group = 32
k = 128
n = 128
torch.manual_seed(1)
b = torch.rand((k, n), dtype=torch.bfloat16)
in_features = b.size(0)
out_features = b.size(1)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def fn(b, in_features, out_features):
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
return b_int4pack
self.common(fn, (b, in_features, out_features))
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
def test__dyn_quant_matmul_4bit(self):
def test__dyn_quant_matmul_4bit_fp32_input(self):
q_group = 32
m = 32
k = 128
@ -2557,6 +2590,60 @@ class CommonTemplate:
self.common(fn, (a, q_group, in_features, out_features))
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
def test__dyn_quant_matmul_4bit_bf16_input(self):
m = 32
k = 128
n = 128
q_group = k
torch.manual_seed(1)
a = torch.rand((m, k), dtype=torch.bfloat16)
b = torch.rand((k, n), dtype=torch.bfloat16)
# codegen_dynamic_shape test fails without explicitly marking these dynamic
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 1)
in_features = b.size(0)
out_features = b.size(1)
if not self.is_dtype_supported(torch.bfloat16):
raise unittest.SkipTest(
f"torch.bfloat16 not supported for device {self.device}"
)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def fn(a, q_group, in_features, out_features):
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
res = torch.ops.aten._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
return res
self.common(fn, (a, q_group, in_features, out_features), atol=1, rtol=0.5)
def test_expanded_reduction(self):
def fn(x, y):
z = x * y

View File

@ -7798,7 +7798,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test__dyn_quant_matmul_4bit(self, device, m, k, n):
def test__dyn_quant_matmul_4bit_fp32(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
@ -7870,7 +7870,86 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n):
def test__dyn_quant_matmul_4bit_bf16(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
torch.manual_seed(1)
a_bfloat16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b_bfloat16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
in_features = b_bfloat16.size(0)
out_features = b_bfloat16.size(1)
q_group = in_features
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def dyn_quant_matmul_4bit(
a, b_int4pack, q_group, in_features, out_features
):
return torch.ops.aten._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight(
b_bfloat16, in_features, out_features
)
dtypes = [torch.bfloat16]
for dtype in dtypes:
a = a_bfloat16.to(dtype=dtype)
b = b_bfloat16.to(dtype=dtype)
ref = torch.mm(a, b)
res = dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
# Mean relative error check
expected_mean_err = 0.00952
mean_err_tol = 0.005 # allow small deviation (±0.005)
mean_err = ((res - ref).abs() / ref.abs().clamp(min=1e-5)).mean()
self.assertTrue(
abs(mean_err - expected_mean_err) < mean_err_tol,
f"Mean relative error {mean_err:.6f} deviates from expected {expected_mean_err}"
)
# Elementwise relative error check
elementwise_diff = (res - ref).abs()
elementwise_relative_error = elementwise_diff / ref.abs().clamp(min=torch.finfo(ref.dtype).eps)
self.assertTrue(
torch.all(elementwise_relative_error < 0.070),
"Some elements have relative error >= 7%"
)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test_compile_dyn_quant_matmul_4bit_fp32(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
@ -7928,6 +8007,83 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
)
@onlyNativeDeviceTypes
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test_compile_dyn_quant_matmul_4bit_bf16(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
torch.manual_seed(1)
a_bfloat16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b_bfloat16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
in_features = b_bfloat16.size(0)
out_features = b_bfloat16.size(1)
q_group = in_features
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b_bfloat16, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16)
@torch.compile
def dyn_quant_matmul_4bit(
a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features
):
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return torch._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
res = dyn_quant_matmul_4bit(
a_bfloat16,
b_uint8,
b_scales_and_zeros,
q_group,
in_features,
out_features,
)
ref = torch.mm(a_bfloat16, b_bfloat16)
# === Accuracy checks ===
# Mean relative error check
expected_mean_err = 0.00952
mean_err_tol = 0.005 # allow small deviation (±0.005)
mean_err = ((res - ref).abs() / ref.abs().clamp(min=1e-5)).mean()
self.assertTrue(
abs(mean_err - expected_mean_err) < mean_err_tol,
f"Mean relative error {mean_err:.6f} deviates from expected {expected_mean_err}"
)
# Avoid divide-by-zero with clamp
denominator = ref.abs().clamp(min=torch.finfo(ref.dtype).eps)
# Compute elementwise relative error — always non-negative
elementwise_relative_error = (res - ref).abs() / denominator
# Check if all elements are within 6% error
assert torch.all(elementwise_relative_error >= 0), "Relative error should never be negative"
self.assertTrue(
torch.all(elementwise_relative_error < 0.070),
"Some elements have relative error >= 7%"
)
@onlyCPU
@parametrize("m", [32, 64])
@parametrize("k", [32, 64])
@parametrize("n", [48, 64])

View File

@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
_proxy_tensor_disable_update_tensor_tracker,
get_proxy_mode,
maybe_disable_thunkify,
maybe_enable_thunkify,
)
@ -295,6 +296,10 @@ def create_joint(
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
fn, primals
)
mode = get_proxy_mode()
assert mode is not None
for node in mode.tracer.graph.nodes:
node.meta["partitioner_tag"] = "is_forward"
# TODO: I think this hook can also be eliminated now
if joint_fn_handle and joint_fn_handle.post_forward:

View File

@ -51,6 +51,7 @@ from ._activation_checkpointing.knapsack import (
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.functional_utils import assert_functional_graph
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -297,6 +298,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_is_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_forward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
@ -1021,105 +1026,95 @@ def default_partition(
Returns:
Returns the generated forward and backward Fx graph modules.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
# Respect the original placement of ops rather than rely on dataflow.
forward_nodes = []
last_node = None
for node in joint_module.graph.nodes:
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
last_node = node
assert last_node is not None
for node in joint_module.graph.nodes:
if not _is_tangent(node):
forward_nodes.append(node)
if node is last_node:
break
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
node.name for node in forward_nodes if node.op != "output"
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
assert_functional_graph(joint_module.graph)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
saved_values = []
saved_sym_nodes = []
def is_mutated_later_in_fw(node):
if _has_tag_is_backward(node):
return False
tensor_arg_aliases = [
x
for x in node.args
if isinstance(x, fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor)
]
while len(tensor_arg_aliases) > 0:
a = tensor_arg_aliases.pop()
for u in a.users:
if not isinstance(u.target, torch._ops.OpOverload):
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
if (
# one of the args was mutated
u.target._schema.is_mutable
# and the mutation happens "later"
and order[u] > order[node]
# and the mutation happened during the forward
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
):
for idx, alias_info in enumerate(u.target._schema.arguments):
if alias_info.is_write and u.args[idx] is a:
return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif (
continue
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
saved_values.append(node)
continue
if node.is_impure(impure_random=False) and node.op not in (
"placeholder",
"output",
):
# See is_impure in torch/fx/node.py
assert not graph_has_recomputable_ops, (
"Trying to apply AC on a graph with impure op",
node,
node.target,
)
saved_values.append(node)
continue
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
continue
if (
"tensor_meta" not in node.meta
and node.op == "call_function"
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
):
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target is operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [
n for n in node.users if n.name not in forward_node_names
]
if "tensor_meta" in node.meta and all(
is_sym_node(n) for n in backward_usages
):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
else:
saved_values.append(node)
assert all(user.target == operator.getitem for user in node.users)
continue
if not must_recompute(node):
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
return _extract_fwd_bwd_modules(
if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
if static_lifetime_input_nodes is None:
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
@ -1127,6 +1122,24 @@ def default_partition(
static_lifetime_input_nodes=static_lifetime_input_nodes,
)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)
# raise all getitem ops to as early as possible
# this is helpful for memory, especially in the case of aot_eager backend
fw_module = raise_getitems(fw_module)
bw_module = raise_getitems(bw_module)
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
if len(node_info.required_bw_nodes) > 0:
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
return fw_module, bw_module
INT_INF = int(1e6)
@ -1621,7 +1634,9 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
break
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
def cleanup_recompute_tags(
joint_module: fx.GraphModule, *, is_default_partition: bool
) -> fx.GraphModule:
"""
If there are two consecutive checkpointed blocks with no operator in
between, we would still want to stash the tensor at the boundary of
@ -1658,6 +1673,16 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
# in forward graph outputs. With this, we can break the above circular dependency.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
elif (
"ac_graph_id" not in node.meta
and any(must_recompute(user) for user in node.users)
and is_default_partition
):
# This node is not part of the AC region and a user is marked as recompute.
# This means it's an input to the AC region and we should save it.
# For ease of landing, gate this to default partitioner only, but we should think
# about flipping the switch in general as well.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
return joint_module
@ -2765,6 +2790,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
return module
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
def min_cut_rematerialization_partition(
joint_module: fx.GraphModule,
_joint_inputs,
@ -2813,68 +2891,16 @@ def min_cut_rematerialization_partition(
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
def classify_nodes(joint_module, static_lifetime_input_indices):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
# networkx blows up on graphs with no required backward nodes
# Since there's nothing to partition anyway, and the default partitioner can "handle"

View File

@ -7099,19 +7099,13 @@ def sym_constrain_range(a, min=None, max=None):
@register_lowering(aten.sym_size.int)
def sym_size(a, dim):
val = V.graph.current_node.meta["val"]
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return int(val)
return val.node.expr
@register_lowering(aten.sym_stride.int)
def sym_stride(a, dim):
val = V.graph.current_node.meta["val"]
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return int(val)
return val.node.expr
@register_lowering(aten.sym_numel)

View File

@ -3722,6 +3722,7 @@ def kai_roundup(a: int, b: int) -> int:
def get_kai_packed_weight_size(n_bits, N, K, groupsize):
if n_bits == 4:
# Works for both fp32 and bf16 Kernels
if groupsize == K: # channelwise
# dotprod params only [1x8x32_neon_dotprod]
kai_nr = 8
@ -3851,6 +3852,8 @@ def meta__dyn_quant_pack_4bit_weight(
)
return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
packed_weight_size = weights.numel() + scales_zeros.numel()
if bias is not None:
packed_weight_size += bias.numel()
return weights.new_empty(packed_weight_size, dtype=torch.float)
@ -3864,8 +3867,12 @@ def meta__dyn_quant_matmul_4bit(
):
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
torch._check(
inp.dtype == torch.float32,
lambda: f"expected input to be f32, got {inp.dtype}",
(inp.dtype == torch.float32)
or (inp.dtype == torch.bfloat16 and block_size == in_features),
lambda: (
f"expected input to be f32 or bf16 (bf16 requires block_size == in_features), "
f"got {inp.dtype} with block_size={block_size} and in_features={in_features}"
),
)
M = inp.size(0)
return inp.new_empty(M, out_features, dtype=inp.dtype)

View File

@ -702,7 +702,7 @@ def exp2(a):
# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promoting_args=("a,"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:

View File

@ -2,7 +2,7 @@ from collections.abc import Callable
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, NamedTuple, Optional, TypeVar, Union
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
import torch
@ -17,9 +17,6 @@ from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary, weakref
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
_TOTAL_KEY = "Total"
__all__ = ["FSDPMemTracker"]
@ -368,28 +365,14 @@ class FSDPMemTracker(MemTracker):
# `FSDPParamGroup.post_forward` because during AC these won't be called.
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
# get the unique _MultiHandlers/RemoveHandlers and store in dictionary
# the _MultiHandlers object will only need to be grabbed once.
unique_handlers: dict[RemovableHandle, bool] = {}
# pyrefly: ignore # missing-attribute
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
if not unique_handlers.get(fsdp_state._pre_forward_hook_handle):
unique_handlers[fsdp_state._pre_forward_hook_handle] = True
if not unique_handlers.get(fsdp_state._post_forward_hook_handle):
unique_handlers[fsdp_state._post_forward_hook_handle] = True
# call remove on the handles once
for f_hook_handle in unique_handlers.keys():
f_hook_handle.remove()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
fsdp_state._pre_forward_hook_handle.remove()
fsdp_state._post_forward_hook_handle.remove()
fsdp_state._pre_forward_hook_handle = (
# pyrefly: ignore [missing-attribute]
module.register_forward_pre_hook(

View File

@ -259,13 +259,14 @@ else:
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
@ -296,6 +297,11 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""

View File

@ -19,13 +19,8 @@ __all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"register_flash_attention_impl",
"activate_flash_attention_impl",
"list_flash_attention_impls",
"current_flash_attention_impl",
]
# Note: [SDPA warnings]
# TODO: Consider using this for sdpa regardless of subclasses
# This only effects users of bias subclasses
@ -167,23 +162,3 @@ def _sdpa_kernel_variadic(*backends: SDPBackend):
def _get_flash_version() -> str:
"""This returns the closest matching tag for the flash attention backend"""
return "2.5.7"
from . import _registry
# Re-export registry types and functions for public API
_FlashAttentionImpl = _registry._FlashAttentionImpl
_RegisterFn = _registry._RegisterFn
register_flash_attention_impl = _registry.register_flash_attention_impl
activate_flash_attention_impl = _registry.activate_flash_attention_impl
list_flash_attention_impls = _registry.list_flash_attention_impls
current_flash_attention_impl = _registry.current_flash_attention_impl
register_flash_attention_impl.__module__ = __name__
activate_flash_attention_impl.__module__ = __name__
list_flash_attention_impls.__module__ = __name__
current_flash_attention_impl.__module__ = __name__
# Import built-in implementations to trigger self-registration
from . import _fa4 # noqa: F401

View File

@ -1,444 +0,0 @@
"""UBER PROTOTYPE!!!"""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
from dataclasses import dataclass
from functools import cache
from typing import Any, TYPE_CHECKING
from typing_extensions import TypeVarTuple, Unpack
from . import _registry
if TYPE_CHECKING:
from types import ModuleType
import torch
from torch.library import Library
__all__ = [
"register_flash_attention_fa4",
]
_FA4_MODULE_PATH: str | None = None
@dataclass
class _FA4Handle:
library: Library | None
def remove(self) -> None:
self.library = None
@cache
def _get_device_major(device: torch.device) -> int:
major, _ = torch.cuda.get_device_capability(device)
return major
def register_flash_attention_fa4(
module_path: str = "flash_attn.cute.interface",
) -> _FA4Handle:
"""
Register FA4 flash attention kernels with the PyTorch dispatcher.
Args:
module_path: Python module path to the FA4 implementation.
"""
global _FA4_MODULE_PATH
_ = _fa4_import_module(module_path)
_FA4_MODULE_PATH = module_path
return _FA4Handle(_fa4_register_kernels())
@cache
def _fa4_import_module(module_path: str) -> ModuleType:
module = importlib.import_module(module_path)
if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
return module
def _fa4_register_kernels() -> Library:
lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
lib.impl(
"_scaled_dot_product_flash_attention",
_fa4_scaled_dot_product_flash_attention_forward_impl,
"CUDA",
)
lib.impl(
"_scaled_dot_product_flash_attention_backward",
_fa4_scaled_dot_product_flash_attention_backward_impl,
"CUDA",
)
return lib
def _fa4_common_support_error(
query: torch.Tensor,
tensors: tuple[torch.Tensor, ...],
cum_seq_q: torch.Tensor | None,
require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
) -> str | None:
if not all(t.is_cuda for t in tensors):
return "inputs must be CUDA tensors"
if len({t.device for t in tensors}) != 1:
return "inputs must share device"
if query.dtype not in (torch.float16, torch.bfloat16):
return "query dtype must be float16 or bfloat16"
for name, tensor in require_fp32:
if tensor.dtype != torch.float32:
return f"{name} dtype must be float32"
if cum_seq_q is None and query.dim() != 4:
return "dense query must be 4D"
if cum_seq_q is not None and query.dim() != 3:
return "ragged query must be 3D"
if not torch.cuda.is_available():
return "CUDA not available"
if _get_device_major(query.device) not in (9, 10):
return "FA4 requires compute capability 9.0 or 10.0"
return None
def _fa4_forward_support_error(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float,
return_debug_mask: bool,
alibi_slopes: torch.Tensor | None,
seqused_k: torch.Tensor | None,
cum_seq_q: torch.Tensor | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if return_debug_mask:
return "return_debug_mask must be False"
if alibi_slopes is not None:
return "alibi_slopes not supported"
if seqused_k is not None:
if seqused_k.dtype != torch.int32:
return "seqused_k must be int32"
if not seqused_k.is_cuda:
return "seqused_k must be CUDA"
error = _fa4_common_support_error(
query,
(query, key, value),
cum_seq_q,
)
if error is not None:
if error == "inputs must share device":
return "query, key, value must be on same device"
return error
return None
def _fa4_backward_support_error(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
dropout_p: float,
cum_seq_q: torch.Tensor | None,
window_size_left: int | None,
window_size_right: int | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if window_size_left is not None or window_size_right is not None:
return "windowed attention not supported"
error = _fa4_common_support_error(
query,
(grad_out, query, key, value, out, logsumexp),
cum_seq_q,
require_fp32=(("logsumexp", logsumexp),),
)
if error is not None:
return error
return None
Ts = TypeVarTuple("Ts")
def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
def _fa4_run_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
window_size_left: int | None,
window_size_right: int | None,
seqused_k: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
kwargs: dict[str, Any] = {
"softmax_scale": scale,
"causal": is_causal,
"window_size_left": window_size_left,
"window_size_right": window_size_right,
"return_lse": True,
"cu_seqlens_q": cu_seq_q,
"cu_seqlens_k": cu_seq_k,
"seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
}
out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
return out, lse.contiguous()
def _fa4_run_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
dq, dk, dv = module._flash_attn_bwd(
query,
key,
value,
out,
grad_out,
logsumexp.contiguous(),
softmax_scale=scale,
causal=is_causal,
cu_seqlens_q=cu_seq_q,
cu_seqlens_k=cu_seq_k,
)
return dq, dk, dv
def _fa4_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
return_debug_mask: bool,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
seqused_k: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
alibi_slopes,
seqused_k,
cum_seq_q,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
out, lse = _fa4_run_forward(
query,
key,
value,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
window_size_left,
window_size_right,
seqused_k,
)
rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
return out, lse, rng_state, philox_offset, debug_mask
def _fa4_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
rng_state: torch.Tensor,
unused: torch.Tensor,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
cum_seq_q,
window_size_left,
window_size_right,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
dq, dk, dv = _fa4_run_backward(
grad_out,
query,
key,
value,
out,
logsumexp,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
)
return dq, dk, dv
def _fa4_scaled_dot_product_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
q, k, v = _transpose_dense(query, key, value)
max_q_flash = q.size(1)
max_k_flash = k.size(1)
out, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
q,
k,
v,
None,
None,
max_q_flash,
max_k_flash,
dropout_p,
is_causal,
return_debug_mask,
scale=scale,
)
(out,) = _transpose_dense(out)
max_q = query.size(2)
max_k = key.size(2)
return (
out,
lse,
None,
None,
max_q,
max_k,
rng_state,
philox_offset,
debug_mask,
)
def _fa4_scaled_dot_product_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
*,
scale: float | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
max_q = query.size(2)
max_k = key.size(2)
dq, dk, dv = _fa4_flash_attention_backward_impl(
go,
q,
k,
v,
o,
logsumexp,
None,
None,
max_q,
max_k,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale=scale,
)
dq, dk, dv = _transpose_dense(dq, dk, dv)
return dq, dk, dv
_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)

View File

@ -1,108 +0,0 @@
# mypy: allow-untyped-defs
"""Registry for flash attention implementations.
This module contains the registration system for flash attention implementations.
It has no torch dependencies to avoid circular imports during initialization.
"""
from typing import Callable, Literal, Protocol
class FlashAttentionHandle(Protocol):
def remove(self) -> None: ...
_RegisterFn = Callable[..., FlashAttentionHandle | None]
_FlashAttentionImpl = Literal["FA4"]
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
_FLASH_ATTENTION_ACTIVE: str | None = None
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
def register_flash_attention_impl(
impl: str | _FlashAttentionImpl,
*,
register_fn: _RegisterFn,
) -> None:
"""
Register the callable that activates a flash attention impl.
.. note::
This function is intended for SDPA backend providers to register their
implementations. End users should use :func:`activate_flash_attention_impl`
to activate a registered implementation.
Args:
impl: Implementation identifier (e.g., ``"FA4"``).
register_fn: Callable that performs the actual dispatcher registration.
This function will be invoked by :func:`activate_flash_attention_impl`
and should register custom kernels with the PyTorch dispatcher.
It may optionally return a handle implementing
:class:`FlashAttentionHandle` to keep any necessary state alive.
Example:
>>> def my_impl_register(module_path: str = "my_flash_impl"):
... # Register custom kernels with torch dispatcher
... pass # doctest: +SKIP
>>> register_flash_attention_impl(
... "MyImpl", register_fn=my_impl_register
... ) # doctest: +SKIP
"""
_FLASH_ATTENTION_IMPLS[impl] = register_fn
def activate_flash_attention_impl(
impl: str | _FlashAttentionImpl,
) -> None:
"""
Activate into the dispatcher a previously registered flash attention impl.
.. note::
Backend providers should NOT automatically activate their implementation
on import. Users should explicitly opt-in by calling this function or via
environment variables to ensure multiple provider libraries can coexist.
Args:
impl: Implementation identifier to activate. See
:func:`~torch.nn.attention.list_flash_attention_impls` for available
implementations.
If the backend's :func:`register_flash_attention_impl` callable
returns a :class:`FlashAttentionHandle`, the registry keeps that
handle alive for the lifetime of the process (until explicit
uninstall support exists).
Example:
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
"""
global _FLASH_ATTENTION_ACTIVE
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
if register_fn is None:
raise ValueError(
f"Unknown flash attention impl '{impl}'. "
f"Available implementations: {list_flash_attention_impls()}"
)
# TODO: The only way to actually register a new impl is to unregister the current impl
# reinstall the default impl and then register the new impl
if _FLASH_ATTENTION_ACTIVE == impl:
return
handle = register_fn()
if handle is not None:
_FLASH_ATTENTION_HANDLES[impl] = handle
_FLASH_ATTENTION_ACTIVE = impl
def list_flash_attention_impls() -> list[str]:
"""Return the names of all available flash attention implementations."""
return sorted(_FLASH_ATTENTION_IMPLS.keys())
def current_flash_attention_impl() -> str | None:
"""
Return the currently activated flash attention impl name, if any.
``None`` indicates that no custom impl has been activated.
"""
return _FLASH_ATTENTION_ACTIVE