mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)
**Summary:** Support masked vectorization for the tail_loop for fp8 datatype. **Example:** ``` import torch def fn( x, scale, zero_point, quant_min, quant_max, dtype, ): x = torch.ops.quantized_decomposed.dequantize_per_tensor( x, scale, zero_point, quant_min, quant_max, dtype, ) x = torch.relu(x) x = torch.ops.quantized_decomposed.quantize_per_tensor( x, scale, zero_point, quant_min, quant_max, dtype ) return x quant_min = -128 quant_max = 127 dtype = torch.float8_e4m3fn x = torch.clamp(torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, quant_min, quant_max).to(dtype) zero_point = 100 scale = 0.01 with torch.no_grad(): compiled_fn = torch.compile(fn) compiled_fn(x, scale, zero_point, quant_min, quant_max, dtype) ``` **Generated code:** - Before ``` cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r''' #include <torch/csrc/inductor/cpp_prefix.h> extern "C" void kernel(const at::Float8_e4m3fn* in_ptr0, at::Float8_e4m3fn* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L))) { auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<float>(tmp0); auto tmp2 = static_cast<float>(100.0); auto tmp3 = at::vec::Vectorized<float>(tmp2); auto tmp4 = tmp1 - tmp3; auto tmp5 = static_cast<float>(0.01); auto tmp6 = at::vec::Vectorized<float>(tmp5); auto tmp7 = tmp4 * tmp6; auto tmp8 = (tmp7); auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0)); auto tmp10 = tmp9 * tmp3; auto tmp11 = tmp10.round(); auto tmp12 = tmp11 + tmp3; auto tmp13 = static_cast<float>(-128.0); auto tmp14 = at::vec::Vectorized<float>(tmp13); auto tmp15 = at::vec::maximum(tmp12, tmp14); auto tmp16 = static_cast<float>(127.0); auto tmp17 = at::vec::Vectorized<float>(tmp16); auto tmp18 = at::vec::minimum(tmp15, tmp17); auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18); tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); } if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L))) { for (int64_t x0_tail = static_cast<int64_t>(432L);x0_tail < static_cast<int64_t>(441L); x0_tail++) { auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)]; auto tmp1 = c10::convert<float>(tmp0); auto tmp2 = static_cast<float>(100.0); auto tmp3 = float(tmp1 - tmp2); auto tmp4 = static_cast<float>(0.01); auto tmp5 = float(tmp3 * tmp4); auto tmp6 = c10::convert<float>(tmp5); auto tmp7 = std::max(tmp6, decltype(tmp6)(0)); auto tmp8 = float(tmp7 * tmp2); auto tmp9 = std::nearbyint(tmp8); auto tmp10 = float(tmp9 + tmp2); auto tmp11 = static_cast<float>(-128.0); auto tmp12 = max_propagate_nan(tmp10, tmp11); auto tmp13 = static_cast<float>(127.0); auto tmp14 = min_propagate_nan(tmp12, tmp13); auto tmp15 = c10::convert<at::Float8_e4m3fn>(tmp14); out_ptr0[static_cast<int64_t>(x0_tail)] = tmp15; } } } } } } ''') async_compile.wait(globals()) del async_compile class Runner: def __init__(self, partitions): self.partitions = partitions def recursively_apply_fns(self, fns): new_callables = [] for fn, c in zip(fns, self.partitions): new_callables.append(fn(c)) self.partitions = new_callables def call(self, args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1)) buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn) # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1 cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0) del arg0_1 return (buf0, ) ``` - After ``` cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r''' #include <torch/csrc/inductor/cpp_prefix.h> extern "C" void kernel(const at::Float8_e4m3fn* in_ptr0, at::Float8_e4m3fn* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L))) { auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<float>(tmp0); auto tmp2 = static_cast<float>(100.0); auto tmp3 = at::vec::Vectorized<float>(tmp2); auto tmp4 = tmp1 - tmp3; auto tmp5 = static_cast<float>(0.01); auto tmp6 = at::vec::Vectorized<float>(tmp5); auto tmp7 = tmp4 * tmp6; auto tmp8 = (tmp7); auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0)); auto tmp10 = tmp9 * tmp3; auto tmp11 = tmp10.round(); auto tmp12 = tmp11 + tmp3; auto tmp13 = static_cast<float>(-128.0); auto tmp14 = at::vec::Vectorized<float>(tmp13); auto tmp15 = at::vec::maximum(tmp12, tmp14); auto tmp16 = static_cast<float>(127.0); auto tmp17 = at::vec::Vectorized<float>(tmp16); auto tmp18 = at::vec::minimum(tmp15, tmp17); auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18); tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); } if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L))) { auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L)); auto tmp1 = at::vec::convert<float>(tmp0); auto tmp2 = static_cast<float>(100.0); auto tmp3 = at::vec::Vectorized<float>(tmp2); auto tmp4 = tmp1 - tmp3; auto tmp5 = static_cast<float>(0.01); auto tmp6 = at::vec::Vectorized<float>(tmp5); auto tmp7 = tmp4 * tmp6; auto tmp8 = (tmp7); auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0)); auto tmp10 = tmp9 * tmp3; auto tmp11 = tmp10.round(); auto tmp12 = tmp11 + tmp3; auto tmp13 = static_cast<float>(-128.0); auto tmp14 = at::vec::Vectorized<float>(tmp13); auto tmp15 = at::vec::maximum(tmp12, tmp14); auto tmp16 = static_cast<float>(127.0); auto tmp17 = at::vec::Vectorized<float>(tmp16); auto tmp18 = at::vec::minimum(tmp15, tmp17); auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18); tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L)); } } } } } ''') async_compile.wait(globals()) del async_compile class Runner: def __init__(self, partitions): self.partitions = partitions def recursively_apply_fns(self, fns): new_callables = [] for fn, c in zip(fns, self.partitions): new_callables.append(fn(c)) self.partitions = new_callables def call(self, args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1)) buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn) # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1 cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0) del arg0_1 return (buf0, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163324 Approved by: https://github.com/Xia-Weiwen, https://github.com/mingfeima, https://github.com/jansel ghstack dependencies: #163316
This commit is contained in:
committed by
PyTorch MergeBot
parent
e9d8973427
commit
e8cb34dd52
@ -165,6 +165,8 @@ MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
torch.float16,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user