From e84bf88dde509d44175a0a1c00cec13c9926843e Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Fri, 31 Jan 2025 06:42:07 +0000 Subject: [PATCH] [ATen][CUDA] Implement 128 bit vectorization v2 (#145746) This is a re-base PR to my previous one #141959. Description from the original PR: This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.
The benchmark code used ```Python import time import torch from torch.profiler import profile, ProfilerActivity def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False): device = torch.device("cuda") shapes = [] for p in range(24, 30): shape = 1<

Results | op | dtype | size | time after | time before | % improvement | | ---- | ---- | ---- | ---- | ---- | ---- | | relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 | | relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 | | relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 | | relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 | | relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 | | relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 | | relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 | | relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 | | relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 | | relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 | | relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 | | relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 | | relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 | | relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 | | relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 | | relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 | | relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 | | sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 | | sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 | | sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 | | sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 | | sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 | | sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 | | sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 | | sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 | | sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 | | sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 | | sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 | | sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 | | sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 | | sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 | | sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 | | sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 | | sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 | | sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 | | tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 | | tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 | | tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 | | tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 | | tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 | | tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 | | tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 | | tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 | | tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 | | tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 | | tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 | | tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 | | tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 | | tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 | | tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 | | tanh | torch.float32 | 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 | | tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 | | tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 | | gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 | | gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 | | gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 | | gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 | | gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 | | gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 | | gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 | | gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 | | gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 | | gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 | | gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 | | gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 | | gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 | | gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 | | gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 | | gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 | | gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 | | gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 | | sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 | | sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 | | sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 | | sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 | | sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 | | sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 | | sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 | | sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 | | sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 | | sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 | | sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 | | sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 | | sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 | | sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 | | sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 | | sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 | | sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 | | sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 | | exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 | | exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 | | exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 | | exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 | | exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 | | exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 | | exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 | | exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 | | exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 | | exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 | | exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 | | exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 | | exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 | | exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 | | exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 | | exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 | | exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 | | exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145746 Approved by: https://github.com/eqy, https://github.com/ngimel Co-authored-by: Aaron Gokaslan --- aten/src/ATen/native/cuda/CUDAJitLoops.cuh | 20 ++++++++++++--- aten/src/ATen/native/cuda/CUDALoops.cuh | 27 ++++++++++++++++++-- aten/src/ATen/native/cuda/Dropout.cu | 5 +++- aten/src/ATen/native/cuda/MemoryAccess.cuh | 6 ++++- aten/src/ATen/native/cuda/jit_utils.cpp | 11 +++++--- aten/src/ATen/native/cuda/jit_utils.h | 12 ++++++--- aten/src/ATen/native/cuda/thread_constants.h | 5 +++- aten/src/ATen/test/cuda_vectorized_test.cu | 14 +++++----- 8 files changed, 78 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh index 9136f63acabb..c4c3af83ccd8 100644 --- a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh +++ b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh @@ -49,8 +49,8 @@ struct JittedVecKernelCache { at::cuda::jit::NvrtcFunction vec1; at::cuda::jit::NvrtcFunction vec2; at::cuda::jit::NvrtcFunction vec4; -#ifdef USE_ROCM at::cuda::jit::NvrtcFunction vec8; +#ifdef USE_ROCM at::cuda::jit::NvrtcFunction vec16; #endif @@ -131,6 +131,18 @@ void launch_jitted_vectorized_kernel( int vec_size = at::cuda::jit::can_vectorize_up_to( desc, c10::ArrayRef(data.data(), data.size())); +#ifndef USE_ROCM + const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + const int optimal_vec_size = 16 / static_cast(input_size); + vec_size = std::min(optimal_vec_size, vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + if (input_size < 2) { + vec_size = std::min(vec_size, 4); + } +#endif + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) // fn_ptr is set to the appropriate function based on the vec size and GPU used at::cuda::jit::NvrtcFunction* fn_ptr = nullptr; @@ -138,11 +150,11 @@ void launch_jitted_vectorized_kernel( #ifdef USE_ROCM if (vec_size == 16) { fn_ptr = &fn_cache.vec16; - } else if (vec_size == 8) { - fn_ptr = &fn_cache.vec8; } else #endif - if (vec_size == 4) { + if (vec_size == 8) { + fn_ptr = &fn_cache.vec8; + } else if (vec_size == 4) { fn_ptr = &fn_cache.vec4; } else if (vec_size == 2) { fn_ptr = &fn_cache.vec2; diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 44d2aa6ecc2a..6c4ad7a48a0b 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -61,6 +61,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence) { } } +#ifdef USE_ROCM template constexpr auto elems_per_thread(){ if constexpr (io_sizes == 1) { @@ -71,6 +72,16 @@ constexpr auto elems_per_thread(){ return 4; } } +#else +template +constexpr auto elems_per_thread(){ + if constexpr (io_sizes == 1) { + return 16; + } else { + return 8; + } +} +#endif template constexpr auto io_block_work_size() { @@ -191,8 +202,20 @@ static inline void launch_vectorized_kernel( constexpr auto io_size = calc_io_size(); int64_t grid = (N + io_block_work_size() - 1) / io_block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); +#ifdef USE_ROCM int vec_size = memory::can_vectorize_up_to(data); - +#else + using cpp_type = typename function_traits::result_type; + const uint16_t max_vec_size = memory::can_vectorize_up_to(data); + uint16_t vec_size = 16 / static_cast(sizeof(cpp_type)); + vec_size = std::min(vec_size, max_vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + if constexpr (sizeof(cpp_type) < 2) { + vec_size = std::min(vec_size, 4); + } +#endif switch (vec_size) { #ifdef USE_ROCM case 16: @@ -200,12 +223,12 @@ static inline void launch_vectorized_kernel( <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; +#endif case 8: vectorized_elementwise_kernel<8, func_t, array_t> <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; -#endif case 4: vectorized_elementwise_kernel<4, func_t, array_t> <<>>(N, f, data); diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index e37fa1b1835a..9c1a6e046de7 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -217,8 +217,11 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) { // make sure we don't break assumption that we can't have > 16 elements / thread TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]"); #else + const int optimal_vec_size = 16 / static_cast(sizeof(scalar_t)); + vec_size = std::min(optimal_vec_size, vec_size); + // make sure we don't break assumption that we can't have > 4 elements / thread - TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]"); + TORCH_INTERNAL_ASSERT(vec_size <= 8, "Value of VEC must be in [2, 4, 8]"); #endif } diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index cda0e10a1fa4..4a00d714a0c7 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -351,8 +351,8 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { uint64_t address = reinterpret_cast(pointer); constexpr int vec2_alignment = std::alignment_of_v>; constexpr int vec4_alignment = std::alignment_of_v>; -#ifdef USE_ROCM constexpr int vec8_alignment = std::alignment_of_v>; +#ifdef USE_ROCM constexpr int vec16_alignment = std::alignment_of_v>; constexpr int type_size = sizeof(scalar_t); if (type_size == 1 && (address % vec16_alignment == 0)) { @@ -360,6 +360,10 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { } else if (type_size <= 2 && (address % vec8_alignment == 0)) { return 8; } else +#else + if (address % vec8_alignment == 0) { + return 8; + } else #endif if (address % vec4_alignment == 0) { return 4; diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 7a6f0d0be479..83007cc7159d 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -932,7 +932,6 @@ void initializeCudaContext() { } } -#ifdef USE_ROCM int calc_io_size( const int nInputs, const int nOutputs, @@ -952,7 +951,6 @@ int calc_io_size( return 0; } -#endif int calc_thread_work_size( const int nInputs, @@ -971,7 +969,14 @@ int calc_thread_work_size( } return io_size; #else - return JIT_THREAD_WORK_SIZE; + auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type); + TORCH_INTERNAL_ASSERT(io_size > 0); + if (io_size == 1) { + return 16; + } else { + return 8; + } + return io_size; #endif } diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h index c7dcbd8cf728..d971df01396e 100644 --- a/aten/src/ATen/native/cuda/jit_utils.h +++ b/aten/src/ATen/native/cuda/jit_utils.h @@ -60,6 +60,10 @@ inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) { return 8; } +#else + if (ip % (8 * default_alignment) == 0) { + return 8; + } #endif if (ip % (4 * default_alignment) == 0) { return 4; @@ -88,15 +92,17 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef CUDALoops.cuh -> jit_utils.h -> Loops.cuh -#define JIT_THREAD_WORK_SIZE 4 - #ifdef USE_ROCM +#define JIT_THREAD_WORK_SIZE 4 +#else +#define JIT_THREAD_WORK_SIZE 8 +#endif + int calc_io_size( const int nInputs, const int nOutputs, const c10::ScalarType& inputs_type, const c10::ScalarType& result_type); -#endif int calc_thread_work_size( const int nInputs, diff --git a/aten/src/ATen/native/cuda/thread_constants.h b/aten/src/ATen/native/cuda/thread_constants.h index 651053d663e4..bcc797a26e1c 100644 --- a/aten/src/ATen/native/cuda/thread_constants.h +++ b/aten/src/ATen/native/cuda/thread_constants.h @@ -12,11 +12,14 @@ constexpr int num_threads() { return 256; } + +constexpr int thread_work_size() { return 4; } #else constexpr uint32_t num_threads() { return C10_WARP_SIZE * 4; } + +constexpr int thread_work_size() { return 8; } #endif -constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu index c1143fabb277..6b120f7eb304 100644 --- a/aten/src/ATen/test/cuda_vectorized_test.cu +++ b/aten/src/ATen/test/cuda_vectorized_test.cu @@ -47,11 +47,11 @@ TEST(TestLoops, HasSameArgTypes) { TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) { char *ptr = reinterpret_cast(buffer1); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 1), 1); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 1), 1); @@ -65,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) { ASSERT_EQ(memory::can_vectorize_up_to(ptr + 4), 2); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 4), 1); - ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 4); + ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 8); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 4); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 2); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 1);