[ATen][CUDA] Implement 128 bit vectorization v2 (#145746)

This is a re-base PR to my previous one #141959.

Description from the original PR:

This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.

<details>

<summary>The benchmark code used </summary>

```Python

import time
import torch
from torch.profiler import profile, ProfilerActivity

def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False):
    device = torch.device("cuda")

    shapes = []
    for p in range(24, 30):
        shape = 1<<p
        shapes.append(shape)

    for shape in shapes:
        for _ in range(6):
            x = torch.randn(shape, device=device, dtype=dtype)
            y = function(x)

        if print_profile:
            x = torch.randn(shape, device=device, dtype=dtype)
            with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
                y = function(x)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

        x = torch.randn(shape, device=device, dtype=dtype)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        for _ in range(6):
            y = function(x)
        torch.cuda.synchronize()
        t2 = time.perf_counter()
        perf_time = (t2 - t1) / 6

        print(f"{function.__name__}, {dtype}, {shape}, {perf_time}")
        if check_numerics:
            x_cpu = x.cpu()
            y_cpu = function(x_cpu).cuda()
            try:
                torch.testing.assert_allclose(y_cpu, y)
            except AssertionError as error:
                print("An exception occurred:", error)

def main():
    ops = [
            torch.relu,
            torch.sigmoid,
            torch.tanh,
            torch.nn.functional.gelu,
            torch.sin,
            torch.exp,
    ]

    dtypes = [
            torch.float16,
            torch.bfloat16,
            torch.float32,
    ]

    for op in ops:
        for dtype in dtypes:
            benchmark(op, dtype=dtype)
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
```

</details>

<details>

<summary> Results </summary>

| op | dtype | size | time after | time before | % improvement |
| ---- | ---- | ---- | ---- | ---- | ---- |
| relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 |
| relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 |
| relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 |
| relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 |
| relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 |
| relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 |
| relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 |
| relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 |
| relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 |
| relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 |
| relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 |
| relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 |
| relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 |
| relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 |
| relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 |
| relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 |
| relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 |
| sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 |
| sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 |
| sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 |
| sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 |
| sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 |
| sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 |
| sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 |
| sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 |
| sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 |
| sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 |
| sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 |
| sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 |
| sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 |
| sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 |
| sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 |
| sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 |
| sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 |
| sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 |
| tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 |
| tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 |
| tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 |
| tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 |
| tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 |
| tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 |
| tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 |
| tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 |
| tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 |
| tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 |
| tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 |
| tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 |
| tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 |
| tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 |
| tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 |
| tanh | torch.float32 | 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 |
| tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 |
| tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 |
| gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 |
| gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 |
| gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 |
| gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 |
| gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 |
| gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 |
| gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 |
| gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 |
| gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 |
| gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 |
| gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 |
| gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 |
| gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 |
| gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 |
| gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 |
| gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 |
| gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 |
| gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 |
| sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 |
| sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 |
| sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 |
| sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 |
| sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 |
| sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 |
| sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 |
| sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 |
| sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 |
| sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 |
| sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 |
| sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 |
| sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 |
| sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 |
| sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 |
| sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 |
| sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 |
| sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 |
| exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 |
| exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 |
| exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 |
| exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 |
| exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 |
| exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 |
| exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 |
| exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 |
| exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 |
| exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 |
| exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 |
| exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 |
| exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 |
| exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 |
| exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 |
| exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 |
| exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 |
| exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 |

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145746
Approved by: https://github.com/eqy, https://github.com/ngimel

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
Aidyn-A
2025-01-31 06:42:07 +00:00
committed by PyTorch MergeBot
parent eeb5e1bf20
commit e84bf88dde
8 changed files with 78 additions and 22 deletions

View File

@ -49,8 +49,8 @@ struct JittedVecKernelCache {
at::cuda::jit::NvrtcFunction vec1;
at::cuda::jit::NvrtcFunction vec2;
at::cuda::jit::NvrtcFunction vec4;
#ifdef USE_ROCM
at::cuda::jit::NvrtcFunction vec8;
#ifdef USE_ROCM
at::cuda::jit::NvrtcFunction vec16;
#endif
@ -131,6 +131,18 @@ void launch_jitted_vectorized_kernel(
int vec_size = at::cuda::jit::can_vectorize_up_to(
desc, c10::ArrayRef<char*>(data.data(), data.size()));
#ifndef USE_ROCM
const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
const int optimal_vec_size = 16 / static_cast<int>(input_size);
vec_size = std::min<int>(optimal_vec_size, vec_size);
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
// that causes some numerical mismatches with uint8 on sm80 and sm90.
// TODO: Revisit this after CUDA 12.8 update.
if (input_size < 2) {
vec_size = std::min<int>(vec_size, 4);
}
#endif
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
// fn_ptr is set to the appropriate function based on the vec size and GPU used
at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;
@ -138,11 +150,11 @@ void launch_jitted_vectorized_kernel(
#ifdef USE_ROCM
if (vec_size == 16) {
fn_ptr = &fn_cache.vec16;
} else if (vec_size == 8) {
fn_ptr = &fn_cache.vec8;
} else
#endif
if (vec_size == 4) {
if (vec_size == 8) {
fn_ptr = &fn_cache.vec8;
} else if (vec_size == 4) {
fn_ptr = &fn_cache.vec4;
} else if (vec_size == 2) {
fn_ptr = &fn_cache.vec2;

View File

@ -61,6 +61,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
}
}
#ifdef USE_ROCM
template <int io_sizes>
constexpr auto elems_per_thread(){
if constexpr (io_sizes == 1) {
@ -71,6 +72,16 @@ constexpr auto elems_per_thread(){
return 4;
}
}
#else
template <int io_sizes>
constexpr auto elems_per_thread(){
if constexpr (io_sizes == 1) {
return 16;
} else {
return 8;
}
}
#endif
template <int io_sizes>
constexpr auto io_block_work_size() {
@ -191,8 +202,20 @@ static inline void launch_vectorized_kernel(
constexpr auto io_size = calc_io_size<func_t>();
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
auto stream = at::cuda::getCurrentCUDAStream();
#ifdef USE_ROCM
int vec_size = memory::can_vectorize_up_to<func_t>(data);
#else
using cpp_type = typename function_traits<func_t>::result_type;
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type));
vec_size = std::min<uint16_t>(vec_size, max_vec_size);
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
// that causes some numerical mismatches with uint8 on sm80 and sm90.
// TODO: Revisit this after CUDA 12.8 update.
if constexpr (sizeof(cpp_type) < 2) {
vec_size = std::min<uint16_t>(vec_size, 4);
}
#endif
switch (vec_size) {
#ifdef USE_ROCM
case 16:
@ -200,12 +223,12 @@ static inline void launch_vectorized_kernel(
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
case 8:
vectorized_elementwise_kernel<8, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
case 4:
vectorized_elementwise_kernel<4, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);

View File

@ -217,8 +217,11 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
// make sure we don't break assumption that we can't have > 16 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]");
#else
const int optimal_vec_size = 16 / static_cast<int>(sizeof(scalar_t));
vec_size = std::min<int>(optimal_vec_size, vec_size);
// make sure we don't break assumption that we can't have > 4 elements / thread
TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]");
TORCH_INTERNAL_ASSERT(vec_size <= 8, "Value of VEC must be in [2, 4, 8]");
#endif
}

View File

@ -351,8 +351,8 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
#ifdef USE_ROCM
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
#ifdef USE_ROCM
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
constexpr int type_size = sizeof(scalar_t);
if (type_size == 1 && (address % vec16_alignment == 0)) {
@ -360,6 +360,10 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
return 8;
} else
#else
if (address % vec8_alignment == 0) {
return 8;
} else
#endif
if (address % vec4_alignment == 0) {
return 4;

View File

@ -932,7 +932,6 @@ void initializeCudaContext() {
}
}
#ifdef USE_ROCM
int calc_io_size(
const int nInputs,
const int nOutputs,
@ -952,7 +951,6 @@ int calc_io_size(
return 0;
}
#endif
int calc_thread_work_size(
const int nInputs,
@ -971,7 +969,14 @@ int calc_thread_work_size(
}
return io_size;
#else
return JIT_THREAD_WORK_SIZE;
auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type);
TORCH_INTERNAL_ASSERT(io_size > 0);
if (io_size == 1) {
return 16;
} else {
return 8;
}
return io_size;
#endif
}

View File

@ -60,6 +60,10 @@ inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) {
return 8;
}
#else
if (ip % (8 * default_alignment) == 0) {
return 8;
}
#endif
if (ip % (4 * default_alignment) == 0) {
return 4;
@ -88,15 +92,17 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*
}
//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
#define JIT_THREAD_WORK_SIZE 4
#ifdef USE_ROCM
#define JIT_THREAD_WORK_SIZE 4
#else
#define JIT_THREAD_WORK_SIZE 8
#endif
int calc_io_size(
const int nInputs,
const int nOutputs,
const c10::ScalarType& inputs_type,
const c10::ScalarType& result_type);
#endif
int calc_thread_work_size(
const int nInputs,

View File

@ -12,11 +12,14 @@
constexpr int num_threads() {
return 256;
}
constexpr int thread_work_size() { return 4; }
#else
constexpr uint32_t num_threads() {
return C10_WARP_SIZE * 4;
}
constexpr int thread_work_size() { return 8; }
#endif
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }

View File

@ -47,11 +47,11 @@ TEST(TestLoops, HasSameArgTypes) {
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
char *ptr = reinterpret_cast<char *>(buffer1);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 4);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
@ -65,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 4);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);