[Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)

**Summary:**
Support masked vectorization for the tail_loop for fp8 datatype.

**Example:**
```
import torch

def fn(
    x,
    scale,
    zero_point,
    quant_min,
    quant_max,
    dtype,
):
    x = torch.ops.quantized_decomposed.dequantize_per_tensor(
        x,
        scale,
        zero_point,
        quant_min,
        quant_max,
        dtype,
    )
    x = torch.relu(x)
    x = torch.ops.quantized_decomposed.quantize_per_tensor(
        x, scale, zero_point, quant_min, quant_max, dtype
    )
    return x

quant_min = -128
quant_max = 127
dtype = torch.float8_e4m3fn
x = torch.clamp(torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, quant_min, quant_max).to(dtype)
zero_point = 100
scale = 0.01

with torch.no_grad():
    compiled_fn = torch.compile(fn)
    compiled_fn(x, scale, zero_point, quant_min, quant_max, dtype)
```

**Generated code:**

- Before
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const at::Float8_e4m3fn* in_ptr0,
                       at::Float8_e4m3fn* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
                {
                    for (int64_t x0_tail = static_cast<int64_t>(432L);x0_tail < static_cast<int64_t>(441L); x0_tail++)
                    {
                        auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
                        auto tmp1 = c10::convert<float>(tmp0);
                        auto tmp2 = static_cast<float>(100.0);
                        auto tmp3 = float(tmp1 - tmp2);
                        auto tmp4 = static_cast<float>(0.01);
                        auto tmp5 = float(tmp3 * tmp4);
                        auto tmp6 = c10::convert<float>(tmp5);
                        auto tmp7 = std::max(tmp6, decltype(tmp6)(0));
                        auto tmp8 = float(tmp7 * tmp2);
                        auto tmp9 = std::nearbyint(tmp8);
                        auto tmp10 = float(tmp9 + tmp2);
                        auto tmp11 = static_cast<float>(-128.0);
                        auto tmp12 = max_propagate_nan(tmp10, tmp11);
                        auto tmp13 = static_cast<float>(127.0);
                        auto tmp14 = min_propagate_nan(tmp12, tmp13);
                        auto tmp15 = c10::convert<at::Float8_e4m3fn>(tmp14);
                        out_ptr0[static_cast<int64_t>(x0_tail)] = tmp15;
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
        buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
        # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
        cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```
- After
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const at::Float8_e4m3fn* in_ptr0,
                       at::Float8_e4m3fn* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
        buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
        # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
        cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163324
Approved by: https://github.com/Xia-Weiwen, https://github.com/mingfeima, https://github.com/jansel
ghstack dependencies: #163316
This commit is contained in:
Sun, Jiayi
2025-10-15 16:42:34 +00:00
committed by PyTorch MergeBot
parent e9d8973427
commit e8cb34dd52
2 changed files with 19 additions and 13 deletions

View File

@ -1543,22 +1543,26 @@ class CPUReproTests(TestCase):
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(
fn,
(
x,
scale,
zero_point,
use_dequant,
use_quant,
quant_min,
quant_max,
dtype,
dequant_out_dtype,
),
inputs = (
x,
scale,
zero_point,
use_dequant,
use_quant,
quant_min,
quant_max,
dtype,
dequant_out_dtype,
)
self.common(fn, inputs)
check_metrics_vec_kernel_count(1)
# Check that both main and tail loops are vectorized
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
compiled_fn = torch.compile(fn)
_, code = run_and_get_cpp_code(compiled_fn, *inputs)
FileCheck().check_count("loadu", 2, exactly=True).run(code)
@requires_vectorization
def test_dequant_quant_lowering_uint8(self):
self._test_dequant_quant_lowering_helper(torch.uint8)