Compare commits

...

40 Commits

Author SHA1 Message Date
8690e80b9d Merge branch 'main' into bf/lite 2025-11-11 10:02:15 -08:00
24edea44b3 nit 2025-11-11 10:01:20 -08:00
ce8672c24f Fix use of TORCH_CHECK in torch/csrc/stable (#167495)
Tested by above PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167495
Approved by: https://github.com/janeyx99
ghstack dependencies: #166579, #166694, #166695, #167362
2025-11-11 17:58:30 +00:00
402c465030 [ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (#158250)
Co-authored-by: Nikhil Gupta [nikhil.gupta2@arm.com](mailto:nikhil.gupta2@arm.com)

This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch to boost LLM prefill & decode performance

**This change improves decode throughput by ~15% & reduces memory required to inference the model by 50%**

### Benchmark Setup
```
Model: meta-llama/Llama-3.1-8B
Test Platform: Neoverse V2
```
### Detailed Results

| Metric                           | With `--compile`         | Without `--compile`      |
|----------------------------------|---------------------------|---------------------------|
| Quantization Scheme              | INT4 symmetric channelwise | INT4 symmetric channelwise |
| Input Precision                  | BF16                      | BF16                      |
| Number of Layers Quantized       | 32                        | 32                        |
| Average Compression Ratio        | 87.49%                    | 87.49%                    |
| Total Quantization Time (s)      | 9.62                      | 10.32                     |
| Compile Time (First) (s)         | 134.48                    | 1.69                      |
| Compile Time (Second) (s)        | 80.44                     | 1.60                      |
| Compile Time (Subsequent) (s)    | 0.19                      | 0.22                      |
| Prefill Tokens                   | 54                        | 54                        |
| Decoded Tokens                   | 33                        | 33                        |
| Prefill Time (s)                 | 0.19                      | 0.22                      |
| Decode Time (s)                  | 0.76                      | 1.38                      |
| E2E Generation Time (s)          | 0.95                      | 1.60                      |
| Prefill Throughput (tokens/s)    | 288.13                    | 249.91                    |
| Decode Throughput (tokens/s)     | 43.42                     | 23.83                     |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158250
Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/fadara01

Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-11-11 17:50:22 +00:00
573a79fffa [OpenReg] Initialize device stream states for all devices in initOpenRegStreamsOnce (#167528)
Fixes #167527

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167528
Approved by: https://github.com/fffrog
2025-11-11 16:53:22 +00:00
4945180468 Add empty tensor check for _pad_packed_sequence (#167521)
That prevents null pointer dereference

Fixes https://github.com/pytorch/pytorch/issues/149622
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167521
Approved by: https://github.com/albanD
2025-11-11 16:46:13 +00:00
1df723e6f5 [inductor] Fix constant creation (#167398)
We ran into this issue when debugging inductor-lite. Calling `torch.tensor` within a fake mode (which is the case inside of inductor) will create a FakeTensor, which causes this FakeTensor to be used as a constant within inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167398
Approved by: https://github.com/eellison, https://github.com/BoyuanFeng
2025-11-11 16:30:46 +00:00
f9b81e23e4 [ROCm] Disable group gemm CK path when composable kernel (CK) is not enabled (#167403)
For ROCm builds without CK support, ensure use_fast_path is false so that the CK path is not triggered, since CK is currently not available in this configuration.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167403
Approved by: https://github.com/Skylion007, https://github.com/ScottTodd, https://github.com/jeffdaily
2025-11-11 16:15:51 +00:00
ffe6cc39c7 [inductor] Optimize cold compile time when cudagraphs-partition is enabled (#167132)
Summary: When cudagraphs-parittion is enabled, we have seen an increase of cold compile time in the vllm benchmark (see https://github.com/vllm-project/vllm/issues/27080). After some profiling, we found Triton compilation time increased the most. Further investigation reveals it was caused by duplicated Triton kernels not being shared among different partitions. This PR fixes the issue by reusing the Trition kernel source code cache at the top-level PythonWrapperCodegen.

In theory we could further reduce the compilation time by completely skipping compiling duplicated partitions. That can come as a furture improvement.

Some vllm benchmarking data,

```
VLLM_USE_STANDALONE_COMPILE=0 VLLM_DISABLE_COMPILE_CACHE=1 vllm bench latency -O.cudagraph_mode=PIECEWISE -O.use_inductor_graph_partition=True --model meta-llama/Meta-Llama-3.1-8
```
Before:
```
torch.compile takes 69.18 s in total
```
After:
```
torch.compile takes 26.81 s in total
```

As a refrence, this is the compile time when turning off inductor graph partition. Looks like we still have some gap to close.
```
VLLM_USE_STANDALONE_COMPILE=0 VLLM_DISABLE_COMPILE_CACHE=1 vllm bench latency -O.cudagraph_mode=PIECEWISE -O.use_inductor_graph_partition=False --model meta-llama/Meta-Llama-3.1-8B

torch.compile takes 19.41 s in total
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167132
Approved by: https://github.com/eellison
ghstack dependencies: #167131
2025-11-11 15:54:31 +00:00
db1f3f6901 [inductor] Only generate compile-time auto-tuning block in the main graph (#167131)
Summary: When cudagraphs partition and autotune_at_compile_time are enabled, currently each subgraph will generate its own auto-tuning code block and run them once by one. This PR improves it by only generating one auto-tuning code block at the main graph level and execute it once time to auto-tune all the kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167131
Approved by: https://github.com/eellison
2025-11-11 15:54:31 +00:00
43041f0a43 Remove superflous/misplaced TestFailure specs (#165989)
The tests are in class `TestInductorDynamic` which isn't affected by the `test_failures` dict which is only used as an argument to `copy_tests` for the `CommonTemplate` defined in another file.

So those have no effect.

Idea: Enhance `copy_tests` to detect unused keys

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165989
Approved by: https://github.com/benjaminglass1, https://github.com/ezyang
2025-11-11 15:36:43 +00:00
dc00842b81 [ROCm][CI] trigger magma build with gfx950 for ROCm7.1 (#167390)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167390
Approved by: https://github.com/jeffdaily
2025-11-11 15:17:37 +00:00
f1a129a6d0 Clarify that crashes/OOB accesses and not security threats (#167519)
Added note on crashes and out of bounds access in PyTorch.

Addresses https://github.com/pytorch/pytorch/issues/166881#issuecomment-3513245388

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167519
Approved by: https://github.com/albanD
2025-11-11 15:14:51 +00:00
fad48ffa62 [ROCm][CI] Match workflow names with workflow file names (#167483)
Fixes issue with uploading artifacts, which was inadvertently disabled for some renamed workflows via https://github.com/pytorch/pytorch/pull/167225

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167483
Approved by: https://github.com/jeffdaily
2025-11-11 14:45:44 +00:00
8f0fa2e52d skip a mps test since mul2_kernel not exist 2025-11-10 23:05:11 -08:00
0f120be0b1 minor doc/format improve 2025-11-10 17:41:42 -08:00
fee30c60a8 lint 2025-11-10 15:50:22 -08:00
5a8affeba1 add docs 2025-11-10 14:54:09 -08:00
0932f1ff21 Merge branch 'main' into bf/lite 2025-11-09 22:43:20 -08:00
814cb7024b support triton_kernel_wrapper_functional 2025-11-09 22:41:47 -08:00
f2a7f85f11 also fallback by default for hop 2025-11-09 16:06:14 -08:00
308414c1b6 nit 2025-11-07 17:52:20 -08:00
5c9a710c99 make _SelectiveDecomposeInterpreter private 2025-11-07 10:09:03 -08:00
a24161475f Merge branch 'main' into bf/lite 2025-11-07 10:07:03 -08:00
f79e116bfa fix more cpp_wrapper 2025-11-06 23:39:44 -08:00
db30ee1a88 check different string for cpp_wrapper tests 2025-11-06 21:54:42 -08:00
5c17c94af8 nit 2025-11-06 19:28:12 -08:00
1d56ee10de nit 2025-11-06 18:01:30 -08:00
ccfc98f8a0 support dynamic shape assertion 2025-11-06 17:35:39 -08:00
7dddcb24c9 Merge branch 'main' into bf/lite 2025-11-06 17:15:58 -08:00
cd453f5e7c more tests 2025-11-06 17:09:36 -08:00
16d170b0ef add inductor selective_decompose config 2025-11-06 17:01:40 -08:00
8c23bb9fef support both joint graph and fwd only graph 2025-11-06 16:47:56 -08:00
08517c8556 nit 2025-11-06 13:10:37 -08:00
b42d8bdc76 Merge branch 'main' into bf/lite 2025-11-06 11:15:47 -08:00
a68e77bdea selective decompose for regional compile 2025-11-05 22:47:25 -08:00
9d977f1f68 nit 2025-11-05 13:35:04 -08:00
0bba0cdb9c Merge branch 'main' into bf/lite 2025-11-05 11:47:25 -08:00
7dba73a0a7 nit 2025-11-05 11:46:18 -08:00
8ec9f7de82 init 2025-11-05 11:37:54 -08:00
34 changed files with 1433 additions and 187 deletions

View File

@ -30,7 +30,6 @@ into a tarball, with the following structure:
More specifically, `build_magma.sh` copies over the relevant files from the `package_files` directory depending on the ROCm version.
Outputted binaries should be in the `output` folder.
## Pushing
Packages can be uploaded to an S3 bucket using:

View File

@ -1,4 +1,4 @@
name: inductor-rocm
name: inductor-rocm-mi200
on:
schedule:

View File

@ -1,4 +1,4 @@
name: rocm
name: rocm-mi200
on:
push:

View File

@ -18,6 +18,8 @@ Please report security issues using https://github.com/pytorch/pytorch/security/
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat

View File

@ -3541,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
const int64_t out_features) {
auto M = inp.size(0);
TORCH_CHECK(
inp.dtype() == kFloat,
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
__func__,
" : expect input to be 32-bit float tensor.");
" : expect input to be float32 or bfloat16 tensor.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),

View File

@ -142,6 +142,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
auto batch_sizes_t = _batch_sizes.contiguous();
checkLongTensor(batch_sizes_t);
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
int64_t max_batch_size = batch_sizes[0];

View File

@ -8,6 +8,7 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h>
#include <cmath>
#include <c10/util/Unroll.h>
#include <c10/util/irange.h>
@ -793,6 +794,139 @@ bool can_use_kleidiai(
}
#endif
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
size_t m,
size_t n,
size_t k,
const uint16_t* lhs_bf16,
const uint8_t* rhs_qs4cx,
const float* rhs_scales,
uint16_t* dst_bf16,
float scalar_min,
float scalar_max,
const float* bias) {
// Roundup lambda for internal stride calculations
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
// Cast bfloat16 to float32 inline
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
float f;
std::memcpy(&f, &tmp, sizeof(f));
return f;
};
// Cast float32 to bfloat16 inline
auto cast_f32_to_bf16 = [](float f) {
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
return static_cast<uint16_t>(bits >> 16);
};
// Quantization pack lambda (channelwise QA8DX)
auto quant_pack_8bit_channelwise =
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
for (size_t i = 0; i < M; ++i) {
const uint16_t* row_ptr = src_bf16 + i * K;
// find min/max
float mn = FLT_MAX, mx = -FLT_MAX;
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
mn = std::min(mn, v);
mx = std::max(mx, v);
}
float rmin = std::min(0.0f, mn);
float rmax = std::max(0.0f, mx);
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
float recip = scale ? 1.0f / scale : 0.0f;
int32_t zp;
float des_min = rmin * scale;
float des_max = rmax * scale;
float err_min = qmin + des_min;
float err_max = qmax + des_max;
float zp_f =
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
zp_f = std::clamp(zp_f, qmin, qmax);
zp = std::lrintf(zp_f);
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
// store header
*reinterpret_cast<float*>(out_ptr) = recip;
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
out_ptr += sizeof(float) + sizeof(int32_t);
// quantize
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
q = std::clamp(
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
*out_ptr++ = static_cast<int8_t>(q);
}
}
};
// MatMul lambda (MXN x MXK -> MNXK BF16)
auto matmul_kernel = [&](size_t M,
size_t N,
size_t K,
const int8_t* lhs,
const uint8_t* rhs,
const float* scales,
uint16_t* dst,
float lo,
float hi) {
const size_t lhs_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
const size_t rhs_stride = roundup(K, 2) / 2;
for (size_t i = 0; i < M; ++i) {
const int8_t* lhs_row = lhs + i * lhs_stride;
for (size_t j = 0; j < N; ++j) {
int32_t acc = 0;
const int8_t* lptr = lhs_row;
const uint8_t* rptr = rhs + j * rhs_stride;
float lhs_scale = *reinterpret_cast<const float*>(lptr);
int32_t lhs_off =
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
lptr += sizeof(float) + sizeof(int32_t);
for (size_t t = 0; t < K; ++t) {
int32_t lv = static_cast<int32_t>(lptr[t]);
uint8_t bv = rptr[t / 2];
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
: (static_cast<int32_t>(bv >> 4) - 8);
acc += lv * rv + lhs_off * rv;
}
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
if (bias) {
res += bias[j];
}
res = std::clamp(res, lo, hi);
*dst++ = cast_f32_to_bf16(res);
}
}
};
// allocate and run
std::unique_ptr<int8_t[]> packed(
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
matmul_kernel(
m,
n,
k,
packed.get(),
rhs_qs4cx,
rhs_scales,
dst_bf16,
scalar_min,
scalar_max);
}
/**
* The Int4 quantized weights must be represented as a uint8 tensor
* For matrix multiplication with a weight shape of (N x K)
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
#if AT_KLEIDIAI_ENABLED()
if (can_use_kleidiai(scales_zeros, K, block_size)) {
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
packed_weights.resize_({weight_packed_size});
kleidiai::kai_pack_int4_rhs(
packed_weights, weights, scales_zeros, bias, N, K, block_size);
} else
#endif
{
TORCH_CHECK(
bias.has_value() == 0,
__func__,
" : Bias is unsupported in reference implementation");
packed_weights = packed_weights.to(kFloat);
auto weight_reshaped = weights.view({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
if (bias.has_value()) {
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
}
auto res = at::cat(tensors_to_cat, 0);
packed_weights.resize_(res.sizes()).copy_(res);
}
}
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
const float* rhs_scales_f32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
// required format for matmul
auto input_quant_pack_8bit_channelwise =
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
}
// Maximum/minimum int8 values
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
zero_point0 = std::min(zero_point0, qmax);
// Round to nearest integer
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = v0_s32 + nudged_zero_point0;
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[n_idx];
}
// Clamp (min-max) operation
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float* rhs_scales_fp32,
float* dst_f32,
float scalar_min,
float scalar_max) {
float scalar_max,
const float* bias) {
// Lambda for LHS quantization
auto lhs_quant_pack = [&](size_t m,
size_t k,
const float* lhs_f32,
int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
min0 = std::min(src0_0, min0);
}
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
zero_point0 = std::max(zero_point0, qmin);
zero_point0 = std::min(zero_point0, qmax);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float src0_0 = src_ptr[k_idx];
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = std::max(
std::min(
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
static_cast<int32_t>(INT8_MIN));
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
static_cast<int32_t>(kI8Min));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[col_idx];
}
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
/**
* Dynamic Input Quant 4 bit weights matmul execution flow
(INT4 Weights + FP scales + FP32 Bias)
FP32 Input Packed Buffer
| |
Quantize Cast
to INT8 to INT8
| |
v v
INT8 Input INT8 Weights
\ /
\ /
\ /
INT8 Matrix Multiplication
|
v
FP32 Dequantized and Accumulate in FP32
|
v
FP32 Final Output
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
* Dynamic INT4 weight-only MatMul with per-row input quantization.
*
* Execution Flow:
*
* (INT4 Weights + FP Scales [+ optional Bias])
*
* Input (FP32 or BF16) Packed Weight Buffer
* | |
* Row-wise Quantization (INT8) |
* | |
* INT8 Input Activation INT4 Quantized Weights + Scales
* \ /
* \ /
* Quantized Matrix Multiply
* |
* Output Tensor (BF16 or FP32)
*
* Notes:
* - Groupwise kernels expect BF16 scales
* - Channelwise kernels expect FP32 scales
* - Bias is currently unsupported in fallback path
*/
void dyn_quant_matmul_4bit_kernel(
const Tensor& output,
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface internally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
} else
#endif
{
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
const auto weights_size = N * K / 2;
// The weights needs to be in uint8_t data type after quantization
auto extracted_weights =
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
auto float32_scales =
(packed_weights.narrow(
0, weights_size, packed_weights.size(0) - weights_size))
.to(kFloat);
uint8_t* rhs_4bit =
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M,
N,
K,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M,
N,
K,
block_size,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else {
TORCH_CHECK(
block_size == K || (!(block_size % 32) && !(K % block_size)),
__func__,
": Group size should be multiple 32 or in_features [",
K,
"]. Provided ",
block_size);
{
void* input = inp.data_ptr();
void* dst = output.data_ptr();
// Extract weights, sclaes and biases form from packed tensor
const int weights_elements = N * K / 2;
const int scale_elements = N * (K / block_size);
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
float* weight_scales = float32_scales.data_ptr<float>();
void* bias_data = nullptr;
if (bias_elements) {
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
bias_data = float32_bias.data_ptr();
}
// 2 elements of 4 bit weights are packed into 1 uint8 packet
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
// Dispatch to reference kernels
if (inp.scalar_type() == at::kBFloat16) {
// BF16 input, BF16 output
constexpr float BF16_MAX = 3.38953139e+38f;
constexpr float BF16_MIN = -BF16_MAX;
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
M, N, K,
(uint16_t*)input, weights_4bit, weight_scales,
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
}
} else if (inp.scalar_type() == at::kFloat) {
// FP32 input, FP32 output
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M, N, K,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M, N, K, block_size,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
}
} else {
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
}
}
}
}
} // anonymous namespace
}
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)

View File

@ -669,9 +669,12 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
// On non CK system(w/ ROCm), make sure use_fast_path is false
#if defined(USE_ROCM_CK_GEMM)
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif //USE_ROCM_CK_GEMM
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
@ -680,7 +683,11 @@ std::optional<c10::ScalarType> out_dtype) {
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
#if defined(USE_ROCM_CK_GEMM)
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#else
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
#endif //USE_ROCM_CK_GEMM
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);

View File

@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
if (weight.scalar_type() == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl) {
const int64_t bl,
at::ScalarType tensor_dtype) {
size_t packed_size = n * k;
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
if (tensor_dtype == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
}
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
});
}
void kai_quant_pack_lhs_int4_mm(
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k) {
// Kernel IDs for GEMM and GEMV
constexpr kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
constexpr kai_kernel_id gemv_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
// Get total threads and select kernel
const int64_t total_threads = at::get_num_threads();
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm() && m > 1) {
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
}
// Thread blocking parameters
const int64_t n_step = kernel_packet.ukernel.get_n_step();
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
const uint8_t* lhs_native_mtx_bf16 =
reinterpret_cast<const uint8_t*>(input.data_ptr());
const uint8_t* rhs_packed_mtx_qs4cx =
reinterpret_cast<const uint8_t*>(weight.data_ptr());
uint8_t* lhs_packed_base = lhs_packed.get();
constexpr int32_t element_size = sizeof(uint16_t);
const size_t lhs_stride = k * element_size;
const size_t dst_stride = n * element_size;
// LHS quantization packing
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
const size_t src_stride = vec_per_thread * lhs_stride;
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.kai_run_lhs_quant_pack(
vec_num,
k,
mr,
kr,
sr,
0,
(const uint16_t*)lhs_src_ptr,
lhs_stride,
lhs_packed_ptr);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
lhs_quant_pack(thread_id);
}
});
// Matrix multiplication
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
auto mm = [=, &kernel_packet](int64_t thread_id) {
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(
thread_id * vec_per_thread, k);
auto dst_ptr = dst_act_mtx_bf16 +
kernel_packet.ukernel.get_dst_offset(
0, thread_id * vec_per_thread, dst_stride);
const int64_t vec_num = (thread_id == num_threads - 1)
? (n - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.ukernel.run_matmul(
m,
vec_num,
k,
lhs_packed_base,
rhs_packed_ptr,
(uint16_t*)dst_ptr,
dst_stride,
element_size, // dst_stride_col
-FLT_MAX,
FLT_MAX);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
mm(thread_id);
}
});
}
void kai_quant_pack_lhs_int4_mm(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else if (!(bl % 32) && !(k % bl)) {
const auto input_dtype = input.dtype();
if (input_dtype == at::kBFloat16) {
if (cpuinfo_has_arm_bf16()) {
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
}
} else if (input_dtype == at::kFloat) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
}
} else if ((bl % 32 == 0) && (k % bl == 0)) {
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
output, input, weight, m, n, k, bl);
}

View File

@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl);
const int64_t bl,
at::ScalarType tensor_dtype = at::kFloat);
/**
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )

View File

@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
const auto scales_data = scales.data_ptr<float>();
const auto scales_data = scales.to(kFloat).data_ptr<float>();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(

View File

@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return channelwise_8bit_4bit_kernels.at(id);
}
// Kernel Mapping - BF16 Channelwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
bf16_channelwise_8bit_4bit_kernels = {
{kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return bf16_channelwise_8bit_4bit_kernels.at(id);
}
} // namespace at::native::kleidiai
#endif

View File

@ -10,21 +10,32 @@
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
namespace at::native::kleidiai {
enum class kai_kernel_id {
// FP32 inputs, 4-bit weights, FP32 output
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
0, // Groupwise 4 bit GEMV
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
1, // Groupwise 4 bit GEMM
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
2, // Channelwise 4 bit GEMV
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
3 // Channelwise 4 bit GEMM
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
// BF16 inputs, 4-bit weights, BF16 output
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
4, // Channelwise 4-bit GEMV with BF16 input/output
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
5 // Channelwise 4-bit GEMM with BF16 input/output
};
// Channelwise Kernel mapping
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
// bf16 Channelwise Kernel mapping
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const void* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
const uint8_t* rhs,
const float* bias,
const float* scale,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
};
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
// Groupwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(

View File

@ -140,6 +140,11 @@ static void initDeviceStreamState(DeviceIndex device_index) {
static void initOpenRegStreamsOnce() {
c10::call_once(init_flag, initGlobalStreamState);
for (const auto i : c10::irange(num_devices)) {
c10::call_once(
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
}
if (current_streams) {
return;
}
@ -202,8 +207,6 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
if (device_index == -1) {
device_index = current_device();
}
c10::call_once(
device_flags[device_index], initDeviceStreamState, device_index);
auto pri_idx =
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
const auto idx = get_idx(priority_counters[device_index][pri_idx]);

View File

@ -4101,6 +4101,53 @@ if HAS_CUDA_AND_TRITON:
compiled_out = compiled_foo(x)
self.assertEqual(eager_out, compiled_out)
# Use autotune_at_compile_time=True to test standalone_compile
@parametrize("autotune_at_compile_time", [True, False])
@config.patch("graph_partition", True)
def test_graph_partition_kernel_reuse(self, autotune_at_compile_time):
def foo(x):
# partition 1
x1 = x @ x
y1 = x1 + 1
z_cpu = y1.cpu() + 1
# partition 2
# partition 2 should reuse the fused triton kernel generated
# in partition 1
x2 = z_cpu.to("cuda") @ z_cpu.to("cuda")
y2 = x2 + 1
return y1, y2
with config.patch(
"triton.autotune_at_compile_time", autotune_at_compile_time
):
compiled_foo = torch.compile(foo)
x = torch.randn((20, 20), device="cuda")
eager_out = foo(x)
compiled_out, code = run_and_get_code(compiled_foo, x)
self.assertEqual(eager_out, compiled_out)
if autotune_at_compile_time:
# auto-tuning block should only appear once. We generate auto-tuning code
# for all the kernels no matter if they are defined in the main graph or
# subgraph, to avoid the overhead of executing multiple auto-tuning code blocks.
FileCheck().check_count(
"Compile-time auto-tuning block", 1, exactly=True
).run(code[0])
# triton_poi_fused_add_ should appear twice, first in the auto-tuning block,
# and then in the main code block
FileCheck().check_count(
"def triton_poi_fused_add_", 2, exactly=True
).run(code[0])
# cpu kernel definition should only appence once, not in the auto-tuning block
FileCheck().check_count(
"cpp_fused__to_copy_add_1 = ", 1, exactly=True
).run(code[0])
else:
# triton_poi_fused_add_ should appear once, because of kernel reuse
FileCheck().check_count(
"def triton_poi_fused_add_", 1, exactly=True
).run(code[0])
def test_meta_tensor(self):
def foobar(x, y):
return x * 2, y * 3

View File

@ -4,8 +4,9 @@ from functools import partial
from unittest import skipIf
import torch
from torch._inductor import config
from torch._inductor.ir import Pointwise
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.lowering import make_fallback, make_pointwise, register_lowering
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.virtualized import ops
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
@ -237,6 +238,17 @@ class TestCustomLowering(InductorTestCase):
out2 = fn_opt(a, b)
self.assertEqual(out1, out2)
@config.patch(joint_graph_constant_folding=False)
def test_constant_creation(self):
class M(torch.nn.Module):
def forward(self, x):
return x + torch.tensor(1)
make_fallback(torch.ops.aten.lift_fresh_copy.default)
self.assertTrue(
torch.allclose(torch.compile(M())(torch.ones(3)), torch.ones(3) + 1)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -30,6 +30,7 @@ import numpy as np
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.aoti_eager
import torch.fx.traceback as fx_traceback
import torch.nn as nn
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
from torch._dispatch.python import enable_python_dispatcher
@ -2476,12 +2477,11 @@ class CommonTemplate:
b_int8pack, b_scales = convert_weight_to_int8pack(b)
self.common(fn, (a, b_int8pack, b_scales, c))
@xfail_if_mps_unimplemented
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
def test__dyn_quant_pack_4bit_weight(self):
def test__dyn_quant_pack_4bit_weight_fp32(self):
q_group = 32
k = 128
n = 128
@ -2512,12 +2512,46 @@ class CommonTemplate:
self.common(fn, (b, in_features, out_features))
@xfail_if_mps_unimplemented
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
def test__dyn_quant_pack_4bit_weight_bf16(self):
q_group = 32
k = 128
n = 128
torch.manual_seed(1)
b = torch.rand((k, n), dtype=torch.bfloat16)
in_features = b.size(0)
out_features = b.size(1)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def fn(b, in_features, out_features):
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
return b_int4pack
self.common(fn, (b, in_features, out_features))
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
def test__dyn_quant_matmul_4bit(self):
def test__dyn_quant_matmul_4bit_fp32_input(self):
q_group = 32
m = 32
k = 128
@ -2557,6 +2591,60 @@ class CommonTemplate:
self.common(fn, (a, q_group, in_features, out_features))
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
def test__dyn_quant_matmul_4bit_bf16_input(self):
m = 32
k = 128
n = 128
q_group = k
torch.manual_seed(1)
a = torch.rand((m, k), dtype=torch.bfloat16)
b = torch.rand((k, n), dtype=torch.bfloat16)
# codegen_dynamic_shape test fails without explicitly marking these dynamic
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 1)
in_features = b.size(0)
out_features = b.size(1)
if not self.is_dtype_supported(torch.bfloat16):
raise unittest.SkipTest(
f"torch.bfloat16 not supported for device {self.device}"
)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def fn(a, q_group, in_features, out_features):
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
res = torch.ops.aten._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
return res
self.common(fn, (a, q_group, in_features, out_features), atol=1, rtol=0.5)
def test_expanded_reduction(self):
def fn(x, y):
z = x * y
@ -13564,6 +13652,224 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.."
FileCheck().check_regex(size_assert_pattern).run(code)
def test_lite_mode_fallback(self):
def f(x):
z = x.sin()
return z.cos()
f = torch.compile(f, mode="lite")
_, code = run_and_get_code(f, torch.randn(2, device=self.device))
# Checks that aten ops are kept and run
if config.cpp_wrapper:
FileCheck().check("aoti_torch_call_dispatcher(").check("aten::sin").check(
"aoti_torch_call_dispatcher("
).check("aten::cos").run(code[0])
else:
FileCheck().check("torch.ops.aten.sin.default(").check(
"torch.ops.aten.cos.default("
).run(code[0])
# Checks that no triton code run in the generated code
self.assertFalse(".run(" in code[0])
# skip cpu test since rms norm is always decomposed on cpu
def test_lite_mode_not_decompose(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
def f(x, shape):
y = x + 1
z = torch.ops.aten._fused_rms_norm(y, shape, None, None)
return z[0] + z[1]
f = torch.compile(f, mode="lite")
x = torch.randn(2, 3, device=self.device)
_, code = run_and_get_code(f, x, [2, 3])
if config.cpp_wrapper:
FileCheck().check(
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda__fused_rms_norm("
).run(code[0])
else:
FileCheck().check("torch.ops.aten._fused_rms_norm.default(").run(code[0])
if config.cpp_wrapper:
# arg type List[int] is not yet supported by custom_op_wrapper
pass
else:
x = torch.randn(2, 3, device=self.device, requires_grad=True)
_, codes = run_fw_bw_and_get_code(lambda: f(x, [2, 3]))
self.assertEqual(len(codes), 2)
FileCheck().check("torch.ops.aten._fused_rms_norm.default(").run(code[0])
def test_lite_regional_compile_flex_attention(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
def _squared(score, b, h, m, n):
return score * score
def mask_mod(b, h, q, k):
return q >= 0
a = 12
b = 64
block_mask = create_block_mask(
mask_mod, None, None, a * b, a * b, device=self.device
)
def fn(x):
x = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
return torch.cos(x)
x = torch.randn(
1,
1,
a * b,
b,
dtype=torch.bfloat16,
device=self.device,
requires_grad=True,
)
opt_fn = torch.compile(
fn,
mode="lite",
fullgraph=True,
)
# Check that inductor compilation is called twice
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
self.assertEqual(len(codes), 2)
@unittest.skipIf(
config.cpp_wrapper,
"codegen invoke_subgraph is not implemented for cpp wrapper",
)
def test_lite_regional_compile_invoke_subgraph(self):
# Checks that get_attr nodes custom metadata is propagated
@torch.compiler.nested_compile_region
def gn(x):
return torch.sin(x)
def fn(x):
x = x + 1
with fx_traceback.annotate({"compile_with_inductor": 0}):
z = gn(x)
return torch.sigmoid(z)
opt_fn = torch.compile(fn, mode="lite", fullgraph=True)
x = torch.randn(10, requires_grad=True)
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
self.assertEqual(len(codes), 2)
@unittest.skipIf(
config.cpp_wrapper,
"codegen triton_kernel_wrapper_functional is not implemented for cpp wrapper",
)
def test_lite_triton_kernel_wrapper_functional(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
from torch._higher_order_ops.triton_kernel_wrap import (
kernel_side_table,
triton_kernel_wrapper_functional,
)
from torch.testing._internal.triton_utils import mul2_kernel
kernel_side_table.reset_table()
def f(x, output):
out = triton_kernel_wrapper_functional(
kernel_idx=kernel_side_table.add_kernel(mul2_kernel),
constant_args_idx=kernel_side_table.add_constant_args(
{"n_elements": output.numel(), "BLOCK_SIZE": 16}
),
grid=[(x.numel(),)],
tma_descriptor_metadata={},
kwargs={
"in_ptr0": x,
"out_ptr": output,
},
tensors_to_clone=["in_ptr0", "out_ptr"],
)
return out["out_ptr"]
t1 = torch.rand(5, device=self.device)
t2 = torch.rand(5, device=self.device)
compiled_f = torch.compile(f, mode="lite")
out = compiled_f(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(out, t2)
def test_lite_regional_compile_repeated_blocks(self):
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
a = fn(x, y)
return fn(a, y)
mod = Mod()
opt_mod = torch.compile(
mod,
mode="lite",
fullgraph=True,
)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
self.assertEqual(len(codes), 2)
def test_lite_dynamic_shape_assertion(self):
class Model(torch.nn.Module):
def forward(self, c):
d = torch.concat([c, c], dim=0)
with fx_traceback.annotate({"compile_with_inductor": "my_region"}):
d = d + 1
return d
model = Model()
model = torch.compile(
model,
mode="lite",
fullgraph=True,
)
c = torch.randn((64, 32), device=self.device)
torch._dynamo.decorators.mark_unbacked(c, 0)
_, code = run_and_get_code(model, c)
# Checks that unbacked symint assertions are kept
if config.cpp_wrapper:
FileCheck().check_regex(r"if \(!\(u.* >= 0L\)\)").check_regex(
"Expected u.* >= 0 but receive"
).run(code[0])
else:
FileCheck().check_regex(r"if not \(u.* >= 0\):").check_regex(
r"raise RuntimeError\('u.* >= 0'\)"
).run(code[0])
@lowering.force_fallback(aten.sort.default)
@unittest.skipIf(
config.cpp_wrapper,

View File

@ -31,7 +31,6 @@ from torch.testing._internal.common_utils import (
serialTest,
TEST_CUDA_MEM_LEAK_CHECK,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
@ -93,17 +92,6 @@ if not torch._inductor.config.cpp_wrapper:
("cuda",)
)
if TEST_WITH_ROCM:
# Tensor-likes are not close
test_failures["test_dynamic_stride_nobreak"] = TestFailure(
("cpu", "cuda"), is_skip=True
)
test_failures["test_item_to_inputs_kernel_nobreak"] = TestFailure(
("cpu", "cuda"), is_skip=True
)
test_failures["test_unbacked_reduction"] = TestFailure(("cpu"), is_skip=True)
if any(os.getenv("BUILD_ENVIRONMENT", "").endswith(x) for x in ("-debug", "-asan")):
# Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073
# After https://github.com/pytorch/pytorch/pull/161586, starts failing UBSAN so we can't even xfail.

View File

@ -492,6 +492,36 @@ class PackedSequenceTest(TestCase):
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
)
def test_empty_packed_sequence(self):
"""
Regression test for https://github.com/pytorch/pytorch/issues/149622
Tests that pad_packed_sequence and unpack_sequence handle empty tensors
without segmentation fault (CVE-2025-2998, CVE-2025-2999)
"""
# Test case 1: pad_packed_sequence with empty tensors
# Previously caused segmentation fault
empty_data = torch.randn(0, 5)
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
empty_packed = rnn_utils.PackedSequence(
empty_data, empty_batch_sizes, None, None
)
# Should not crash - either return empty result or raise informative error
with self.assertRaises(RuntimeError):
rnn_utils.pad_packed_sequence(empty_packed, batch_first=True)
# Test case 2: unpack_sequence with empty tensors
# Previously caused segmentation fault
empty_data = torch.tensor([])
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
packed = rnn_utils.PackedSequence(
data=empty_data, batch_sizes=empty_batch_sizes
)
# Should not crash - either return empty list or raise informative error
with self.assertRaises(RuntimeError):
rnn_utils.unpack_sequence(packed)
if __name__ == "__main__":
run_tests()

View File

@ -7798,7 +7798,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test__dyn_quant_matmul_4bit(self, device, m, k, n):
def test__dyn_quant_matmul_4bit_fp32(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
@ -7870,7 +7870,86 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n):
def test__dyn_quant_matmul_4bit_bf16(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
torch.manual_seed(1)
a_bfloat16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b_bfloat16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
in_features = b_bfloat16.size(0)
out_features = b_bfloat16.size(1)
q_group = in_features
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def dyn_quant_matmul_4bit(
a, b_int4pack, q_group, in_features, out_features
):
return torch.ops.aten._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight(
b_bfloat16, in_features, out_features
)
dtypes = [torch.bfloat16]
for dtype in dtypes:
a = a_bfloat16.to(dtype=dtype)
b = b_bfloat16.to(dtype=dtype)
ref = torch.mm(a, b)
res = dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
# Mean relative error check
expected_mean_err = 0.00952
mean_err_tol = 0.005 # allow small deviation (±0.005)
mean_err = ((res - ref).abs() / ref.abs().clamp(min=1e-5)).mean()
self.assertTrue(
abs(mean_err - expected_mean_err) < mean_err_tol,
f"Mean relative error {mean_err:.6f} deviates from expected {expected_mean_err}"
)
# Elementwise relative error check
elementwise_diff = (res - ref).abs()
elementwise_relative_error = elementwise_diff / ref.abs().clamp(min=torch.finfo(ref.dtype).eps)
self.assertTrue(
torch.all(elementwise_relative_error < 0.070),
"Some elements have relative error >= 7%"
)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test_compile_dyn_quant_matmul_4bit_fp32(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
@ -7928,6 +8007,83 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
)
@onlyNativeDeviceTypes
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test_compile_dyn_quant_matmul_4bit_bf16(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
torch.manual_seed(1)
a_bfloat16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
b_bfloat16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
in_features = b_bfloat16.size(0)
out_features = b_bfloat16.size(1)
q_group = in_features
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b_bfloat16, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16)
@torch.compile
def dyn_quant_matmul_4bit(
a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features
):
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return torch._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
res = dyn_quant_matmul_4bit(
a_bfloat16,
b_uint8,
b_scales_and_zeros,
q_group,
in_features,
out_features,
)
ref = torch.mm(a_bfloat16, b_bfloat16)
# === Accuracy checks ===
# Mean relative error check
expected_mean_err = 0.00952
mean_err_tol = 0.005 # allow small deviation (±0.005)
mean_err = ((res - ref).abs() / ref.abs().clamp(min=1e-5)).mean()
self.assertTrue(
abs(mean_err - expected_mean_err) < mean_err_tol,
f"Mean relative error {mean_err:.6f} deviates from expected {expected_mean_err}"
)
# Avoid divide-by-zero with clamp
denominator = ref.abs().clamp(min=torch.finfo(ref.dtype).eps)
# Compute elementwise relative error — always non-negative
elementwise_relative_error = (res - ref).abs() / denominator
# Check if all elements are within 6% error
assert torch.all(elementwise_relative_error >= 0), "Relative error should never be negative"
self.assertTrue(
torch.all(elementwise_relative_error < 0.070),
"Some elements have relative error >= 7%"
)
@onlyCPU
@parametrize("m", [32, 64])
@parametrize("k", [32, 64])
@parametrize("n", [48, 64])

View File

@ -103,6 +103,17 @@ from .utils import (
_thread_local = threading.local()
@contextmanager
def maybe_skip_decompose(aot_config: AOTConfig):
old_decomp = aot_config.decompositions
try:
if config.selective_decompose:
aot_config.decompositions = {}
yield
finally:
aot_config.decompositions = old_decomp
# Saved tensor hooks context
# Compiled saved tensor hooks are convenient way to inline some logic in the graphs
# for saved nodes from forward to backward. (E.g. activations quantization)
@ -196,11 +207,28 @@ def aot_stage1_graph_capture(
# deterministic TLS can be different
aot_state.fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]]
if aot_state.needs_autograd and not aot_config.pre_dispatch:
# FYI: this being moved to trigger in export is new, seems fine!
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
with maybe_skip_decompose(aot_config):
# if config.selective_decompose, skip decomposition and apply selective_decompose
# after we get the joint graph. See [Note: Selective Decomposition] for details.
if aot_state.needs_autograd and not aot_config.pre_dispatch:
# FYI: this being moved to trigger in export is new, seems fine!
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
(
graph,
updated_flat_args,
updated_flat_args_descs,
maybe_subclass_meta,
) = aot_dispatch_autograd_graph(
flat_fn,
aot_state.flat_args,
aot_state.flat_args_descs,
aot_config,
fw_metadata=aot_state.fw_metadata,
)
else:
graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = (
aot_dispatch_autograd_graph(
aot_dispatch_base_graph(
flat_fn,
aot_state.flat_args,
aot_state.flat_args_descs,
@ -208,15 +236,17 @@ def aot_stage1_graph_capture(
fw_metadata=aot_state.fw_metadata,
)
)
else:
graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = (
aot_dispatch_base_graph( # type: ignore[assignment]
flat_fn,
aot_state.flat_args,
aot_state.flat_args_descs,
aot_config,
fw_metadata=aot_state.fw_metadata,
)
if config.selective_decompose:
from torch.fx.experimental.proxy_tensor import selective_decompose
from torch.fx.passes.regional_inductor import _needs_inductor_compile
graph = selective_decompose(
graph,
*updated_flat_args,
decomposition=aot_config.decompositions,
should_decompose=_needs_inductor_compile,
trace_joint_graph=aot_state.needs_autograd and not aot_config.pre_dispatch,
)
return AOTGraphCapture(

View File

@ -374,6 +374,13 @@ saved_tensors_hooks_filtering_mode = "donated"
# This callback is invoked on the joint graph before partitioning
joint_custom_pass: Callable = None # type: ignore[assignment]
# Note [Selective Decomposition]
# This config allows selective decomposition of certain operators in the graph.
# When True, it does NOT decompose any nodes, except those nodes that users explicitly
# annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
# on to explicitly annotate. This is currently only used by inductor lite mode.
selective_decompose: bool = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -315,6 +315,25 @@ def aot_compile(
)
lite_mode_options = {
# Fallback by default unless users explicitly annotated with
# regional inductor compile.
"fallback_by_default": True,
"selective_decompose": True,
# Disable reorder optimizations
"reorder_for_peak_memory": False,
"reorder_for_compute_comm_overlap": False,
"triton.reorder_for_reducing_graph_partitions": False,
# Disable pre-, joint-, post-grad passes
"use_pre_grad_passes": False,
"use_joint_graph_passes": False,
"use_post_grad_passes": False,
# Disable dead code elimination (dce) and buffer reuse
"use_dce": False,
"allow_buffer_reuse": False,
}
def list_mode_options(
mode: Optional[str] = None, dynamic: Optional[bool] = None
) -> dict[str, Any]:
@ -332,6 +351,8 @@ def list_mode_options(
mode_options: dict[str, dict[str, bool]] = {
"default": {},
# lite backend for opt-in optimizations
"lite": lite_mode_options,
# enable cudagraphs
"reduce-overhead": {
"triton.cudagraphs": True,

View File

@ -2259,7 +2259,7 @@ class PythonWrapperCodegen(CodeGen):
gpu: bool = True,
cpp_definition: Optional[str] = None,
):
if config.triton.autotune_at_compile_time:
if config.triton.autotune_at_compile_time and gpu:
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=metadata
)
@ -3745,6 +3745,13 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
super().__init__()
root = self.get_root_graph()
# Only generate auto-tuning block in the main graph
self.kernel_autotune_defs = root.kernel_autotune_defs
self.kernel_autotune_calls = root.kernel_autotune_calls
# Only store kernel src to name mapping in the main graph
self.src_to_kernel = root.src_to_kernel
def set_launcher_fn_name(self) -> None:
# This sets up the name of the function containing the launcher code of
# the subgraph.
@ -3837,3 +3844,16 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
# )
self.parent_wrapper.write_get_raw_stream_header_once()
@cache_on_self
def get_root_graph(self) -> PythonWrapperCodegen:
root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
while isinstance(root, SubgraphPythonWrapperCodegen):
root = root.parent_wrapper
assert isinstance(root, PythonWrapperCodegen)
return root
def generate_and_run_autotune_block(self):
# Only execute auto-tuning block in the main graph
pass

View File

@ -508,6 +508,9 @@ def _recursive_pre_grad_passes(
log_pt2_compile_event=True,
dynamo_compile_column_us="pre_grad_pass_time_us",
):
if not config.use_pre_grad_passes:
return gm
add_passes = config.add_pre_grad_passes
remove_passes = config.remove_pre_grad_passes
for subgraph_name in _get_subgraph_names(gm):
@ -526,6 +529,9 @@ def _recursive_joint_graph_passes(
log_pt2_compile_event=True,
dynamo_compile_column_us="joint_graph_pass_time_us",
):
if not config.use_joint_graph_passes:
return
# invoke_subgraph already runs the _recursive_joint_graph_passes. In
# AOTAutograd, `run_joint_graph_passes_on_hops` partitions the
# invoke_subgraph HOP before calling the partitioner on the outer graph.
@ -544,6 +550,9 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
log_pt2_compile_event=True,
dynamo_compile_column_us="post_grad_pass_time_us",
):
if not config.use_post_grad_passes:
return
for subgraph_name in _get_subgraph_names(gm):
subgraph = getattr(gm, subgraph_name)
_recursive_post_grad_passes(subgraph, is_inference)
@ -2708,7 +2717,10 @@ def _compile_fx_main(
is_valid_aoti_model_name()
with functorch_config.patch(unlift_effect_tokens=True):
with functorch_config.patch(
unlift_effect_tokens=True,
selective_decompose=config.selective_decompose,
):
gm, graph_signature = aot_export_module(
model_,
example_inputs_,
@ -2768,7 +2780,10 @@ def _compile_fx_main(
V.set_fake_mode(fake_mode),
torch._guards.tracing(tracing_context),
compiled_autograd._disable(),
functorch_config.patch(unlift_effect_tokens=True),
functorch_config.patch(
unlift_effect_tokens=True,
selective_decompose=config.selective_decompose,
),
):
try:
return aot_autograd(

View File

@ -550,6 +550,32 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
# Fall back to ATen for all ops by default, except those nodes that users explicitly
# annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
# on to explicitly annotate. This is currently only used by inductor lite mode.
# Different from default inductor mode that fuses all nodes, this config enables an
# opt-in mode that only fuse for user-specified nodes. The motivation is to provide
# guaranteed numeric correctness and give full control to users.
fallback_by_default: bool = False
# This config allows selective decomposition of certain operators in the graph.
# Currently the only use case is to patch the same-name config in functorch, for
# inductor lite mode. See more details in [Note: Selective Decomposition]
selective_decompose: bool = False
# Use dead code elimination
use_dce: bool = True
# Use fx graph passes
use_pre_grad_passes: bool = True
use_joint_graph_passes: bool = True
use_post_grad_passes: bool = True
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
@ -1373,6 +1399,10 @@ class triton:
default=False,
)
# reorder nodes to minimize the number of graph partitions while
# not incurring large memory overhead
reorder_for_reducing_graph_partitions: bool = True
# assertions on the fast path
fast_path_cudagraph_asserts = False

View File

@ -110,6 +110,7 @@ from .utils import (
maybe_get_suppress_shape_guards_ctx,
normalize_name,
should_assume_input_aligned,
should_fallback_by_default,
SUPPORTED_MKLDNN_DEVICES,
ValueWithLineMap,
)
@ -1634,6 +1635,20 @@ class GraphLowering(torch.fx.Interpreter):
*args, # type: ignore[possibly-undefined]
**kwargs, # type: ignore[possibly-undefined]
)
elif (
n.op == "call_function"
and isinstance(
n.target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
)
and should_fallback_by_default(n)
):
# this path supports fallback due to inductor lite mode. It supports
# both OpOverload and HOPs (e.g., triton_kernel_wrapper_functional).
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, # type: ignore[possibly-undefined]
**kwargs, # type: ignore[possibly-undefined]
)
elif (
n.op == "call_function"
and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation

View File

@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.fx.node import Node
from torch.utils._ordered_set import OrderedSet
from torch.utils._python_dispatch import _disable_current_modes
from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
from torch.utils._sympy.symbol import SymT
@ -6135,9 +6136,12 @@ class ExternKernel(InputsKernel):
if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
return ShapeAsConstantBuffer(expr=x)
if isinstance(x, Constant):
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
# We need to unset fake mode, or else the torch.tensor() call will
# turn into a FakeTensor
with _disable_current_modes():
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
if isinstance(x, ConstantBuffer):
return x
if isinstance(x, TensorBox):

View File

@ -2742,8 +2742,9 @@ class Scheduler:
self.process_grouped_nodes()
if (
torch._inductor.config.graph_partition
and torch._inductor.config.triton.cudagraphs
config.graph_partition
and config.triton.cudagraphs
and config.triton.reorder_for_reducing_graph_partitions
):
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
@ -3191,6 +3192,9 @@ class Scheduler:
"""
Remove any nodes without users
"""
if not config.use_dce:
return
# self.nodes is in topological order, so by iterating in reverse order
# we have visited (and potentially removed) all users before visiting a
# given node.

View File

@ -58,6 +58,7 @@ import torch
import torch.utils._pytree as pytree
from torch._inductor.analysis.device_info import datasheet_tops
from torch._inductor.runtime.hints import DeviceProperties
from torch.fx.passes.regional_inductor import _needs_inductor_compile
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_flatten, tree_map_only
@ -4036,3 +4037,40 @@ def load_template(name: str, template_dir: Path) -> str:
"""Load a template file and return its content."""
with open(template_dir / f"{name}.py.jinja") as f:
return f.read()
def should_fallback_by_default(node: torch.fx.Node) -> bool:
"""Decide whether fallback for a node. This is only used in inductor lite mode."""
target = node.target
assert isinstance(
target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
), f"Expected OpOverload or HigherOrderOperator, but found {type(target)}"
if not config.fallback_by_default:
return False
# some ops need special handle due to dynamic shapes. we can avoid
# fallback if they do not impact numerics.
skip_fallback_due_to_dynamic_shape = OrderedSet(
[
torch.ops.aten._assert_scalar.default,
torch.ops.aten.lift_fresh_copy.default,
]
)
if target in skip_fallback_due_to_dynamic_shape:
return False
# Most hops have registered lowering. We should follow the lowering and not fallback.
# However, in rare cases, hops may not register lowering, such as
# torch.ops.higher_order.triton_kernel_wrapper_functional. We should fallback for
# these hops.
fallback_hops = OrderedSet(
[torch.ops.higher_order.triton_kernel_wrapper_functional]
)
if isinstance(target, torch._ops.HigherOrderOperator):
return target in fallback_hops
return not _needs_inductor_compile(node)

View File

@ -3722,6 +3722,7 @@ def kai_roundup(a: int, b: int) -> int:
def get_kai_packed_weight_size(n_bits, N, K, groupsize):
if n_bits == 4:
# Works for both fp32 and bf16 Kernels
if groupsize == K: # channelwise
# dotprod params only [1x8x32_neon_dotprod]
kai_nr = 8
@ -3851,6 +3852,8 @@ def meta__dyn_quant_pack_4bit_weight(
)
return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
packed_weight_size = weights.numel() + scales_zeros.numel()
if bias is not None:
packed_weight_size += bias.numel()
return weights.new_empty(packed_weight_size, dtype=torch.float)
@ -3864,8 +3867,12 @@ def meta__dyn_quant_matmul_4bit(
):
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
torch._check(
inp.dtype == torch.float32,
lambda: f"expected input to be f32, got {inp.dtype}",
(inp.dtype == torch.float32)
or (inp.dtype == torch.bfloat16 and block_size == in_features),
lambda: (
f"expected input to be f32 or bf16 (bf16 requires block_size == in_features), "
f"got {inp.dtype} with block_size={block_size} and in_features={in_features}"
),
)
M = inp.size(0)
return inp.new_empty(M, out_features, dtype=inp.dtype)

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/csrc/stable/device_struct.h>
@ -120,7 +119,7 @@ struct FromImpl<ScalarType> {
case ScalarType::UInt64:
return from(aoti_torch_dtype_uint64());
default:
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported ScalarType, please file an issue describing your use case.");
}
@ -151,7 +150,7 @@ struct FromImpl<DeviceType> {
case DeviceType::PrivateUse1:
return from(aoti_torch_device_type_privateuse1());
default:
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported DeviceType, please file an issue describing your use case.");
}
@ -379,7 +378,7 @@ struct ToImpl<ScalarType> {
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
return ScalarType::UInt64;
} else {
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported ScalarType ",
std::to_string(shim_scalartype),
@ -409,7 +408,7 @@ struct ToImpl<DeviceType> {
} else if (shim_devicetype == aoti_torch_device_type_privateuse1()) {
return DeviceType::PrivateUse1;
} else {
TORCH_CHECK(
STD_TORCH_CHECK(
false,
"Not yet supported DeviceType ",
std::to_string(shim_devicetype),

View File

@ -42,6 +42,7 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from torch._logging import trace_structured
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_impls import fast_detach
from torch._subclasses.fake_tensor import (
FakeTensor,
@ -90,6 +91,7 @@ __all__ = [
"dispatch_trace",
"make_fx",
"DecompositionInterpreter",
"selective_decompose",
"py_sym_types",
"get_innermost_proxy_mode",
"get_proxy_mode",
@ -1881,6 +1883,93 @@ class DecompositionInterpreter(fx.Interpreter):
return super().run(*args, **kwargs) # type: ignore[arg-type]
class _SelectiveDecomposeInterpreter(fx.Interpreter):
def __init__(
self,
module: fx.GraphModule,
should_decompose: Callable[[fx.Node], bool],
decomposition_table: Mapping[OpOverload, Callable],
**kwargs: object,
) -> None:
"""
For all nodes in `module`, selectively decompose if is `should_decompose`,
following the given `decomposition_table`.
"""
super().__init__(module, **kwargs) # type: ignore[arg-type]
self.should_decompose = should_decompose
self.decomposition_table = decomposition_table
@staticmethod
def recursive_wrap(
gm: fx.GraphModule,
should_decompose: Callable[[fx.Node], bool],
decomposition_table: Mapping[OpOverload, Callable],
**kwargs: object,
) -> _SelectiveDecomposeInterpreter:
"""
Recursively wrap gm and its sub graph modules. Specifically, HOP takes
sub graph module as args. We may not want to decompose all nodes within
these sub graph modules. So we also need to wrap these sub graph modules.
As a result:
- if should_decompose(hop) is True, we decompose all nodes within the hop.
- if should_decompose(hop) is False, we check each node within the hop
and decide whether decompose or not.
"""
for node in gm.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, HigherOrderOperator
):
new_args = []
for arg in node.args:
if isinstance(arg, fx.GraphModule):
new_arg = _SelectiveDecomposeInterpreter.recursive_wrap(
arg, should_decompose, decomposition_table, **kwargs
)
else:
new_arg = arg
new_args.append(new_arg)
node.args = tuple(new_args)
return _SelectiveDecomposeInterpreter(
gm, should_decompose, decomposition_table, **kwargs
)
def run_node(self, n):
if self.should_decompose(n):
with decompose(self.decomposition_table):
result = super().run_node(n)
else:
result = super().run_node(n)
return result
def selective_decompose(
joint_gm: fx.GraphModule,
*args,
decomposition,
should_decompose,
trace_joint_graph: bool,
) -> fx.GraphModule:
"""Retrace a joint graph module and selectively apply decomposition."""
if trace_joint_graph:
# the arg name, primals and tangents, are important.
# make_fx keeps the name in the traced graph and partitioner later relies
# on the name to partition joint graph correctly.
def wrap_fn(primals: list[Any], tangents: list[Any]):
return _SelectiveDecomposeInterpreter.recursive_wrap(
joint_gm, should_decompose, decomposition
).run(*args)
else:
def wrap_fn(*args):
return _SelectiveDecomposeInterpreter.recursive_wrap(
joint_gm, should_decompose, decomposition
).run(*args)
return make_fx(wrap_fn, decomposition_table={})(*args)
def wrapper_and_args_for_make_fx(
func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object]
) -> tuple[Callable[[list[object]], R], list[object]]:

View File

@ -112,7 +112,7 @@ def _compile_submod(gm, prefix):
return gm
def _needs_inductor_compile(node):
def _needs_inductor_compile(node: torch.fx.Node):
return (
node.op not in ("placeholder", "output")
and hasattr(node, "meta")