mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)
**Summary:** Support masked vectorization for the tail_loop for float64 datatype. **Example:** ``` import torch def fn(x): return x * x x = torch.randn((22, 22), dtype=torch.double) with torch.no_grad(): compiled_fn = torch.compile(fn) compiled_fn(x) ``` **Generated code:** - Before ``` cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r''' #include <torch/csrc/inductor/cpp_prefix.h> extern "C" void kernel(const double* in_ptr0, double* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L))) { auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); auto tmp1 = tmp0 * tmp0; tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); } if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L))) { for (int64_t x0_tail = static_cast<int64_t>(480L);x0_tail < static_cast<int64_t>(484L); x0_tail++) { auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)]; auto tmp1 = double(tmp0 * tmp0); out_ptr0[static_cast<int64_t>(x0_tail)] = tmp1; } } } } } } ''') async_compile.wait(globals()) del async_compile class Runner: def __init__(self, partitions): self.partitions = partitions def recursively_apply_fns(self, fns): new_callables = [] for fn, c in zip(fns, self.partitions): new_callables.append(fn(c)) self.partitions = new_callables def call(self, args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (22, 22), (22, 1)) buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64) # [Provenance debug handles] cpp_fused_mul_0:1 cpp_fused_mul_0(arg0_1, buf0) del arg0_1 return (buf0, ) ``` - After ``` cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r''' #include <torch/csrc/inductor/cpp_prefix.h> extern "C" void kernel(const double* in_ptr0, double* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L))) { auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); auto tmp1 = tmp0 * tmp0; tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); } if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L))) { auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L)); auto tmp1 = tmp0 * tmp0; tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L)); } } } } } ''') async_compile.wait(globals()) del async_compile class Runner: def __init__(self, partitions): self.partitions = partitions def recursively_apply_fns(self, fns): new_callables = [] for fn, c in zip(fns, self.partitions): new_callables.append(fn(c)) self.partitions = new_callables def call(self, args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (22, 22), (22, 1)) buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64) # [Provenance debug handles] cpp_fused_mul_0:1 cpp_fused_mul_0(arg0_1, buf0) del arg0_1 return (buf0, ) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163316 Approved by: https://github.com/mingfeima, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
61d9a5180e
commit
e9d8973427
@ -159,6 +159,7 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
]
|
||||
|
||||
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
|
||||
torch.float64,
|
||||
torch.float,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
|
Reference in New Issue
Block a user