mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 01:15:12 +08:00
Compare commits
40 Commits
ciflow/tru
...
bf/lite
| Author | SHA1 | Date | |
|---|---|---|---|
| 8690e80b9d | |||
| 24edea44b3 | |||
| ce8672c24f | |||
| 402c465030 | |||
| 573a79fffa | |||
| 4945180468 | |||
| 1df723e6f5 | |||
| f9b81e23e4 | |||
| ffe6cc39c7 | |||
| db1f3f6901 | |||
| 43041f0a43 | |||
| dc00842b81 | |||
| f1a129a6d0 | |||
| fad48ffa62 | |||
| 8f0fa2e52d | |||
| 0f120be0b1 | |||
| fee30c60a8 | |||
| 5a8affeba1 | |||
| 0932f1ff21 | |||
| 814cb7024b | |||
| f2a7f85f11 | |||
| 308414c1b6 | |||
| 5c9a710c99 | |||
| a24161475f | |||
| f79e116bfa | |||
| db30ee1a88 | |||
| 5c17c94af8 | |||
| 1d56ee10de | |||
| ccfc98f8a0 | |||
| 7dddcb24c9 | |||
| cd453f5e7c | |||
| 16d170b0ef | |||
| 8c23bb9fef | |||
| 08517c8556 | |||
| b42d8bdc76 | |||
| a68e77bdea | |||
| 9d977f1f68 | |||
| 0bba0cdb9c | |||
| 7dba73a0a7 | |||
| 8ec9f7de82 |
@ -30,7 +30,6 @@ into a tarball, with the following structure:
|
||||
More specifically, `build_magma.sh` copies over the relevant files from the `package_files` directory depending on the ROCm version.
|
||||
Outputted binaries should be in the `output` folder.
|
||||
|
||||
|
||||
## Pushing
|
||||
|
||||
Packages can be uploaded to an S3 bucket using:
|
||||
|
||||
2
.github/workflows/inductor-rocm-mi200.yml
vendored
2
.github/workflows/inductor-rocm-mi200.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: inductor-rocm
|
||||
name: inductor-rocm-mi200
|
||||
|
||||
on:
|
||||
schedule:
|
||||
|
||||
2
.github/workflows/rocm-mi200.yml
vendored
2
.github/workflows/rocm-mi200.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: rocm
|
||||
name: rocm-mi200
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
@ -18,6 +18,8 @@ Please report security issues using https://github.com/pytorch/pytorch/security/
|
||||
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
@ -3541,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
|
||||
const int64_t out_features) {
|
||||
auto M = inp.size(0);
|
||||
TORCH_CHECK(
|
||||
inp.dtype() == kFloat,
|
||||
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
|
||||
__func__,
|
||||
" : expect input to be 32-bit float tensor.");
|
||||
" : expect input to be float32 or bfloat16 tensor.");
|
||||
TORCH_CHECK(
|
||||
block_size == in_features ||
|
||||
(!(block_size % 32) && !(in_features % block_size)),
|
||||
|
||||
@ -142,6 +142,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
|
||||
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
|
||||
auto batch_sizes_t = _batch_sizes.contiguous();
|
||||
checkLongTensor(batch_sizes_t);
|
||||
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
|
||||
|
||||
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
|
||||
int64_t max_batch_size = batch_sizes[0];
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <cmath>
|
||||
#include <c10/util/Unroll.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
@ -793,6 +794,139 @@ bool can_use_kleidiai(
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t k,
|
||||
const uint16_t* lhs_bf16,
|
||||
const uint8_t* rhs_qs4cx,
|
||||
const float* rhs_scales,
|
||||
uint16_t* dst_bf16,
|
||||
float scalar_min,
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Roundup lambda for internal stride calculations
|
||||
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
|
||||
|
||||
// Cast bfloat16 to float32 inline
|
||||
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
|
||||
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
|
||||
float f;
|
||||
std::memcpy(&f, &tmp, sizeof(f));
|
||||
return f;
|
||||
};
|
||||
|
||||
// Cast float32 to bfloat16 inline
|
||||
auto cast_f32_to_bf16 = [](float f) {
|
||||
uint32_t bits;
|
||||
std::memcpy(&bits, &f, sizeof(bits));
|
||||
return static_cast<uint16_t>(bits >> 16);
|
||||
};
|
||||
|
||||
// Quantization pack lambda (channelwise QA8DX)
|
||||
auto quant_pack_8bit_channelwise =
|
||||
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const uint16_t* row_ptr = src_bf16 + i * K;
|
||||
// find min/max
|
||||
float mn = FLT_MAX, mx = -FLT_MAX;
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
mn = std::min(mn, v);
|
||||
mx = std::max(mx, v);
|
||||
}
|
||||
float rmin = std::min(0.0f, mn);
|
||||
float rmax = std::max(0.0f, mx);
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
|
||||
float recip = scale ? 1.0f / scale : 0.0f;
|
||||
int32_t zp;
|
||||
float des_min = rmin * scale;
|
||||
float des_max = rmax * scale;
|
||||
float err_min = qmin + des_min;
|
||||
float err_max = qmax + des_max;
|
||||
float zp_f =
|
||||
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
|
||||
zp_f = std::clamp(zp_f, qmin, qmax);
|
||||
zp = std::lrintf(zp_f);
|
||||
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
|
||||
// store header
|
||||
*reinterpret_cast<float*>(out_ptr) = recip;
|
||||
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
|
||||
out_ptr += sizeof(float) + sizeof(int32_t);
|
||||
// quantize
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
|
||||
q = std::clamp(
|
||||
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
|
||||
*out_ptr++ = static_cast<int8_t>(q);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// MatMul lambda (MXN x MXK -> MNXK BF16)
|
||||
auto matmul_kernel = [&](size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const int8_t* lhs,
|
||||
const uint8_t* rhs,
|
||||
const float* scales,
|
||||
uint16_t* dst,
|
||||
float lo,
|
||||
float hi) {
|
||||
const size_t lhs_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
const size_t rhs_stride = roundup(K, 2) / 2;
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const int8_t* lhs_row = lhs + i * lhs_stride;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
int32_t acc = 0;
|
||||
const int8_t* lptr = lhs_row;
|
||||
const uint8_t* rptr = rhs + j * rhs_stride;
|
||||
float lhs_scale = *reinterpret_cast<const float*>(lptr);
|
||||
int32_t lhs_off =
|
||||
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
|
||||
lptr += sizeof(float) + sizeof(int32_t);
|
||||
for (size_t t = 0; t < K; ++t) {
|
||||
int32_t lv = static_cast<int32_t>(lptr[t]);
|
||||
uint8_t bv = rptr[t / 2];
|
||||
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
|
||||
: (static_cast<int32_t>(bv >> 4) - 8);
|
||||
acc += lv * rv + lhs_off * rv;
|
||||
}
|
||||
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
|
||||
if (bias) {
|
||||
res += bias[j];
|
||||
}
|
||||
res = std::clamp(res, lo, hi);
|
||||
*dst++ = cast_f32_to_bf16(res);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// allocate and run
|
||||
std::unique_ptr<int8_t[]> packed(
|
||||
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
|
||||
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
|
||||
matmul_kernel(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
packed.get(),
|
||||
rhs_qs4cx,
|
||||
rhs_scales,
|
||||
dst_bf16,
|
||||
scalar_min,
|
||||
scalar_max);
|
||||
}
|
||||
|
||||
/**
|
||||
* The Int4 quantized weights must be represented as a uint8 tensor
|
||||
* For matrix multiplication with a weight shape of (N x K)
|
||||
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
if (can_use_kleidiai(scales_zeros, K, block_size)) {
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
|
||||
packed_weights.resize_({weight_packed_size});
|
||||
kleidiai::kai_pack_int4_rhs(
|
||||
packed_weights, weights, scales_zeros, bias, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
TORCH_CHECK(
|
||||
bias.has_value() == 0,
|
||||
__func__,
|
||||
" : Bias is unsupported in reference implementation");
|
||||
packed_weights = packed_weights.to(kFloat);
|
||||
auto weight_reshaped = weights.view({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
|
||||
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
|
||||
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
|
||||
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
|
||||
if (bias.has_value()) {
|
||||
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
|
||||
}
|
||||
auto res = at::cat(tensors_to_cat, 0);
|
||||
packed_weights.resize_(res.sizes()).copy_(res);
|
||||
}
|
||||
}
|
||||
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
const float* rhs_scales_f32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
|
||||
|
||||
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
|
||||
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
// required format for matmul
|
||||
auto input_quant_pack_8bit_channelwise =
|
||||
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
}
|
||||
|
||||
// Maximum/minimum int8 values
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
|
||||
// Round to nearest integer
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
|
||||
|
||||
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
|
||||
v0_s32 = v0_s32 + nudged_zero_point0;
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[n_idx];
|
||||
}
|
||||
|
||||
// Clamp (min-max) operation
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float* rhs_scales_fp32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Lambda for LHS quantization
|
||||
auto lhs_quant_pack = [&](size_t m,
|
||||
size_t k,
|
||||
const float* lhs_f32,
|
||||
int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
min0 = std::min(src0_0, min0);
|
||||
}
|
||||
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
|
||||
zero_point0 = std::max(zero_point0, qmin);
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
|
||||
|
||||
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
v0_s32 = std::max(
|
||||
std::min(
|
||||
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
|
||||
static_cast<int32_t>(INT8_MIN));
|
||||
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
|
||||
static_cast<int32_t>(kI8Min));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[col_idx];
|
||||
}
|
||||
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
|
||||
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic Input Quant 4 bit weights matmul execution flow
|
||||
(INT4 Weights + FP scales + FP32 Bias)
|
||||
FP32 Input Packed Buffer
|
||||
| |
|
||||
Quantize Cast
|
||||
to INT8 to INT8
|
||||
| |
|
||||
v v
|
||||
INT8 Input INT8 Weights
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
INT8 Matrix Multiplication
|
||||
|
|
||||
v
|
||||
FP32 Dequantized and Accumulate in FP32
|
||||
|
|
||||
v
|
||||
FP32 Final Output
|
||||
|
||||
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||
* Dynamic INT4 weight-only MatMul with per-row input quantization.
|
||||
*
|
||||
* Execution Flow:
|
||||
*
|
||||
* (INT4 Weights + FP Scales [+ optional Bias])
|
||||
*
|
||||
* Input (FP32 or BF16) Packed Weight Buffer
|
||||
* | |
|
||||
* Row-wise Quantization (INT8) |
|
||||
* | |
|
||||
* INT8 Input Activation INT4 Quantized Weights + Scales
|
||||
* \ /
|
||||
* \ /
|
||||
* Quantized Matrix Multiply
|
||||
* |
|
||||
* Output Tensor (BF16 or FP32)
|
||||
*
|
||||
* Notes:
|
||||
* - Groupwise kernels expect BF16 scales
|
||||
* - Channelwise kernels expect FP32 scales
|
||||
* - Bias is currently unsupported in fallback path
|
||||
*/
|
||||
void dyn_quant_matmul_4bit_kernel(
|
||||
const Tensor& output,
|
||||
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
|
||||
const int64_t block_size) {
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
|
||||
if (weight_packed_size == packed_weights.numel()) {
|
||||
// KleidiAI interface internally handles the Channelwise and groupwise
|
||||
// distinction
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(
|
||||
output, inp, packed_weights, M, N, K, block_size);
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
|
||||
const auto weights_size = N * K / 2;
|
||||
// The weights needs to be in uint8_t data type after quantization
|
||||
auto extracted_weights =
|
||||
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
|
||||
auto float32_scales =
|
||||
(packed_weights.narrow(
|
||||
0, weights_size, packed_weights.size(0) - weights_size))
|
||||
.to(kFloat);
|
||||
uint8_t* rhs_4bit =
|
||||
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
|
||||
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_size,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
block_size == K || (!(block_size % 32) && !(K % block_size)),
|
||||
__func__,
|
||||
": Group size should be multiple 32 or in_features [",
|
||||
K,
|
||||
"]. Provided ",
|
||||
block_size);
|
||||
{
|
||||
void* input = inp.data_ptr();
|
||||
void* dst = output.data_ptr();
|
||||
|
||||
// Extract weights, sclaes and biases form from packed tensor
|
||||
const int weights_elements = N * K / 2;
|
||||
const int scale_elements = N * (K / block_size);
|
||||
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
|
||||
|
||||
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
|
||||
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
|
||||
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
|
||||
|
||||
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
|
||||
float* weight_scales = float32_scales.data_ptr<float>();
|
||||
|
||||
void* bias_data = nullptr;
|
||||
if (bias_elements) {
|
||||
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
|
||||
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
|
||||
bias_data = float32_bias.data_ptr();
|
||||
|
||||
}
|
||||
// 2 elements of 4 bit weights are packed into 1 uint8 packet
|
||||
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
|
||||
// Dispatch to reference kernels
|
||||
if (inp.scalar_type() == at::kBFloat16) {
|
||||
// BF16 input, BF16 output
|
||||
constexpr float BF16_MAX = 3.38953139e+38f;
|
||||
constexpr float BF16_MIN = -BF16_MAX;
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
M, N, K,
|
||||
(uint16_t*)input, weights_4bit, weight_scales,
|
||||
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
|
||||
}
|
||||
} else if (inp.scalar_type() == at::kFloat) {
|
||||
// FP32 input, FP32 output
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M, N, K,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M, N, K, block_size,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
}
|
||||
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
|
||||
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
|
||||
|
||||
@ -669,9 +669,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
// On non CK system(w/ ROCm), make sure use_fast_path is false
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
@ -680,7 +683,11 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
|
||||
@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
if (weight.scalar_type() == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype) {
|
||||
size_t packed_size = n * k;
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
if (tensor_dtype == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
});
|
||||
}
|
||||
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
// Kernel IDs for GEMM and GEMV
|
||||
constexpr kai_kernel_id gemm_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
|
||||
constexpr kai_kernel_id gemv_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
|
||||
|
||||
// Get total threads and select kernel
|
||||
const int64_t total_threads = at::get_num_threads();
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
|
||||
if (cpuinfo_has_arm_i8mm() && m > 1) {
|
||||
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
|
||||
}
|
||||
|
||||
// Thread blocking parameters
|
||||
const int64_t n_step = kernel_packet.ukernel.get_n_step();
|
||||
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||
const size_t kr = kernel_packet.ukernel.get_kr();
|
||||
const size_t sr = kernel_packet.ukernel.get_sr();
|
||||
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
|
||||
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||
const uint8_t* lhs_native_mtx_bf16 =
|
||||
reinterpret_cast<const uint8_t*>(input.data_ptr());
|
||||
const uint8_t* rhs_packed_mtx_qs4cx =
|
||||
reinterpret_cast<const uint8_t*>(weight.data_ptr());
|
||||
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||
|
||||
constexpr int32_t element_size = sizeof(uint16_t);
|
||||
const size_t lhs_stride = k * element_size;
|
||||
const size_t dst_stride = n * element_size;
|
||||
|
||||
// LHS quantization packing
|
||||
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
|
||||
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
|
||||
const size_t src_stride = vec_per_thread * lhs_stride;
|
||||
|
||||
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
vec_num,
|
||||
k,
|
||||
mr,
|
||||
kr,
|
||||
sr,
|
||||
0,
|
||||
(const uint16_t*)lhs_src_ptr,
|
||||
lhs_stride,
|
||||
lhs_packed_ptr);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
lhs_quant_pack(thread_id);
|
||||
}
|
||||
});
|
||||
|
||||
// Matrix multiplication
|
||||
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
|
||||
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
|
||||
|
||||
auto mm = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
|
||||
kernel_packet.ukernel.get_rhs_packed_offset(
|
||||
thread_id * vec_per_thread, k);
|
||||
auto dst_ptr = dst_act_mtx_bf16 +
|
||||
kernel_packet.ukernel.get_dst_offset(
|
||||
0, thread_id * vec_per_thread, dst_stride);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (n - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.ukernel.run_matmul(
|
||||
m,
|
||||
vec_num,
|
||||
k,
|
||||
lhs_packed_base,
|
||||
rhs_packed_ptr,
|
||||
(uint16_t*)dst_ptr,
|
||||
dst_stride,
|
||||
element_size, // dst_stride_col
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
mm(thread_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
const at::Tensor& output,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
const auto input_dtype = input.dtype();
|
||||
|
||||
if (input_dtype == at::kBFloat16) {
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
|
||||
}
|
||||
} else if (input_dtype == at::kFloat) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
|
||||
}
|
||||
} else if ((bl % 32 == 0) && (k % bl == 0)) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
output, input, weight, m, n, k, bl);
|
||||
}
|
||||
|
||||
@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl);
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype = at::kFloat);
|
||||
|
||||
/**
|
||||
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
|
||||
|
||||
@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
auto weight_packed_data =
|
||||
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||
const auto scales_data = scales.data_ptr<float>();
|
||||
|
||||
const auto scales_data = scales.to(kFloat).data_ptr<float>();
|
||||
|
||||
if (weight_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
|
||||
@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
|
||||
// Kernel Mapping - BF16 Channelwise
|
||||
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
|
||||
bf16_channelwise_8bit_4bit_kernels = {
|
||||
{kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
|
||||
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return bf16_channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
||||
|
||||
@ -10,21 +10,32 @@
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
enum class kai_kernel_id {
|
||||
// FP32 inputs, 4-bit weights, FP32 output
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
|
||||
0, // Groupwise 4 bit GEMV
|
||||
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
|
||||
1, // Groupwise 4 bit GEMM
|
||||
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
|
||||
2, // Channelwise 4 bit GEMV
|
||||
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
|
||||
3 // Channelwise 4 bit GEMM
|
||||
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
|
||||
|
||||
// BF16 inputs, 4-bit weights, BF16 output
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
|
||||
4, // Channelwise 4-bit GEMV with BF16 input/output
|
||||
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
|
||||
5 // Channelwise 4-bit GEMM with BF16 input/output
|
||||
};
|
||||
|
||||
// Channelwise Kernel mapping
|
||||
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
|
||||
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// bf16 Channelwise Kernel mapping
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
|
||||
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
|
||||
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
|
||||
size_t (*kai_get_lhs_packed_size)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
size_t (*kai_get_rhs_packed_size)(
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
void (*kai_run_lhs_quant_pack)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t m_idx_start,
|
||||
const void* lhs,
|
||||
size_t lhs_stride,
|
||||
void* lhs_packed);
|
||||
void (*kai_run_rhs_pack)(
|
||||
size_t num_groups,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
const uint8_t* rhs,
|
||||
const float* bias,
|
||||
const float* scale,
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
: ukernel(kernel),
|
||||
kai_get_lhs_packed_size(
|
||||
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
|
||||
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// Groupwise Kernel mapping
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
|
||||
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
|
||||
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||
|
||||
@ -140,6 +140,11 @@ static void initDeviceStreamState(DeviceIndex device_index) {
|
||||
static void initOpenRegStreamsOnce() {
|
||||
c10::call_once(init_flag, initGlobalStreamState);
|
||||
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
c10::call_once(
|
||||
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
|
||||
}
|
||||
|
||||
if (current_streams) {
|
||||
return;
|
||||
}
|
||||
@ -202,8 +207,6 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
|
||||
if (device_index == -1) {
|
||||
device_index = current_device();
|
||||
}
|
||||
c10::call_once(
|
||||
device_flags[device_index], initDeviceStreamState, device_index);
|
||||
auto pri_idx =
|
||||
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
|
||||
const auto idx = get_idx(priority_counters[device_index][pri_idx]);
|
||||
|
||||
@ -4101,6 +4101,53 @@ if HAS_CUDA_AND_TRITON:
|
||||
compiled_out = compiled_foo(x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
# Use autotune_at_compile_time=True to test standalone_compile
|
||||
@parametrize("autotune_at_compile_time", [True, False])
|
||||
@config.patch("graph_partition", True)
|
||||
def test_graph_partition_kernel_reuse(self, autotune_at_compile_time):
|
||||
def foo(x):
|
||||
# partition 1
|
||||
x1 = x @ x
|
||||
y1 = x1 + 1
|
||||
z_cpu = y1.cpu() + 1
|
||||
# partition 2
|
||||
# partition 2 should reuse the fused triton kernel generated
|
||||
# in partition 1
|
||||
x2 = z_cpu.to("cuda") @ z_cpu.to("cuda")
|
||||
y2 = x2 + 1
|
||||
return y1, y2
|
||||
|
||||
with config.patch(
|
||||
"triton.autotune_at_compile_time", autotune_at_compile_time
|
||||
):
|
||||
compiled_foo = torch.compile(foo)
|
||||
x = torch.randn((20, 20), device="cuda")
|
||||
eager_out = foo(x)
|
||||
compiled_out, code = run_and_get_code(compiled_foo, x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
if autotune_at_compile_time:
|
||||
# auto-tuning block should only appear once. We generate auto-tuning code
|
||||
# for all the kernels no matter if they are defined in the main graph or
|
||||
# subgraph, to avoid the overhead of executing multiple auto-tuning code blocks.
|
||||
FileCheck().check_count(
|
||||
"Compile-time auto-tuning block", 1, exactly=True
|
||||
).run(code[0])
|
||||
# triton_poi_fused_add_ should appear twice, first in the auto-tuning block,
|
||||
# and then in the main code block
|
||||
FileCheck().check_count(
|
||||
"def triton_poi_fused_add_", 2, exactly=True
|
||||
).run(code[0])
|
||||
# cpu kernel definition should only appence once, not in the auto-tuning block
|
||||
FileCheck().check_count(
|
||||
"cpp_fused__to_copy_add_1 = ", 1, exactly=True
|
||||
).run(code[0])
|
||||
else:
|
||||
# triton_poi_fused_add_ should appear once, because of kernel reuse
|
||||
FileCheck().check_count(
|
||||
"def triton_poi_fused_add_", 1, exactly=True
|
||||
).run(code[0])
|
||||
|
||||
def test_meta_tensor(self):
|
||||
def foobar(x, y):
|
||||
return x * 2, y * 3
|
||||
|
||||
@ -4,8 +4,9 @@ from functools import partial
|
||||
from unittest import skipIf
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.ir import Pointwise
|
||||
from torch._inductor.lowering import make_pointwise, register_lowering
|
||||
from torch._inductor.lowering import make_fallback, make_pointwise, register_lowering
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.virtualized import ops
|
||||
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
|
||||
@ -237,6 +238,17 @@ class TestCustomLowering(InductorTestCase):
|
||||
out2 = fn_opt(a, b)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
@config.patch(joint_graph_constant_folding=False)
|
||||
def test_constant_creation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + torch.tensor(1)
|
||||
|
||||
make_fallback(torch.ops.aten.lift_fresh_copy.default)
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.compile(M())(torch.ones(3)), torch.ones(3) + 1)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
@ -30,6 +30,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.aoti_eager
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.nn as nn
|
||||
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
@ -2476,12 +2477,11 @@ class CommonTemplate:
|
||||
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
||||
self.common(fn, (a, b_int8pack, b_scales, c))
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
|
||||
def test__dyn_quant_pack_4bit_weight(self):
|
||||
def test__dyn_quant_pack_4bit_weight_fp32(self):
|
||||
q_group = 32
|
||||
k = 128
|
||||
n = 128
|
||||
@ -2512,12 +2512,46 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (b, in_features, out_features))
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
|
||||
def test__dyn_quant_pack_4bit_weight_bf16(self):
|
||||
q_group = 32
|
||||
k = 128
|
||||
n = 128
|
||||
|
||||
torch.manual_seed(1)
|
||||
b = torch.rand((k, n), dtype=torch.bfloat16)
|
||||
in_features = b.size(0)
|
||||
out_features = b.size(1)
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def fn(b, in_features, out_features):
|
||||
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
|
||||
return b_int4pack
|
||||
|
||||
self.common(fn, (b, in_features, out_features))
|
||||
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
|
||||
def test__dyn_quant_matmul_4bit(self):
|
||||
def test__dyn_quant_matmul_4bit_fp32_input(self):
|
||||
q_group = 32
|
||||
m = 32
|
||||
k = 128
|
||||
@ -2557,6 +2591,60 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (a, q_group, in_features, out_features))
|
||||
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
|
||||
def test__dyn_quant_matmul_4bit_bf16_input(self):
|
||||
m = 32
|
||||
k = 128
|
||||
n = 128
|
||||
q_group = k
|
||||
|
||||
torch.manual_seed(1)
|
||||
a = torch.rand((m, k), dtype=torch.bfloat16)
|
||||
b = torch.rand((k, n), dtype=torch.bfloat16)
|
||||
|
||||
# codegen_dynamic_shape test fails without explicitly marking these dynamic
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(b, 1)
|
||||
|
||||
in_features = b.size(0)
|
||||
out_features = b.size(1)
|
||||
|
||||
if not self.is_dtype_supported(torch.bfloat16):
|
||||
raise unittest.SkipTest(
|
||||
f"torch.bfloat16 not supported for device {self.device}"
|
||||
)
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def fn(a, q_group, in_features, out_features):
|
||||
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
|
||||
res = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
return res
|
||||
|
||||
self.common(fn, (a, q_group, in_features, out_features), atol=1, rtol=0.5)
|
||||
|
||||
def test_expanded_reduction(self):
|
||||
def fn(x, y):
|
||||
z = x * y
|
||||
@ -13564,6 +13652,224 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
|
||||
FileCheck().check_regex(size_assert_pattern).run(code)
|
||||
|
||||
def test_lite_mode_fallback(self):
|
||||
def f(x):
|
||||
z = x.sin()
|
||||
return z.cos()
|
||||
|
||||
f = torch.compile(f, mode="lite")
|
||||
|
||||
_, code = run_and_get_code(f, torch.randn(2, device=self.device))
|
||||
|
||||
# Checks that aten ops are kept and run
|
||||
if config.cpp_wrapper:
|
||||
FileCheck().check("aoti_torch_call_dispatcher(").check("aten::sin").check(
|
||||
"aoti_torch_call_dispatcher("
|
||||
).check("aten::cos").run(code[0])
|
||||
else:
|
||||
FileCheck().check("torch.ops.aten.sin.default(").check(
|
||||
"torch.ops.aten.cos.default("
|
||||
).run(code[0])
|
||||
# Checks that no triton code run in the generated code
|
||||
self.assertFalse(".run(" in code[0])
|
||||
|
||||
# skip cpu test since rms norm is always decomposed on cpu
|
||||
def test_lite_mode_not_decompose(self):
|
||||
if self.device != GPU_TYPE or self.device == "mps":
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
def f(x, shape):
|
||||
y = x + 1
|
||||
z = torch.ops.aten._fused_rms_norm(y, shape, None, None)
|
||||
return z[0] + z[1]
|
||||
|
||||
f = torch.compile(f, mode="lite")
|
||||
|
||||
x = torch.randn(2, 3, device=self.device)
|
||||
_, code = run_and_get_code(f, x, [2, 3])
|
||||
if config.cpp_wrapper:
|
||||
FileCheck().check(
|
||||
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda__fused_rms_norm("
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("torch.ops.aten._fused_rms_norm.default(").run(code[0])
|
||||
|
||||
if config.cpp_wrapper:
|
||||
# arg type List[int] is not yet supported by custom_op_wrapper
|
||||
pass
|
||||
else:
|
||||
x = torch.randn(2, 3, device=self.device, requires_grad=True)
|
||||
_, codes = run_fw_bw_and_get_code(lambda: f(x, [2, 3]))
|
||||
self.assertEqual(len(codes), 2)
|
||||
FileCheck().check("torch.ops.aten._fused_rms_norm.default(").run(code[0])
|
||||
|
||||
def test_lite_regional_compile_flex_attention(self):
|
||||
if self.device != GPU_TYPE or self.device == "mps":
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
def mask_mod(b, h, q, k):
|
||||
return q >= 0
|
||||
|
||||
a = 12
|
||||
b = 64
|
||||
block_mask = create_block_mask(
|
||||
mask_mod, None, None, a * b, a * b, device=self.device
|
||||
)
|
||||
|
||||
def fn(x):
|
||||
x = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
|
||||
return torch.cos(x)
|
||||
|
||||
x = torch.randn(
|
||||
1,
|
||||
1,
|
||||
a * b,
|
||||
b,
|
||||
dtype=torch.bfloat16,
|
||||
device=self.device,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn,
|
||||
mode="lite",
|
||||
fullgraph=True,
|
||||
)
|
||||
|
||||
# Check that inductor compilation is called twice
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper,
|
||||
"codegen invoke_subgraph is not implemented for cpp wrapper",
|
||||
)
|
||||
def test_lite_regional_compile_invoke_subgraph(self):
|
||||
# Checks that get_attr nodes custom metadata is propagated
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
z = gn(x)
|
||||
return torch.sigmoid(z)
|
||||
|
||||
opt_fn = torch.compile(fn, mode="lite", fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper,
|
||||
"codegen triton_kernel_wrapper_functional is not implemented for cpp wrapper",
|
||||
)
|
||||
def test_lite_triton_kernel_wrapper_functional(self):
|
||||
if self.device != GPU_TYPE or self.device == "mps":
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
kernel_side_table,
|
||||
triton_kernel_wrapper_functional,
|
||||
)
|
||||
from torch.testing._internal.triton_utils import mul2_kernel
|
||||
|
||||
kernel_side_table.reset_table()
|
||||
|
||||
def f(x, output):
|
||||
out = triton_kernel_wrapper_functional(
|
||||
kernel_idx=kernel_side_table.add_kernel(mul2_kernel),
|
||||
constant_args_idx=kernel_side_table.add_constant_args(
|
||||
{"n_elements": output.numel(), "BLOCK_SIZE": 16}
|
||||
),
|
||||
grid=[(x.numel(),)],
|
||||
tma_descriptor_metadata={},
|
||||
kwargs={
|
||||
"in_ptr0": x,
|
||||
"out_ptr": output,
|
||||
},
|
||||
tensors_to_clone=["in_ptr0", "out_ptr"],
|
||||
)
|
||||
return out["out_ptr"]
|
||||
|
||||
t1 = torch.rand(5, device=self.device)
|
||||
t2 = torch.rand(5, device=self.device)
|
||||
|
||||
compiled_f = torch.compile(f, mode="lite")
|
||||
out = compiled_f(t1, t2)
|
||||
|
||||
# Make sure t2 was not modified
|
||||
self.assertNotEqual(out, t2)
|
||||
|
||||
def test_lite_regional_compile_repeated_blocks(self):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
|
||||
return torch.sin(add)
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
a = fn(x, y)
|
||||
return fn(a, y)
|
||||
|
||||
mod = Mod()
|
||||
|
||||
opt_mod = torch.compile(
|
||||
mod,
|
||||
mode="lite",
|
||||
fullgraph=True,
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_lite_dynamic_shape_assertion(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, c):
|
||||
d = torch.concat([c, c], dim=0)
|
||||
with fx_traceback.annotate({"compile_with_inductor": "my_region"}):
|
||||
d = d + 1
|
||||
return d
|
||||
|
||||
model = Model()
|
||||
model = torch.compile(
|
||||
model,
|
||||
mode="lite",
|
||||
fullgraph=True,
|
||||
)
|
||||
|
||||
c = torch.randn((64, 32), device=self.device)
|
||||
torch._dynamo.decorators.mark_unbacked(c, 0)
|
||||
|
||||
_, code = run_and_get_code(model, c)
|
||||
# Checks that unbacked symint assertions are kept
|
||||
if config.cpp_wrapper:
|
||||
FileCheck().check_regex(r"if \(!\(u.* >= 0L\)\)").check_regex(
|
||||
"Expected u.* >= 0 but receive"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check_regex(r"if not \(u.* >= 0\):").check_regex(
|
||||
r"raise RuntimeError\('u.* >= 0'\)"
|
||||
).run(code[0])
|
||||
|
||||
@lowering.force_fallback(aten.sort.default)
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper,
|
||||
|
||||
@ -31,7 +31,6 @@ from torch.testing._internal.common_utils import (
|
||||
serialTest,
|
||||
TEST_CUDA_MEM_LEAK_CHECK,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
@ -93,17 +92,6 @@ if not torch._inductor.config.cpp_wrapper:
|
||||
("cuda",)
|
||||
)
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
# Tensor-likes are not close
|
||||
test_failures["test_dynamic_stride_nobreak"] = TestFailure(
|
||||
("cpu", "cuda"), is_skip=True
|
||||
)
|
||||
test_failures["test_item_to_inputs_kernel_nobreak"] = TestFailure(
|
||||
("cpu", "cuda"), is_skip=True
|
||||
)
|
||||
test_failures["test_unbacked_reduction"] = TestFailure(("cpu"), is_skip=True)
|
||||
|
||||
|
||||
if any(os.getenv("BUILD_ENVIRONMENT", "").endswith(x) for x in ("-debug", "-asan")):
|
||||
# Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073
|
||||
# After https://github.com/pytorch/pytorch/pull/161586, starts failing UBSAN so we can't even xfail.
|
||||
|
||||
@ -492,6 +492,36 @@ class PackedSequenceTest(TestCase):
|
||||
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
|
||||
)
|
||||
|
||||
def test_empty_packed_sequence(self):
|
||||
"""
|
||||
Regression test for https://github.com/pytorch/pytorch/issues/149622
|
||||
Tests that pad_packed_sequence and unpack_sequence handle empty tensors
|
||||
without segmentation fault (CVE-2025-2998, CVE-2025-2999)
|
||||
"""
|
||||
# Test case 1: pad_packed_sequence with empty tensors
|
||||
# Previously caused segmentation fault
|
||||
empty_data = torch.randn(0, 5)
|
||||
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
|
||||
empty_packed = rnn_utils.PackedSequence(
|
||||
empty_data, empty_batch_sizes, None, None
|
||||
)
|
||||
|
||||
# Should not crash - either return empty result or raise informative error
|
||||
with self.assertRaises(RuntimeError):
|
||||
rnn_utils.pad_packed_sequence(empty_packed, batch_first=True)
|
||||
|
||||
# Test case 2: unpack_sequence with empty tensors
|
||||
# Previously caused segmentation fault
|
||||
empty_data = torch.tensor([])
|
||||
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
|
||||
packed = rnn_utils.PackedSequence(
|
||||
data=empty_data, batch_sizes=empty_batch_sizes
|
||||
)
|
||||
|
||||
# Should not crash - either return empty list or raise informative error
|
||||
with self.assertRaises(RuntimeError):
|
||||
rnn_utils.unpack_sequence(packed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -7798,7 +7798,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test__dyn_quant_matmul_4bit(self, device, m, k, n):
|
||||
def test__dyn_quant_matmul_4bit_fp32(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
@ -7870,7 +7870,86 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n):
|
||||
def test__dyn_quant_matmul_4bit_bf16(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
|
||||
torch.manual_seed(1)
|
||||
a_bfloat16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
b_bfloat16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
|
||||
in_features = b_bfloat16.size(0)
|
||||
out_features = b_bfloat16.size(1)
|
||||
q_group = in_features
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def dyn_quant_matmul_4bit(
|
||||
a, b_int4pack, q_group, in_features, out_features
|
||||
):
|
||||
return torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
|
||||
b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight(
|
||||
b_bfloat16, in_features, out_features
|
||||
)
|
||||
|
||||
dtypes = [torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
a = a_bfloat16.to(dtype=dtype)
|
||||
b = b_bfloat16.to(dtype=dtype)
|
||||
ref = torch.mm(a, b)
|
||||
res = dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
|
||||
# Mean relative error check
|
||||
expected_mean_err = 0.00952
|
||||
mean_err_tol = 0.005 # allow small deviation (±0.005)
|
||||
mean_err = ((res - ref).abs() / ref.abs().clamp(min=1e-5)).mean()
|
||||
self.assertTrue(
|
||||
abs(mean_err - expected_mean_err) < mean_err_tol,
|
||||
f"Mean relative error {mean_err:.6f} deviates from expected {expected_mean_err}"
|
||||
)
|
||||
|
||||
# Elementwise relative error check
|
||||
elementwise_diff = (res - ref).abs()
|
||||
elementwise_relative_error = elementwise_diff / ref.abs().clamp(min=torch.finfo(ref.dtype).eps)
|
||||
self.assertTrue(
|
||||
torch.all(elementwise_relative_error < 0.070),
|
||||
"Some elements have relative error >= 7%"
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test_compile_dyn_quant_matmul_4bit_fp32(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
@ -7928,6 +8007,83 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test_compile_dyn_quant_matmul_4bit_bf16(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
|
||||
torch.manual_seed(1)
|
||||
a_bfloat16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
b_bfloat16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
|
||||
in_features = b_bfloat16.size(0)
|
||||
out_features = b_bfloat16.size(1)
|
||||
q_group = in_features
|
||||
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b_bfloat16, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16)
|
||||
|
||||
@torch.compile
|
||||
def dyn_quant_matmul_4bit(
|
||||
a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features
|
||||
):
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
return torch._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
|
||||
res = dyn_quant_matmul_4bit(
|
||||
a_bfloat16,
|
||||
b_uint8,
|
||||
b_scales_and_zeros,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
ref = torch.mm(a_bfloat16, b_bfloat16)
|
||||
|
||||
# === Accuracy checks ===
|
||||
|
||||
# Mean relative error check
|
||||
expected_mean_err = 0.00952
|
||||
mean_err_tol = 0.005 # allow small deviation (±0.005)
|
||||
mean_err = ((res - ref).abs() / ref.abs().clamp(min=1e-5)).mean()
|
||||
self.assertTrue(
|
||||
abs(mean_err - expected_mean_err) < mean_err_tol,
|
||||
f"Mean relative error {mean_err:.6f} deviates from expected {expected_mean_err}"
|
||||
)
|
||||
|
||||
# Avoid divide-by-zero with clamp
|
||||
denominator = ref.abs().clamp(min=torch.finfo(ref.dtype).eps)
|
||||
|
||||
# Compute elementwise relative error — always non-negative
|
||||
elementwise_relative_error = (res - ref).abs() / denominator
|
||||
|
||||
# Check if all elements are within 6% error
|
||||
assert torch.all(elementwise_relative_error >= 0), "Relative error should never be negative"
|
||||
self.assertTrue(
|
||||
torch.all(elementwise_relative_error < 0.070),
|
||||
"Some elements have relative error >= 7%"
|
||||
)
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("m", [32, 64])
|
||||
@parametrize("k", [32, 64])
|
||||
@parametrize("n", [48, 64])
|
||||
|
||||
@ -103,6 +103,17 @@ from .utils import (
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_skip_decompose(aot_config: AOTConfig):
|
||||
old_decomp = aot_config.decompositions
|
||||
try:
|
||||
if config.selective_decompose:
|
||||
aot_config.decompositions = {}
|
||||
yield
|
||||
finally:
|
||||
aot_config.decompositions = old_decomp
|
||||
|
||||
|
||||
# Saved tensor hooks context
|
||||
# Compiled saved tensor hooks are convenient way to inline some logic in the graphs
|
||||
# for saved nodes from forward to backward. (E.g. activations quantization)
|
||||
@ -196,11 +207,28 @@ def aot_stage1_graph_capture(
|
||||
# deterministic TLS can be different
|
||||
aot_state.fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]]
|
||||
if aot_state.needs_autograd and not aot_config.pre_dispatch:
|
||||
# FYI: this being moved to trigger in export is new, seems fine!
|
||||
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
|
||||
|
||||
with maybe_skip_decompose(aot_config):
|
||||
# if config.selective_decompose, skip decomposition and apply selective_decompose
|
||||
# after we get the joint graph. See [Note: Selective Decomposition] for details.
|
||||
if aot_state.needs_autograd and not aot_config.pre_dispatch:
|
||||
# FYI: this being moved to trigger in export is new, seems fine!
|
||||
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
|
||||
(
|
||||
graph,
|
||||
updated_flat_args,
|
||||
updated_flat_args_descs,
|
||||
maybe_subclass_meta,
|
||||
) = aot_dispatch_autograd_graph(
|
||||
flat_fn,
|
||||
aot_state.flat_args,
|
||||
aot_state.flat_args_descs,
|
||||
aot_config,
|
||||
fw_metadata=aot_state.fw_metadata,
|
||||
)
|
||||
else:
|
||||
graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = (
|
||||
aot_dispatch_autograd_graph(
|
||||
aot_dispatch_base_graph(
|
||||
flat_fn,
|
||||
aot_state.flat_args,
|
||||
aot_state.flat_args_descs,
|
||||
@ -208,15 +236,17 @@ def aot_stage1_graph_capture(
|
||||
fw_metadata=aot_state.fw_metadata,
|
||||
)
|
||||
)
|
||||
else:
|
||||
graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = (
|
||||
aot_dispatch_base_graph( # type: ignore[assignment]
|
||||
flat_fn,
|
||||
aot_state.flat_args,
|
||||
aot_state.flat_args_descs,
|
||||
aot_config,
|
||||
fw_metadata=aot_state.fw_metadata,
|
||||
)
|
||||
|
||||
if config.selective_decompose:
|
||||
from torch.fx.experimental.proxy_tensor import selective_decompose
|
||||
from torch.fx.passes.regional_inductor import _needs_inductor_compile
|
||||
|
||||
graph = selective_decompose(
|
||||
graph,
|
||||
*updated_flat_args,
|
||||
decomposition=aot_config.decompositions,
|
||||
should_decompose=_needs_inductor_compile,
|
||||
trace_joint_graph=aot_state.needs_autograd and not aot_config.pre_dispatch,
|
||||
)
|
||||
|
||||
return AOTGraphCapture(
|
||||
|
||||
@ -374,6 +374,13 @@ saved_tensors_hooks_filtering_mode = "donated"
|
||||
# This callback is invoked on the joint graph before partitioning
|
||||
joint_custom_pass: Callable = None # type: ignore[assignment]
|
||||
|
||||
# Note [Selective Decomposition]
|
||||
# This config allows selective decomposition of certain operators in the graph.
|
||||
# When True, it does NOT decompose any nodes, except those nodes that users explicitly
|
||||
# annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
|
||||
# on to explicitly annotate. This is currently only used by inductor lite mode.
|
||||
selective_decompose: bool = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
@ -315,6 +315,25 @@ def aot_compile(
|
||||
)
|
||||
|
||||
|
||||
lite_mode_options = {
|
||||
# Fallback by default unless users explicitly annotated with
|
||||
# regional inductor compile.
|
||||
"fallback_by_default": True,
|
||||
"selective_decompose": True,
|
||||
# Disable reorder optimizations
|
||||
"reorder_for_peak_memory": False,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
"triton.reorder_for_reducing_graph_partitions": False,
|
||||
# Disable pre-, joint-, post-grad passes
|
||||
"use_pre_grad_passes": False,
|
||||
"use_joint_graph_passes": False,
|
||||
"use_post_grad_passes": False,
|
||||
# Disable dead code elimination (dce) and buffer reuse
|
||||
"use_dce": False,
|
||||
"allow_buffer_reuse": False,
|
||||
}
|
||||
|
||||
|
||||
def list_mode_options(
|
||||
mode: Optional[str] = None, dynamic: Optional[bool] = None
|
||||
) -> dict[str, Any]:
|
||||
@ -332,6 +351,8 @@ def list_mode_options(
|
||||
|
||||
mode_options: dict[str, dict[str, bool]] = {
|
||||
"default": {},
|
||||
# lite backend for opt-in optimizations
|
||||
"lite": lite_mode_options,
|
||||
# enable cudagraphs
|
||||
"reduce-overhead": {
|
||||
"triton.cudagraphs": True,
|
||||
|
||||
@ -2259,7 +2259,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
gpu: bool = True,
|
||||
cpp_definition: Optional[str] = None,
|
||||
):
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if config.triton.autotune_at_compile_time and gpu:
|
||||
body = self._format_kernel_definition(
|
||||
kernel_name, kernel_body, metadata=metadata
|
||||
)
|
||||
@ -3745,6 +3745,13 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
|
||||
super().__init__()
|
||||
|
||||
root = self.get_root_graph()
|
||||
# Only generate auto-tuning block in the main graph
|
||||
self.kernel_autotune_defs = root.kernel_autotune_defs
|
||||
self.kernel_autotune_calls = root.kernel_autotune_calls
|
||||
# Only store kernel src to name mapping in the main graph
|
||||
self.src_to_kernel = root.src_to_kernel
|
||||
|
||||
def set_launcher_fn_name(self) -> None:
|
||||
# This sets up the name of the function containing the launcher code of
|
||||
# the subgraph.
|
||||
@ -3837,3 +3844,16 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
# )
|
||||
self.parent_wrapper.write_get_raw_stream_header_once()
|
||||
|
||||
@cache_on_self
|
||||
def get_root_graph(self) -> PythonWrapperCodegen:
|
||||
root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
|
||||
while isinstance(root, SubgraphPythonWrapperCodegen):
|
||||
root = root.parent_wrapper
|
||||
|
||||
assert isinstance(root, PythonWrapperCodegen)
|
||||
return root
|
||||
|
||||
def generate_and_run_autotune_block(self):
|
||||
# Only execute auto-tuning block in the main graph
|
||||
pass
|
||||
|
||||
@ -508,6 +508,9 @@ def _recursive_pre_grad_passes(
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="pre_grad_pass_time_us",
|
||||
):
|
||||
if not config.use_pre_grad_passes:
|
||||
return gm
|
||||
|
||||
add_passes = config.add_pre_grad_passes
|
||||
remove_passes = config.remove_pre_grad_passes
|
||||
for subgraph_name in _get_subgraph_names(gm):
|
||||
@ -526,6 +529,9 @@ def _recursive_joint_graph_passes(
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="joint_graph_pass_time_us",
|
||||
):
|
||||
if not config.use_joint_graph_passes:
|
||||
return
|
||||
|
||||
# invoke_subgraph already runs the _recursive_joint_graph_passes. In
|
||||
# AOTAutograd, `run_joint_graph_passes_on_hops` partitions the
|
||||
# invoke_subgraph HOP before calling the partitioner on the outer graph.
|
||||
@ -544,6 +550,9 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="post_grad_pass_time_us",
|
||||
):
|
||||
if not config.use_post_grad_passes:
|
||||
return
|
||||
|
||||
for subgraph_name in _get_subgraph_names(gm):
|
||||
subgraph = getattr(gm, subgraph_name)
|
||||
_recursive_post_grad_passes(subgraph, is_inference)
|
||||
@ -2708,7 +2717,10 @@ def _compile_fx_main(
|
||||
|
||||
is_valid_aoti_model_name()
|
||||
|
||||
with functorch_config.patch(unlift_effect_tokens=True):
|
||||
with functorch_config.patch(
|
||||
unlift_effect_tokens=True,
|
||||
selective_decompose=config.selective_decompose,
|
||||
):
|
||||
gm, graph_signature = aot_export_module(
|
||||
model_,
|
||||
example_inputs_,
|
||||
@ -2768,7 +2780,10 @@ def _compile_fx_main(
|
||||
V.set_fake_mode(fake_mode),
|
||||
torch._guards.tracing(tracing_context),
|
||||
compiled_autograd._disable(),
|
||||
functorch_config.patch(unlift_effect_tokens=True),
|
||||
functorch_config.patch(
|
||||
unlift_effect_tokens=True,
|
||||
selective_decompose=config.selective_decompose,
|
||||
),
|
||||
):
|
||||
try:
|
||||
return aot_autograd(
|
||||
|
||||
@ -550,6 +550,32 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
|
||||
# Fall back to ATen for all ops by default, except those nodes that users explicitly
|
||||
# annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
|
||||
# on to explicitly annotate. This is currently only used by inductor lite mode.
|
||||
# Different from default inductor mode that fuses all nodes, this config enables an
|
||||
# opt-in mode that only fuse for user-specified nodes. The motivation is to provide
|
||||
# guaranteed numeric correctness and give full control to users.
|
||||
fallback_by_default: bool = False
|
||||
|
||||
|
||||
# This config allows selective decomposition of certain operators in the graph.
|
||||
# Currently the only use case is to patch the same-name config in functorch, for
|
||||
# inductor lite mode. See more details in [Note: Selective Decomposition]
|
||||
selective_decompose: bool = False
|
||||
|
||||
|
||||
# Use dead code elimination
|
||||
use_dce: bool = True
|
||||
|
||||
|
||||
# Use fx graph passes
|
||||
use_pre_grad_passes: bool = True
|
||||
use_joint_graph_passes: bool = True
|
||||
use_post_grad_passes: bool = True
|
||||
|
||||
|
||||
cutedsl_enable_autotuning: bool = (
|
||||
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
|
||||
)
|
||||
@ -1373,6 +1399,10 @@ class triton:
|
||||
default=False,
|
||||
)
|
||||
|
||||
# reorder nodes to minimize the number of graph partitions while
|
||||
# not incurring large memory overhead
|
||||
reorder_for_reducing_graph_partitions: bool = True
|
||||
|
||||
# assertions on the fast path
|
||||
fast_path_cudagraph_asserts = False
|
||||
|
||||
|
||||
@ -110,6 +110,7 @@ from .utils import (
|
||||
maybe_get_suppress_shape_guards_ctx,
|
||||
normalize_name,
|
||||
should_assume_input_aligned,
|
||||
should_fallback_by_default,
|
||||
SUPPORTED_MKLDNN_DEVICES,
|
||||
ValueWithLineMap,
|
||||
)
|
||||
@ -1634,6 +1635,20 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
*args, # type: ignore[possibly-undefined]
|
||||
**kwargs, # type: ignore[possibly-undefined]
|
||||
)
|
||||
elif (
|
||||
n.op == "call_function"
|
||||
and isinstance(
|
||||
n.target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
|
||||
)
|
||||
and should_fallback_by_default(n)
|
||||
):
|
||||
# this path supports fallback due to inductor lite mode. It supports
|
||||
# both OpOverload and HOPs (e.g., triton_kernel_wrapper_functional).
|
||||
debug("fallback_handler")
|
||||
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
||||
*args, # type: ignore[possibly-undefined]
|
||||
**kwargs, # type: ignore[possibly-undefined]
|
||||
)
|
||||
elif (
|
||||
n.op == "call_function"
|
||||
and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation
|
||||
|
||||
@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
)
|
||||
from torch.fx.node import Node
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
|
||||
from torch.utils._sympy.symbol import SymT
|
||||
|
||||
@ -6135,9 +6136,12 @@ class ExternKernel(InputsKernel):
|
||||
if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
|
||||
return ShapeAsConstantBuffer(expr=x)
|
||||
if isinstance(x, Constant):
|
||||
return V.graph.add_tensor_constant(
|
||||
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
||||
)
|
||||
# We need to unset fake mode, or else the torch.tensor() call will
|
||||
# turn into a FakeTensor
|
||||
with _disable_current_modes():
|
||||
return V.graph.add_tensor_constant(
|
||||
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
||||
)
|
||||
if isinstance(x, ConstantBuffer):
|
||||
return x
|
||||
if isinstance(x, TensorBox):
|
||||
|
||||
@ -2742,8 +2742,9 @@ class Scheduler:
|
||||
self.process_grouped_nodes()
|
||||
|
||||
if (
|
||||
torch._inductor.config.graph_partition
|
||||
and torch._inductor.config.triton.cudagraphs
|
||||
config.graph_partition
|
||||
and config.triton.cudagraphs
|
||||
and config.triton.reorder_for_reducing_graph_partitions
|
||||
):
|
||||
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
|
||||
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
|
||||
@ -3191,6 +3192,9 @@ class Scheduler:
|
||||
"""
|
||||
Remove any nodes without users
|
||||
"""
|
||||
if not config.use_dce:
|
||||
return
|
||||
|
||||
# self.nodes is in topological order, so by iterating in reverse order
|
||||
# we have visited (and potentially removed) all users before visiting a
|
||||
# given node.
|
||||
|
||||
@ -58,6 +58,7 @@ import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.analysis.device_info import datasheet_tops
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
from torch.fx.passes.regional_inductor import _needs_inductor_compile
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_flatten, tree_map_only
|
||||
@ -4036,3 +4037,40 @@ def load_template(name: str, template_dir: Path) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(template_dir / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def should_fallback_by_default(node: torch.fx.Node) -> bool:
|
||||
"""Decide whether fallback for a node. This is only used in inductor lite mode."""
|
||||
target = node.target
|
||||
|
||||
assert isinstance(
|
||||
target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
|
||||
), f"Expected OpOverload or HigherOrderOperator, but found {type(target)}"
|
||||
|
||||
if not config.fallback_by_default:
|
||||
return False
|
||||
|
||||
# some ops need special handle due to dynamic shapes. we can avoid
|
||||
# fallback if they do not impact numerics.
|
||||
skip_fallback_due_to_dynamic_shape = OrderedSet(
|
||||
[
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
torch.ops.aten.lift_fresh_copy.default,
|
||||
]
|
||||
)
|
||||
|
||||
if target in skip_fallback_due_to_dynamic_shape:
|
||||
return False
|
||||
|
||||
# Most hops have registered lowering. We should follow the lowering and not fallback.
|
||||
# However, in rare cases, hops may not register lowering, such as
|
||||
# torch.ops.higher_order.triton_kernel_wrapper_functional. We should fallback for
|
||||
# these hops.
|
||||
fallback_hops = OrderedSet(
|
||||
[torch.ops.higher_order.triton_kernel_wrapper_functional]
|
||||
)
|
||||
|
||||
if isinstance(target, torch._ops.HigherOrderOperator):
|
||||
return target in fallback_hops
|
||||
|
||||
return not _needs_inductor_compile(node)
|
||||
|
||||
@ -3722,6 +3722,7 @@ def kai_roundup(a: int, b: int) -> int:
|
||||
|
||||
def get_kai_packed_weight_size(n_bits, N, K, groupsize):
|
||||
if n_bits == 4:
|
||||
# Works for both fp32 and bf16 Kernels
|
||||
if groupsize == K: # channelwise
|
||||
# dotprod params only [1x8x32_neon_dotprod]
|
||||
kai_nr = 8
|
||||
@ -3851,6 +3852,8 @@ def meta__dyn_quant_pack_4bit_weight(
|
||||
)
|
||||
return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
|
||||
packed_weight_size = weights.numel() + scales_zeros.numel()
|
||||
if bias is not None:
|
||||
packed_weight_size += bias.numel()
|
||||
return weights.new_empty(packed_weight_size, dtype=torch.float)
|
||||
|
||||
|
||||
@ -3864,8 +3867,12 @@ def meta__dyn_quant_matmul_4bit(
|
||||
):
|
||||
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
|
||||
torch._check(
|
||||
inp.dtype == torch.float32,
|
||||
lambda: f"expected input to be f32, got {inp.dtype}",
|
||||
(inp.dtype == torch.float32)
|
||||
or (inp.dtype == torch.bfloat16 and block_size == in_features),
|
||||
lambda: (
|
||||
f"expected input to be f32 or bf16 (bf16 requires block_size == in_features), "
|
||||
f"got {inp.dtype} with block_size={block_size} and in_features={in_features}"
|
||||
),
|
||||
)
|
||||
M = inp.size(0)
|
||||
return inp.new_empty(M, out_features, dtype=inp.dtype)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/csrc/stable/device_struct.h>
|
||||
@ -120,7 +119,7 @@ struct FromImpl<ScalarType> {
|
||||
case ScalarType::UInt64:
|
||||
return from(aoti_torch_dtype_uint64());
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported ScalarType, please file an issue describing your use case.");
|
||||
}
|
||||
@ -151,7 +150,7 @@ struct FromImpl<DeviceType> {
|
||||
case DeviceType::PrivateUse1:
|
||||
return from(aoti_torch_device_type_privateuse1());
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported DeviceType, please file an issue describing your use case.");
|
||||
}
|
||||
@ -379,7 +378,7 @@ struct ToImpl<ScalarType> {
|
||||
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
|
||||
return ScalarType::UInt64;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported ScalarType ",
|
||||
std::to_string(shim_scalartype),
|
||||
@ -409,7 +408,7 @@ struct ToImpl<DeviceType> {
|
||||
} else if (shim_devicetype == aoti_torch_device_type_privateuse1()) {
|
||||
return DeviceType::PrivateUse1;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported DeviceType ",
|
||||
std::to_string(shim_devicetype),
|
||||
|
||||
@ -42,6 +42,7 @@ from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._logging import trace_structured
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_impls import fast_detach
|
||||
from torch._subclasses.fake_tensor import (
|
||||
FakeTensor,
|
||||
@ -90,6 +91,7 @@ __all__ = [
|
||||
"dispatch_trace",
|
||||
"make_fx",
|
||||
"DecompositionInterpreter",
|
||||
"selective_decompose",
|
||||
"py_sym_types",
|
||||
"get_innermost_proxy_mode",
|
||||
"get_proxy_mode",
|
||||
@ -1881,6 +1883,93 @@ class DecompositionInterpreter(fx.Interpreter):
|
||||
return super().run(*args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class _SelectiveDecomposeInterpreter(fx.Interpreter):
|
||||
def __init__(
|
||||
self,
|
||||
module: fx.GraphModule,
|
||||
should_decompose: Callable[[fx.Node], bool],
|
||||
decomposition_table: Mapping[OpOverload, Callable],
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
"""
|
||||
For all nodes in `module`, selectively decompose if is `should_decompose`,
|
||||
following the given `decomposition_table`.
|
||||
"""
|
||||
super().__init__(module, **kwargs) # type: ignore[arg-type]
|
||||
self.should_decompose = should_decompose
|
||||
self.decomposition_table = decomposition_table
|
||||
|
||||
@staticmethod
|
||||
def recursive_wrap(
|
||||
gm: fx.GraphModule,
|
||||
should_decompose: Callable[[fx.Node], bool],
|
||||
decomposition_table: Mapping[OpOverload, Callable],
|
||||
**kwargs: object,
|
||||
) -> _SelectiveDecomposeInterpreter:
|
||||
"""
|
||||
Recursively wrap gm and its sub graph modules. Specifically, HOP takes
|
||||
sub graph module as args. We may not want to decompose all nodes within
|
||||
these sub graph modules. So we also need to wrap these sub graph modules.
|
||||
As a result:
|
||||
- if should_decompose(hop) is True, we decompose all nodes within the hop.
|
||||
- if should_decompose(hop) is False, we check each node within the hop
|
||||
and decide whether decompose or not.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and isinstance(
|
||||
node.target, HigherOrderOperator
|
||||
):
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if isinstance(arg, fx.GraphModule):
|
||||
new_arg = _SelectiveDecomposeInterpreter.recursive_wrap(
|
||||
arg, should_decompose, decomposition_table, **kwargs
|
||||
)
|
||||
else:
|
||||
new_arg = arg
|
||||
new_args.append(new_arg)
|
||||
node.args = tuple(new_args)
|
||||
|
||||
return _SelectiveDecomposeInterpreter(
|
||||
gm, should_decompose, decomposition_table, **kwargs
|
||||
)
|
||||
|
||||
def run_node(self, n):
|
||||
if self.should_decompose(n):
|
||||
with decompose(self.decomposition_table):
|
||||
result = super().run_node(n)
|
||||
else:
|
||||
result = super().run_node(n)
|
||||
return result
|
||||
|
||||
|
||||
def selective_decompose(
|
||||
joint_gm: fx.GraphModule,
|
||||
*args,
|
||||
decomposition,
|
||||
should_decompose,
|
||||
trace_joint_graph: bool,
|
||||
) -> fx.GraphModule:
|
||||
"""Retrace a joint graph module and selectively apply decomposition."""
|
||||
|
||||
if trace_joint_graph:
|
||||
# the arg name, primals and tangents, are important.
|
||||
# make_fx keeps the name in the traced graph and partitioner later relies
|
||||
# on the name to partition joint graph correctly.
|
||||
def wrap_fn(primals: list[Any], tangents: list[Any]):
|
||||
return _SelectiveDecomposeInterpreter.recursive_wrap(
|
||||
joint_gm, should_decompose, decomposition
|
||||
).run(*args)
|
||||
else:
|
||||
|
||||
def wrap_fn(*args):
|
||||
return _SelectiveDecomposeInterpreter.recursive_wrap(
|
||||
joint_gm, should_decompose, decomposition
|
||||
).run(*args)
|
||||
|
||||
return make_fx(wrap_fn, decomposition_table={})(*args)
|
||||
|
||||
|
||||
def wrapper_and_args_for_make_fx(
|
||||
func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object]
|
||||
) -> tuple[Callable[[list[object]], R], list[object]]:
|
||||
|
||||
@ -112,7 +112,7 @@ def _compile_submod(gm, prefix):
|
||||
return gm
|
||||
|
||||
|
||||
def _needs_inductor_compile(node):
|
||||
def _needs_inductor_compile(node: torch.fx.Node):
|
||||
return (
|
||||
node.op not in ("placeholder", "output")
|
||||
and hasattr(node, "meta")
|
||||
|
||||
Reference in New Issue
Block a user