From e288c258f7d388c2237d136c81514dd53586da0a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 29 Jul 2025 23:32:07 +0000 Subject: [PATCH] Revert "Remove tensorexpr tests (#158928)" This reverts commit d742a2896c571a535003d5928fe80397325575a5. Reverted https://github.com/pytorch/pytorch/pull/158928 on behalf of https://github.com/yangw-dev due to this breaks bunch of internal dependency since some tests are still using the deleted test files from this pr, the internal reviewer please help fix this using codev ([comment](https://github.com/pytorch/pytorch/pull/158928#issuecomment-3134378616)) --- .ci/pytorch/build.sh | 4 + .ci/pytorch/test.sh | 10 + aten/src/ATen/test/thread_init_test.cpp | 11 +- caffe2/CMakeLists.txt | 4 + test/cpp/tensorexpr/CMakeLists.txt | 83 + test/cpp/tensorexpr/README.md | 55 + test/cpp/tensorexpr/gtest_assert_float_eq.h | 119 + test/cpp/tensorexpr/padded_buffer.cpp | 37 + test/cpp/tensorexpr/padded_buffer.h | 242 + test/cpp/tensorexpr/test_approx.cpp | 96 + test/cpp/tensorexpr/test_aten.cpp | 1068 +++ test/cpp/tensorexpr/test_base.h | 89 + test/cpp/tensorexpr/test_boundsinference.cpp | 1019 +++ test/cpp/tensorexpr/test_conv.cpp | 234 + test/cpp/tensorexpr/test_cpp_codegen.cpp | 259 + test/cpp/tensorexpr/test_cuda.cpp | 2344 ++++++ test/cpp/tensorexpr/test_dynamic_shapes.cpp | 701 ++ test/cpp/tensorexpr/test_expr.cpp | 836 ++ test/cpp/tensorexpr/test_external_calls.cpp | 1061 +++ test/cpp/tensorexpr/test_graph_opt.cpp | 319 + test/cpp/tensorexpr/test_ir_printer.cpp | 98 + test/cpp/tensorexpr/test_ir_verifier.cpp | 191 + test/cpp/tensorexpr/test_kernel.cpp | 2133 +++++ test/cpp/tensorexpr/test_llvm.cpp | 1799 +++++ test/cpp/tensorexpr/test_loopnest.cpp | 6894 +++++++++++++++++ test/cpp/tensorexpr/test_memdependency.cpp | 3252 ++++++++ test/cpp/tensorexpr/test_memplanning.cpp | 708 ++ test/cpp/tensorexpr/test_ops.cpp | 78 + test/cpp/tensorexpr/test_quantization.cpp | 452 ++ test/cpp/tensorexpr/test_reductions.cpp | 1928 +++++ test/cpp/tensorexpr/test_registerizer.cpp | 3702 +++++++++ test/cpp/tensorexpr/test_simplify.cpp | 5680 ++++++++++++++ test/cpp/tensorexpr/test_te_fuser_pass.cpp | 402 + test/cpp/tensorexpr/test_type.cpp | 202 + .../tensorexpr/test_type_specializations.cpp | 75 + test/cpp/tensorexpr/test_utils.h | 78 + test/cpp/tensorexpr/tutorial.cpp | 542 ++ test/test_jit_fuser_te.py | 5 +- torch/csrc/jit/runtime/static/ops.cpp | 2 +- 39 files changed, 36802 insertions(+), 10 deletions(-) create mode 100644 test/cpp/tensorexpr/CMakeLists.txt create mode 100644 test/cpp/tensorexpr/README.md create mode 100644 test/cpp/tensorexpr/gtest_assert_float_eq.h create mode 100644 test/cpp/tensorexpr/padded_buffer.cpp create mode 100644 test/cpp/tensorexpr/padded_buffer.h create mode 100644 test/cpp/tensorexpr/test_approx.cpp create mode 100644 test/cpp/tensorexpr/test_aten.cpp create mode 100644 test/cpp/tensorexpr/test_base.h create mode 100644 test/cpp/tensorexpr/test_boundsinference.cpp create mode 100644 test/cpp/tensorexpr/test_conv.cpp create mode 100644 test/cpp/tensorexpr/test_cpp_codegen.cpp create mode 100644 test/cpp/tensorexpr/test_cuda.cpp create mode 100644 test/cpp/tensorexpr/test_dynamic_shapes.cpp create mode 100644 test/cpp/tensorexpr/test_expr.cpp create mode 100644 test/cpp/tensorexpr/test_external_calls.cpp create mode 100644 test/cpp/tensorexpr/test_graph_opt.cpp create mode 100644 test/cpp/tensorexpr/test_ir_printer.cpp create mode 100644 test/cpp/tensorexpr/test_ir_verifier.cpp create mode 100644 test/cpp/tensorexpr/test_kernel.cpp create mode 100644 test/cpp/tensorexpr/test_llvm.cpp create mode 100644 test/cpp/tensorexpr/test_loopnest.cpp create mode 100644 test/cpp/tensorexpr/test_memdependency.cpp create mode 100644 test/cpp/tensorexpr/test_memplanning.cpp create mode 100644 test/cpp/tensorexpr/test_ops.cpp create mode 100644 test/cpp/tensorexpr/test_quantization.cpp create mode 100644 test/cpp/tensorexpr/test_reductions.cpp create mode 100644 test/cpp/tensorexpr/test_registerizer.cpp create mode 100644 test/cpp/tensorexpr/test_simplify.cpp create mode 100644 test/cpp/tensorexpr/test_te_fuser_pass.cpp create mode 100644 test/cpp/tensorexpr/test_type.cpp create mode 100644 test/cpp/tensorexpr/test_type_specializations.cpp create mode 100644 test/cpp/tensorexpr/test_utils.h create mode 100644 test/cpp/tensorexpr/tutorial.cpp diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 1f0b4b63843b..a7ce0fef736c 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -50,6 +50,9 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then export ATEN_THREADING=NATIVE fi +# Enable LLVM dependency for TensorExpr testing +export USE_LLVM=/opt/llvm +export LLVM_DIR=/opt/llvm/lib/cmake/llvm if ! which conda; then # In ROCm CIs, we are doing cross compilation on build machines with @@ -189,6 +192,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then export USE_ASAN=1 export REL_WITH_DEB_INFO=1 export UBSAN_FLAGS="-fno-sanitize-recover=all" + unset USE_LLVM fi if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 35c828042ac8..fb4e0759d508 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1039,10 +1039,20 @@ test_libtorch_api() { mkdir -p $TEST_REPORTS_DIR OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml + "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml else # Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest" + # On s390x, pytorch is built without llvm. + # Even if it would be built with llvm, llvm currently doesn't support used features on s390x and + # test fails with errors like: + # JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer + # unknown file: Failure + # C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) } + if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then + python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr + fi fi # quantization is not fully supported on s390x yet diff --git a/aten/src/ATen/test/thread_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp index 60dd52d1dffc..7ad7a18e9c66 100644 --- a/aten/src/ATen/test/thread_init_test.cpp +++ b/aten/src/ATen/test/thread_init_test.cpp @@ -1,8 +1,7 @@ -#include - #include #include #include +#include #include @@ -10,7 +9,7 @@ // numbers of threads set and also whether the scheduler // will throw an exception when multiple threads call // their first parallel construct. -static void test(int given_num_threads) { +void test(int given_num_threads) { auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat)); ASSERT_TRUE(given_num_threads >= 0); ASSERT_EQ(at::get_num_threads(), given_num_threads); @@ -20,7 +19,7 @@ static void test(int given_num_threads) { } } -TEST(ThreadInitTest, ThreadInit) { +int main() { at::init_num_threads(); at::set_num_threads(4); @@ -33,11 +32,13 @@ TEST(ThreadInitTest, ThreadInit) { #if !AT_PARALLEL_NATIVE at::set_num_threads(5); - ASSERT_EQ(at::get_num_threads(), 5); + ASSERT_TRUE(at::get_num_threads() == 5); #endif // test inter-op settings at::set_num_interop_threads(5); ASSERT_EQ(at::get_num_interop_threads(), 5); ASSERT_ANY_THROW(at::set_num_interop_threads(6)); + + return 0; } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 4f123bec2dc8..776688dccad5 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1346,6 +1346,10 @@ if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) + add_subdirectory( + ${TORCH_ROOT}/test/cpp/tensorexpr + ${CMAKE_BINARY_DIR}/test_tensorexpr + ) if(USE_DISTRIBUTED) add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) if(NOT WIN32) diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt new file mode 100644 index 000000000000..8fe6ffd525e9 --- /dev/null +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -0,0 +1,83 @@ +set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr) + +set(TENSOREXPR_TEST_SRCS + ${TENSOREXPR_TEST_ROOT}/test_approx.cpp + ${TENSOREXPR_TEST_ROOT}/test_aten.cpp + ${TENSOREXPR_TEST_ROOT}/test_boundsinference.cpp + ${TENSOREXPR_TEST_ROOT}/test_conv.cpp + ${TENSOREXPR_TEST_ROOT}/test_cpp_codegen.cpp + ${TENSOREXPR_TEST_ROOT}/test_dynamic_shapes.cpp + ${TENSOREXPR_TEST_ROOT}/test_expr.cpp + ${TENSOREXPR_TEST_ROOT}/test_external_calls.cpp + ${TENSOREXPR_TEST_ROOT}/test_graph_opt.cpp + ${TENSOREXPR_TEST_ROOT}/test_ir_printer.cpp + ${TENSOREXPR_TEST_ROOT}/test_ir_verifier.cpp + ${TENSOREXPR_TEST_ROOT}/test_kernel.cpp + ${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp + ${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp + ${TENSOREXPR_TEST_ROOT}/test_ops.cpp + ${TENSOREXPR_TEST_ROOT}/test_quantization.cpp + ${TENSOREXPR_TEST_ROOT}/test_memplanning.cpp + ${TENSOREXPR_TEST_ROOT}/test_reductions.cpp + ${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp + ${TENSOREXPR_TEST_ROOT}/test_simplify.cpp + ${TENSOREXPR_TEST_ROOT}/test_te_fuser_pass.cpp + ${TENSOREXPR_TEST_ROOT}/test_type.cpp + ${TENSOREXPR_TEST_ROOT}/test_type_specializations.cpp +) + +if(USE_CUDA) + list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp) +endif() + +if(USE_LLVM AND LLVM_FOUND) + list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp) +endif() + +add_executable(test_tensorexpr + ${TORCH_ROOT}/test/cpp/common/main.cpp + ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp + ${TENSOREXPR_TEST_SRCS}) + +target_link_libraries(test_tensorexpr PRIVATE torch gtest_main) +target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) +target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST) + +add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp) +target_link_libraries(tutorial_tensorexpr PRIVATE torch) +target_include_directories(tutorial_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) + +# The test case depends on the xnnpack header which in turn depends on the +# pthreadpool header. For some build environment we need add the dependency +# explicitly. +if(USE_PTHREADPOOL) + target_link_libraries(test_tensorexpr PRIVATE pthreadpool_interface) +endif() +if(USE_CUDA) + target_compile_definitions(test_tensorexpr PRIVATE USE_CUDA) + target_compile_definitions(tutorial_tensorexpr PRIVATE USE_CUDA) +elseif(USE_ROCM) + target_link_libraries(test_tensorexpr PRIVATE + hiprtc::hiprtc + hip::amdhip64 + ${TORCH_CUDA_LIBRARIES}) + target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) + + target_link_libraries(tutorial_tensorexpr PRIVATE + hiprtc::hiprtc + hip::amdhip64 + ${TORCH_CUDA_LIBRARIES}) + target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM) +endif() + +if(INSTALL_TEST) + set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") + install(TARGETS test_tensorexpr DESTINATION bin) + set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") + install(TARGETS tutorial_tensorexpr DESTINATION bin) + # Install PDB files for MSVC builds + if(MSVC AND BUILD_SHARED_LIBS) + install(FILES $ DESTINATION bin OPTIONAL) + install(FILES $ DESTINATION bin OPTIONAL) + endif() +endif() diff --git a/test/cpp/tensorexpr/README.md b/test/cpp/tensorexpr/README.md new file mode 100644 index 000000000000..f86a50a65e80 --- /dev/null +++ b/test/cpp/tensorexpr/README.md @@ -0,0 +1,55 @@ +# TensorExpr C++ Tests + +## How to add a new test +First, create a new test file. Test files should have be placed in this +directory, with a name that starts with `test_`, like `test_foo.cpp`. + +Here is an example test file you can copy-paste. +```cpp +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +// 1. Test cases are void() functions. +// 2. They start with the prefix `test` +void testCaseOne() { + // ... +} + +void testCaseTwo() { + // ... +} +} +} +``` + +Then, register your test in `tests.h`: +```cpp +// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests +#define TH_FORALL_TESTS(_) \ + _(ADFormulas) \ + _(Attributes) \ + ... + _(CaseOne) // note that the `test` prefix is omitted. + _(CaseTwo) +``` + +We glob all the test files together in `CMakeLists.txt` so that you don't +have to edit it every time you add a test. Unfortunately, this means that in +order to get the build to pick up your new test file, you need to re-run +cmake: +```bash +CMAKE_FRESH=1 python setup.py build +``` + +## How do I run the tests? +The following commands assume you are in PyTorch root. + + ```bash + # (re)build the test binary + ninja build/bin/test_tensorexpr + # run + build/bin/test_tensorexpr --gtest_filter='glob_style_filter*' + ``` diff --git a/test/cpp/tensorexpr/gtest_assert_float_eq.h b/test/cpp/tensorexpr/gtest_assert_float_eq.h new file mode 100644 index 000000000000..f85264a8f5d3 --- /dev/null +++ b/test/cpp/tensorexpr/gtest_assert_float_eq.h @@ -0,0 +1,119 @@ +#pragma once + +#include +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file declares functions and macros used internally by +// Google Test. They are subject to change without notice. + +using Bits = uint32_t; + +// this avoids the "dereferencing type-punned pointer +// will break strict-aliasing rules" error +union Float { + float float_; + Bits bits_; +}; + +// # of bits in a number. +static const size_t kBitCount = 8 * sizeof(Bits); +// The mask for the sign bit. +static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); + +// GOOGLETEST_CM0001 DO NOT DELETE + +// Converts an integer from the sign-and-magnitude representation to +// the biased representation. More precisely, let N be 2 to the +// power of (kBitCount - 1), an integer x is represented by the +// unsigned number x + N. +// +// For instance, +// +// -N + 1 (the most negative number representable using +// sign-and-magnitude) is represented by 1; +// 0 is represented by N; and +// N - 1 (the biggest number representable using +// sign-and-magnitude) is represented by 2N - 1. +// +// Read http://en.wikipedia.org/wiki/Signed_number_representations +// for more details on signed number representations. +static Bits SignAndMagnitudeToBiased(const Bits& sam) { + if (kSignBitMask & sam) { + // sam represents a negative number. + return ~sam + 1; + } else { + // sam represents a positive number. + return kSignBitMask | sam; + } +} + +// Given two numbers in the sign-and-magnitude representation, +// returns the distance between them as an unsigned number. +static Bits DistanceBetweenSignAndMagnitudeNumbers( + const Bits& sam1, + const Bits& sam2) { + const Bits biased1 = SignAndMagnitudeToBiased(sam1); + const Bits biased2 = SignAndMagnitudeToBiased(sam2); + return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); +} + +// How many ULP's (Units in the Last Place) we want to tolerate when +// comparing two numbers. The larger the value, the more error we +// allow. A 0 value means that two numbers must be exactly the same +// to be considered equal. +// +// The maximum error of a single floating-point operation is 0.5 +// units in the last place. On Intel CPU's, all floating-point +// calculations are done with 80-bit precision, while double has 64 +// bits. Therefore, 4 should be enough for ordinary use. +// +// See the following article for more details on ULP: +// http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ +static const size_t kMaxUlps = 4; + +// Returns true if and only if this number is at most kMaxUlps ULP's away +// from rhs. In particular, this function: +// +// - returns false if either number is (or both are) NAN. +// - treats really large numbers as almost equal to infinity. +// - thinks +0.0 and -0.0 are 0 DLP's apart. +inline bool AlmostEquals(float lhs, float rhs) { + // The IEEE standard says that any comparison operation involving + // a NAN must return false. + if (std::isnan(lhs) || std::isnan(rhs)) + return false; + + Float l = {lhs}; + Float r = {rhs}; + + return DistanceBetweenSignAndMagnitudeNumbers(l.bits_, r.bits_) <= kMaxUlps; +} diff --git a/test/cpp/tensorexpr/padded_buffer.cpp b/test/cpp/tensorexpr/padded_buffer.cpp new file mode 100644 index 000000000000..424d82c77453 --- /dev/null +++ b/test/cpp/tensorexpr/padded_buffer.cpp @@ -0,0 +1,37 @@ +#include "test/cpp/tensorexpr/padded_buffer.h" + +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +int PaddedBufferBase::Index(const std::vector& indices) const { + TORCH_DCHECK_EQ(dims_.size(), indices.size()); + int total_index = 0; + for (const auto i : c10::irange(dims_.size())) { + total_index += indices[i] * strides_[i]; + } + return total_index; +} + +PaddedBufferBase::PaddedBufferBase( + const std::vector& dims, + // NOLINTNEXTLINE(modernize-pass-by-value) + const std::string& name) + : dims_(dims), name_(name), strides_(dims.size()) { + for (int i = (int)dims.size() - 1; i >= 0; --i) { + if (i == (int)dims.size() - 1) { + strides_[i] = 1; + } else { + strides_[i] = strides_[i + 1] * dims[i + 1]; + } + } + total_size_ = strides_[0] * dims[0]; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/padded_buffer.h b/test/cpp/tensorexpr/padded_buffer.h new file mode 100644 index 000000000000..b3e5227ae7e6 --- /dev/null +++ b/test/cpp/tensorexpr/padded_buffer.h @@ -0,0 +1,242 @@ +#pragma once + +#include +#include + +#include +#include "torch/csrc/jit/tensorexpr/eval.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +struct DefaultPaddedValue; + +template <> +struct DefaultPaddedValue { + static const int kValue = static_cast(0xDEADBEEF); +}; + +template <> +struct DefaultPaddedValue { + static const int8_t kValue = static_cast(0xBE); +}; + +template <> +struct DefaultPaddedValue { + static const uint8_t kValue = static_cast(0xBE); +}; + +template <> +struct DefaultPaddedValue { + static const int16_t kValue = static_cast(0xBEEF); +}; + +template <> +struct DefaultPaddedValue { + static const int64_t kValue = static_cast(0xDEADBEEF); +}; + +template <> +struct DefaultPaddedValue { + static constexpr float kValue = 0.1357; +}; + +template <> +struct DefaultPaddedValue { + // at::Half ctor isn't constexpr, so just fill it with bits. + static constexpr uint16_t kValue = 1357; +}; + +template <> +struct DefaultPaddedValue { + static constexpr double kValue = 0.1357; +}; + +// A concrete base to be used in PaddedBase. +class PaddedBufferBase { + public: + const std::string& name() const { + return name_; + } + + int size() const { + return total_size_; + } + + int raw_size() const { + return total_size_ + 2 * kPaddingSize; + } + + virtual ~PaddedBufferBase() {} + + protected: + explicit PaddedBufferBase( + const std::vector& dims, + const std::string& name); + int Index(const std::vector& indices) const; + + std::vector dims_; + std::string name_; + std::vector strides_; + int total_size_; // total number of useful element, does not include the + // paddings + static constexpr int kPaddingSize = 64; +}; + +// A padded buffer with wartermarks for testing. +// The buffer carries padded watermarks on both sides to catch potential +// out-of-bounds writes. For read-only data that are not supposed to change, it +// can also make a backup and be compared later. +template +class PaddedBuffer : public PaddedBufferBase { + public: + PaddedBuffer(int d0, const std::string& name = "") + : PaddedBuffer(std::vector({d0}), name) {} + PaddedBuffer(int d0, int d1, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1}), name) {} + PaddedBuffer(int d0, int d1, int d2, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1, d2}), name) {} + PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1, d2, d3}), name) {} + PaddedBuffer(const std::vector& dims, const std::string& name = "") + : PaddedBufferBase(dims, name) { + data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue); + } + PaddedBuffer(const PaddedBuffer& other, const std::string& name) + : PaddedBuffer(other) { + this->name_ = name; + } + + T* data() { + return data_.data() + kPaddingSize; + } + const T* data() const { + return const_cast(this)->data(); + } + T* raw_data() { + return data_.data(); + } + const T* raw_data() const { + return const_cast(this)->raw_data(); + } + T& operator()(int i0) { + // There is a bit performance impact with forming a vector here. But this + // data structure is for testing only, and not performance critical. + return this->operator()(std::vector({i0})); + } + const T& operator()(int i0) const { + return const_cast(this)->operator()(i0); + } + T& operator()(int i0, int i1) { + return this->operator()(std::vector({i0, i1})); + } + const T& operator()(int i0, int i1) const { + return const_cast(this)->operator()(i0, i1); + } + T& operator()(int i0, int i1, int i2) { + return this->operator()(std::vector({i0, i1, i2})); + } + const T& operator()(int i0, int i1, int i2) const { + return const_cast(this)->operator()(i0, i1, i2); + } + T& operator()(int i0, int i1, int i2, int i3) { + return this->operator()(std::vector({i0, i1, i2, i3})); + } + const T& operator()(int i0, int i1, int i2, int i3) const { + return const_cast(this)->operator()(i0, i1, i2, i3); + } + T& operator()(const std::vector& indices) { + return data_[kPaddingSize + Index(indices)]; + } + const T& operator()(const std::vector& indices) const { + return const_cast(this)->operator()(indices); + } + + template + friend void ExpectAllNear( + const PaddedBuffer& v1, + const PaddedBuffer& v2, + float abs_error); + template + friend void ExpectAllEqual( + const PaddedBuffer& v1, + const PaddedBuffer& v2); + void Backup() { + backup_data_ = data_; + } + + // Verify the watermarks in the paddings are intact. + void ValidateWatermark() const { + for (const auto i : c10::irange(kPaddingSize)) { + ASSERT_EQ(data_[i], kPaddingValue); + ASSERT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue); + } + } + + void CheckBackup() const { + ValidateWatermark(); + DCHECK(backup_data_.size() == data_.size()) + << "Please make sure you have call Backup() before calling CheckBackup()"; + for (const auto i : c10::irange(total_size_)) { + ASSERT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]); + } + } + + private: + std::vector data_; + std::vector backup_data_; + T kPaddingValue = DefaultPaddedValue::kValue; +}; + +template +inline CodeGen::CallArg::CallArg(const PaddedBuffer& buffer) + : data_(const_cast(buffer.data())) {} + +template +std::string CompareErrorMsg( + const PaddedBuffer& v1, + const PaddedBuffer& v2, + int index) { + std::ostringstream oss; + oss << "index: " << index << ", v1: (" << v1.name() << ", " << v1(index) + << ")" + << ", v2: (" << v2.name() << ", " << v2(index) << ")"; + return oss.str(); +} + +template +void ExpectAllEqual(const PaddedBuffer& f1, const PaddedBuffer& f2) { + const std::vector& v1 = f1.data_; + const std::vector& v2 = f2.data_; + const int kPaddingSize = f1.kPaddingSize; + const int total_size = f1.total_size_; + ASSERT_EQ(v1.size(), v2.size()); + f1.ValidateWatermark(); + f2.ValidateWatermark(); + for (const auto i : c10::irange(total_size)) { + ASSERT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]); + } +} + +template +void ExpectAllNear( + const PaddedBuffer& f1, + const PaddedBuffer& f2, + float abs_error) { + const std::vector& v1 = f1.data_; + const std::vector& v2 = f2.data_; + const int kPaddingSize = f1.kPaddingSize; + const int total_size = f1.total_size_; + ASSERT_EQ(v1.size(), v2.size()); + f1.ValidateWatermark(); + f2.ValidateWatermark(); + for (const auto i : c10::irange(total_size)) { + ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error); + } +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_approx.cpp b/test/cpp/tensorexpr/test_approx.cpp new file mode 100644 index 000000000000..e1a576aecf52 --- /dev/null +++ b/test/cpp/tensorexpr/test_approx.cpp @@ -0,0 +1,96 @@ +#ifdef TORCH_ENABLE_LLVM + +#include +#include +#include +#include +#include +#include +#include + +using namespace torch::indexing; +namespace te = torch::jit::tensorexpr; + +static void vectorize(te::LoopNest* ln, te::Tensor target, int width) { + auto loops = ln->getLoopStmtsFor(target); + te::ForPtr inner, tail; + ln->splitWithTail(loops[0], width, &inner, &tail); + ASSERT_TRUE(te::LoopNest::vectorize(inner)); +} + +std::string diffs(const at::Tensor& a, const at::Tensor& b) { + auto diff = torch::abs(a.flatten() - b.flatten()); + auto count_diffs = torch::sum(diff > 0.f); + auto greatest_diff_index = torch::argmax(diff); + std::stringstream ss; + ss << "Found " << count_diffs << " unequal element(s). " + << "The greatest difference was " << diff.index({greatest_diff_index}) + << " at index " << greatest_diff_index; + return ss.str(); +} + +TEST(Approx, log_vml) { + te::VarHandle N("N", te::kInt); + te::BufHandle A("A", {N}, te::kFloat); + te::Tensor B = te::Compute( + "B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); }); + + te::LoopNest ln({B}); + ln.prepareForCodegen(); + vectorize(&ln, B, 8); + te::StmtPtr s = ln.root_stmt(); + s = te::IRSimplifier::simplify(s); + te::LLVMCodeGen cg(s, {A, B, N}); + + auto eps = std::numeric_limits::epsilon(); + auto test = [&](const at::Tensor& A_t) { + at::Tensor B_ref = at::log(A_t); + at::Tensor B_t = at::empty_like(A_t); + auto ap = A_t.data_ptr(); + auto bp = B_t.data_ptr(); + cg.call({ap, bp, A_t.numel()}); + // Results should be bit-identical. + ASSERT_TRUE(torch::allclose( + B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true)) + << "Input[:8]\n" + << A_t.index({Slice(0, 8)}) << "\n" + << "Test[:8]\n" + << B_t.index({Slice(0, 8)}) << "\n" + << "Ref[:8]\n" + << B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref); + }; + + // Generate every single-precision FP value in [1.0, 2.0). + at::Tensor A_t = torch::arange(1.0f, 2.0f, eps); + ASSERT_EQ(A_t.numel(), 1 << 23); + + test(A_t); + + test(A_t * 2.0f); + test(A_t * 0.5f); + + test(A_t * 4.0f); + test(A_t * 0.25f); + + test(A_t * powf(2.0f, 16)); + test(A_t * powf(2.0f, -16)); + + test(A_t * powf(2.0f, 126)); + test(A_t * powf(2.0f, -126)); + + test(torch::full({32}, INFINITY)); + test(torch::full({32}, NAN)); + + auto min = std::numeric_limits::min(); + auto denorm_min = std::numeric_limits::denorm_min(); + + // Denormals aren't bit precise, because sleef isn't bit-precise either. + A_t = torch::arange(0.0f, min, denorm_min); + ASSERT_EQ(A_t.numel(), 1 << 23); + auto B_ref = at::log(A_t); + auto B_t = at::empty_like(B_ref); + cg.call({A_t.data_ptr(), B_t.data_ptr(), A_t.numel()}); + ASSERT_TRUE(torch::allclose(B_t, B_ref)); +} + +#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp new file mode 100644 index 000000000000..34ce2bd069d5 --- /dev/null +++ b/test/cpp/tensorexpr/test_aten.cpp @@ -0,0 +1,1068 @@ +#include +#include +#include + +#include + +#include +#include +#include "test/cpp/tensorexpr/padded_buffer.h" +#include "test/cpp/tensorexpr/test_base.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +TEST(ATen, _cast_Float) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle to_float = Cast::make(kFloat, load_a); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), static_cast(i)); + } +} + +TEST(ATen, negInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle to_float = Sub::make(0, load_a); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), -static_cast(i)); + } +} + +TEST(ATen, negFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle to_float = Sub::make(0, load_a); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), -i); + } +} + +TEST(ATen, addInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); + } +} + +TEST(ATen, addFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); + } +} + +TEST(ATen, subInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); + } +} + +TEST(ATen, subFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); + } +} + +TEST(ATen, lerp) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i))); + } +} + +TEST(ATen, addcmulInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); + BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + ExprHandle load_d = d_buf.load(index); + StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + PaddedBuffer e_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + d_v(i) = 5 * i + 3; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); + ir_eval(a_v, b_v, c_v, d_v, e_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), 5 * i + 3); + ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); + } +} + +TEST(ATen, addcmulFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); + BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + ExprHandle load_d = d_buf.load(index); + StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + PaddedBuffer e_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + d_v(i) = 5 * i + 3; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); + ir_eval(a_v, b_v, c_v, d_v, e_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), 5 * i + 3); + ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); + } +} + +TEST(ATen, mulInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, load_a * load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); + } +} + +TEST(ATen, mulFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, load_a * load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); + } +} + +TEST(ATen, divInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, load_a / load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = 2 * i + 1; + b_v(i) = i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), 2 * i + 1); + ASSERT_EQ(b_v(i), i + 1); + ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); + } +} + +TEST(ATen, divFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, load_a / load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = 2 * i + 1; + b_v(i) = i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), 2 * i + 1); + ASSERT_EQ(b_v(i), i + 1); + ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); + } +} + +TEST(ATen, maxInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i))); + } +} + +TEST(ATen, maxFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i))); + } +} + +TEST(ATen, minInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i))); + } +} + +TEST(ATen, minFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i))); + } +} + +void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 1.0f / i); + } +} + +TEST(ATen, reluInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i - 64; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i - 64); + ASSERT_EQ(b_v(i), std::max(a_v(i), 0)); + } +} + +TEST(ATen, reluFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store( + {index}, Max::make(load_a, 0, false) // relu does not propagate nans + ); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i - 64; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i - 64); + ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0)); + } +} + +TEST(ATen, logFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i + 10); + ASSERT_EQ(b_v(i), std::log(a_v(i))); + } +} + +TEST(ATen, fastLogFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = at::randn({1}).item().to(); + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + auto test = b_v(i); + auto ref = std::log(a_v(i)); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_FLOAT_EQ(test, ref); + } + } +} + +TEST(ATen, fastTanhFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = at::randn({1}).item().to(); + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + auto test = b_v(i); + auto ref = std::tanh(a_v(i)); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_NEAR(test, ref, 1e-6); + } + } +} + +TEST(ATen, fastSigmoidFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = at::randn({1}).item().to(); + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + auto test = b_v(i); + at::Tensor t = at::ones({1}) * a_v(i); + float ref = at::sigmoid(t).item().to(); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_NEAR(test, ref, 1e-6); + } + } +} + +TEST(ATen, log10Float) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, log10(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i + 10); + ASSERT_EQ(b_v(i), std::log10(a_v(i))); + } +} + +TEST(ATen, log2Float) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, log2(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i + 10); + ASSERT_EQ(b_v(i), std::log2(a_v(i))); + } +} + +TEST(ATen, expFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, exp(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i / 10.0f); + ASSERT_EQ(b_v(i), std::exp(a_v(i))); + } +} + +TEST(ATen, erfFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, erf(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i / 10.0f); + ASSERT_EQ(b_v(i), std::erf(a_v(i))); + } +} + +TEST(ATen, cosFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, cos(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i / 10.0f); + ASSERT_EQ(b_v(i), std::cos(a_v(i))); + } +} + +TEST(ATen, eqInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kEQ))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +TEST(ATen, geInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kGE))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +TEST(ATen, gtInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 6); + std::vector b_buffer(N, 3); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kGT))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +TEST(ATen, leInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kLE))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +TEST(ATen, ltInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kLT))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 0); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_base.h b/test/cpp/tensorexpr/test_base.h new file mode 100644 index 000000000000..68b96fe6c90f --- /dev/null +++ b/test/cpp/tensorexpr/test_base.h @@ -0,0 +1,89 @@ +#pragma once + +#if defined(USE_GTEST) +#include +#include +#else +#include +#include "c10/util/Exception.h" +#include "test/cpp/tensorexpr/gtest_assert_float_eq.h" +#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__) +#define ASSERT_FLOAT_EQ(x, y, ...) \ + TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__) +#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__) +#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__) +#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__) +#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__) +#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__) + +#define ASSERT_NEAR(x, y, a, ...) \ + TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__) + +#define ASSERT_TRUE TORCH_INTERNAL_ASSERT +#define ASSERT_FALSE(x) ASSERT_TRUE(!(x)) +#define ASSERT_THROWS_WITH(statement, substring) \ + try { \ + (void)statement; \ + ASSERT_TRUE(false); \ + } catch (const std::exception& e) { \ + ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ + } +#define ASSERT_ANY_THROW(statement) \ + { \ + bool threw = false; \ + try { \ + (void)statement; \ + } catch (const std::exception& e) { \ + threw = true; \ + } \ + ASSERT_TRUE(threw); \ + } + +#endif // defined(USE_GTEST) +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +void ExpectAllNear( + const std::vector& v1, + const std::vector& v2, + V threshold, + const std::string& name = "") { + ASSERT_EQ(v1.size(), v2.size()); + for (size_t i = 0; i < v1.size(); i++) { + ASSERT_NEAR(v1[i], v2[i], threshold); + } +} + +template +void ExpectAllNear( + const std::vector& vec, + const U& val, + V threshold, + const std::string& name = "") { + for (size_t i = 0; i < vec.size(); i++) { + ASSERT_NEAR(vec[i], val, threshold); + } +} + +template +static void assertAllEqual(const std::vector& vec, const T& val) { + for (auto const& elt : vec) { + ASSERT_EQ(elt, val); + } +} + +template +static void assertAllEqual(const std::vector& v1, const std::vector& v2) { + ASSERT_EQ(v1.size(), v2.size()); + for (size_t i = 0; i < v1.size(); ++i) { + ASSERT_EQ(v1[i], v2[i]); + } +} +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp new file mode 100644 index 000000000000..2605842d6e74 --- /dev/null +++ b/test/cpp/tensorexpr/test_boundsinference.cpp @@ -0,0 +1,1019 @@ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +static void verifyConstBounds( + const TensorAccessBoundsInfo& access_info, + const std::vector>& ref) { + size_t ndim = ref.size(); + ASSERT_EQ(access_info.start.size(), ndim); + ASSERT_EQ(access_info.stop.size(), ndim); + for (const auto i : c10::irange(ndim)) { + if (ref[i].first >= 0) { // Negative values are used to skip the check + ASSERT_TRUE(access_info.start[i]->isConstant()); + int start_i = immediateAs(access_info.start[i]); + ASSERT_EQ(start_i, ref[i].first); + } + if (ref[i].second >= 0) { + ASSERT_TRUE(access_info.stop[i]->isConstant()); + int stop_i = immediateAs(access_info.stop[i]); + ASSERT_EQ(stop_i, ref[i].second); + } + } +} + +TEST(BoundsInference, _1) { + // Verify that bounds inference works for the following example: + // for i in 0..100: + // b[i] = a[i] + // For this loop bounds inference should yield the following: + // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}} + ExprHandle n(100); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + auto bounds_info = inferBounds(l.root_stmt()); + + // We should have two entries: one for 'b' and one for 'a'. + ASSERT_EQ(bounds_info.size(), 2); + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 99}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); +} + +TEST(BoundsInference, _2) { + // Verify that bounds inference works for the following example: + // for i in 0..n: + // b[i] = a[i] + // For this loop bounds inference should yield the following: + // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}} + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + auto bounds_info = inferBounds(l.root_stmt()); + + // We should have two entries: one for 'b' and one for 'a'. + ASSERT_EQ(bounds_info.size(), 2); + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, -1}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}}); +} + +TEST(BoundsInference, _3) { + // Verify that bounds inference works for the following example: + // for i in 0..100: + // b[i] = a[i] * a[i+10] + // For this loop bounds inference should yield the following: + // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}} + ExprHandle n(100); + BufHandle a("a", {n + 10}, kFloat); + Tensor b = Compute( + "b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); }); + LoopNest l({b}); + auto bounds_info = inferBounds(l.root_stmt()); + + // We should have two entries: one for 'b' and one for 'a'. + ASSERT_EQ(bounds_info.size(), 2); + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 109}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); +} + +TEST(BoundsInference, _4) { + // Verify that bounds inference works for the following example: + // + // for y in 0..200: + // for x in 0..320: + // b[y,x] = x*y + // for y in 0..200: + // for x in 0..320: + // c[y,x] = a[y,x] * b[y,x] + ExprHandle W(320); + ExprHandle H(200); + BufHandle a("a", {H, W}, kFloat); + Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { + return x * y; + }); + Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) { + return a.load(y, x) * b.load(y, x); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + StmtPtr body = l.getLoopBodyFor(c); + { + // Infer bounds on the top-level loop scope + auto bounds_info = inferBounds(loops[0]); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 199}, {0, 319}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}}); + } + { + // Infer bounds on the inner loop scope + auto bounds_info = inferBounds(loops[1]); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {0, 319}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}}); + } + { + // Infer bounds on the inner loop body's scope + auto bounds_info = inferBounds(body); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); + } +} + +TEST(BoundsInference, _5) { + // Verify that bounds inference works for the following example: + // for i in 0..100: + // b[i] = a[i] + // + // ==> split ==> + // + // for i_outer in 0..100/16: + // for i_inner in 0..16: + // b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner] + // for i_tail in 0..100%16: + // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16]; + ExprHandle n(100); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + + ForPtr inner; + ForPtr tail; + std::vector loops = l.getLoopStmtsFor(b); + LoopNest::splitWithTail(loops[0], 16, &inner, &tail); + ForPtr outer = loops[0]; + + { + // Verify inferred bounds for the outer loop + auto bounds_info = inferBounds(outer); + ASSERT_EQ(bounds_info.size(), 2); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 95}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}}); + } + { + // Verify inferred bounds for the tail loop + auto bounds_info = inferBounds(tail); + ASSERT_EQ(bounds_info.size(), 2); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{96, 99}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}}); + } +} + +TEST(BoundsInference, _6) { + // Verify that bounds inference works for the following example: + // + // for y in 0..200: + // for x in 0..320: + // b[y,x] = x*y + // for y in 0..20: + // for x in 0..32: + // c[y,x] = a[y+100,x+100] * b[y*2,x*5] + ExprHandle W(320); + ExprHandle H(200); + ExprHandle CW(32); + ExprHandle CH(20); + BufHandle a("a", {H, W}, kFloat); + Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { + return x * y; + }); + Tensor c = + Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) { + return a.load(y + 100, x + 100) * b.load(y * 2, x * 5); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + StmtPtr body = l.getLoopBodyFor(c); + { + // Infer bounds on the top-level loop scope + auto bounds_info = inferBounds(loops[0]); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{100, 119}, {100, 131}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}}); + } + { + // Infer bounds on the inner loop scope + auto bounds_info = inferBounds(loops[1]); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {100, 131}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}}); + } + { + // Infer bounds on the inner loop body's scope + auto bounds_info = inferBounds(body); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); + } +} + +TEST(BoundsInference, Adjacent) { + ExprHandle H(6); + BufHandle a("a", {20}, kFloat); + Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); }); + Tensor c = + Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); }); + LoopNest l({b, c}); + std::vector loops = NodeFinder::find(l.root_stmt()); + + { + // Infer bounds on the top-level loop scope + auto bounds_info = inferBounds(loops[0]); + ASSERT_EQ(bounds_info.size(), 2); + + // reads from a[0:5], writes to b[0:5] + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 5}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); + } + { + // Infer bounds on the inner loop scope + auto bounds_info = inferBounds(loops[1]); + ASSERT_EQ(bounds_info.size(), 2); + + // reads from a[0+6:5+6], writes to c[0:5] + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{6, 11}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); + } + { + // Infer bounds on the high level program. + auto bounds_info = inferBounds(l.root_stmt()); + ASSERT_EQ(bounds_info.size(), 3); + + // Should be union of above 2 bounds, but this time the bounds of A can be + // merged. + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 11}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); + } +} + +TEST(BoundsInference, MultipleTopLoopLoad) { + BufHandle a("a", {100}, kFloat); + Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); }); + Tensor c = + Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); }); + Tensor d = + Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); }); + LoopNest l({b, c, d}); + + auto bounds_info = inferBounds(l.root_stmt()); + + ASSERT_EQ(bounds_info.size(), 4); + + // a only read. + { + auto bounds = bounds_info[a.node()]; + ASSERT_EQ(bounds.size(), 1); + // One dimension. + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); + // Bounds: + // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b). + // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 = + // 96 + 2 - 1 (d). + verifyConstBounds(bound, {{0, 97}}); + } + + // b, c, d only written. + { + auto bounds = bounds_info[b.buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for b. + verifyConstBounds(bound, {{0, 63}}); + } + { + auto bounds = bounds_info[c.buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for c. + verifyConstBounds(bound, {{0, 31}}); + } + { + auto bounds = bounds_info[d.buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for d. + verifyConstBounds(bound, {{0, 95}}); + } +} + +TEST(BoundsInference, MultipleTopLoopStore) { + BufHandle a("a", {100}, kFloat); + BufHandle b("b", {100}, kFloat); + BufHandle c("c", {100}, kFloat); + BufHandle d("d", {100}, kFloat); + VarHandle x("x", kInt); + + // Same as above but the offsets are on the Store now. + // Can't do this through ComputeAPI without transforms we don't have yet. + StmtPtr stmt = Block::make( + {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))), + For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))), + For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))}); + + auto bounds_info = inferBounds(stmt); + + ASSERT_EQ(bounds_info.size(), 4); + + // a only read. + { + auto bounds = bounds_info[a.node()]; + ASSERT_EQ(bounds.size(), 1); + // One dimension. + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); + // Bounds: there are no offsets, so this is just the max loop bounds. + verifyConstBounds(bound, {{0, 95}}); + } + + // b, c, d only written. + { + auto bounds = bounds_info[b.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the b loop. + // b loop has no offset, so just the loop extents. + verifyConstBounds(bound, {{0, 63}}); + } + { + auto bounds = bounds_info[c.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the c loop. + // Offset is 10, extent is 32-1. + verifyConstBounds(bound, {{10, 41}}); + } + { + auto bounds = bounds_info[d.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the d loop. + // Offset is 2, extent is 96-1. + verifyConstBounds(bound, {{2, 97}}); + } +} + +TEST(BoundsInference, CacheReads) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 3); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}); + auto bounds_info_before = inferBounds(l.root_stmt()); + + StmtPtr j_loop = l.getLoopStmtsFor(B)[1]; + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); + + auto bounds_info_after = inferBounds(l.root_stmt()); + + // CacheAccesses should not change existing bounds, but add a new one for the + // cache. + for (auto& pair : bounds_info_after) { + auto beforeIt = bounds_info_before.find(pair.first); + if (beforeIt != bounds_info_before.end()) { + // Same number of TensorAccessBoundInfos. + ASSERT_EQ(pair.second.size(), beforeIt->second.size()); + + for (const auto i : c10::irange(pair.second.size())) { + TensorAccessBoundsInfo& after = pair.second[i]; + TensorAccessBoundsInfo& before = beforeIt->second[i]; + // Same number of dimensions. + ASSERT_EQ(before.start.size(), after.start.size()); + + // Bounds are equal. + for (const auto j : c10::irange(before.start.size())) { + ASSERT_TRUE(exprEquals(before.start[j], after.start[j])); + ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j])); + } + } + } else { + // This should be the cache. + ASSERT_EQ(pair.first->name_hint(), "A_local"); + // Should have both a load and a store. + ASSERT_EQ(pair.second.size(), 2); + TensorAccessBoundsInfo& first = pair.second[0]; + TensorAccessBoundsInfo& second = pair.second[1]; + + ASSERT_NE(first.kind, second.kind); + // 2 dimensions. + ASSERT_EQ(first.start.size(), second.start.size()); + ASSERT_EQ(first.start.size(), 2); + + // bounds for load and store are equal. + for (const auto j : c10::irange(first.start.size())) { + ASSERT_TRUE(exprEquals(first.start[j], second.start[j])); + ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j])); + } + } + } +} + +TEST(BoundsInference, Flattened) { + Tensor b = Compute( + "b", + {3, 4, 5}, + [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) { + return x * y + z; + }); + + LoopNest l({b}); + // Flatten indices. + l.prepareForCodegen(); + auto bounds_info = inferBounds(l.root_stmt()); + + // There's only one buffer. + ASSERT_EQ(bounds_info.size(), 1); + auto& TABI = bounds_info[b.buf()][0]; + ASSERT_EQ(TABI.kind, TensorAccessKind::kStore); + // Flattened bounds should have a single dimension. + ASSERT_EQ(TABI.start.size(), 1); + ASSERT_EQ(TABI.stop.size(), 1); + + // Bounds should be 0 -> (3*4*5)-1 + ASSERT_TRUE(exprEquals(TABI.start[0], alloc(0))); + ASSERT_TRUE(exprEquals(TABI.stop[0], alloc(3 * 4 * 5 - 1))); +} + +TEST(BoundsInference, GetPotentialHazards) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* + * A[0] = B[0]; + * B[0] = 3; WAR on B + * A[0] = B[0]; WAW on A, RAW on B + * C[0] = 5; + */ + + StorePtr store1 = Store::make(a, {0}, Load::make(b, {0})); + StorePtr store2 = Store::make(b, {0}, 3); + StorePtr store3 = Store::make(a, {0}, Load::make(b, {0})); + StorePtr store4 = Store::make(c, {0}, 5); + StmtPtr stmt = Block::make({store1, store2, store3, store4}); + + MemDependencyChecker analyzer; + stmt->accept(&analyzer); + + ASSERT_EQ( + HazardKind::WriteAfterRead, + getPotentialHazards(analyzer, store1, store2)); + + ASSERT_EQ( + HazardKind::ReadAfterWrite, + getPotentialHazards(analyzer, store2, store3)); + + ASSERT_EQ( + HazardKind::WriteAfterWrite, + getPotentialHazards(analyzer, store1, store3)); + + // Fourth store has no dependencies + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store1, store4)); + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store2, store4)); + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store3, store4)); + } +} + +TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return (i + 1) * (j + 1); + }); + + LoopNest l({A, B}); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; + ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; + + // No dependencies between loops. + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, loopRootA, loopRootB)); +} + +TEST(BoundsInference, GetPotentialHazardsLoopCall) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i, j) + 5; + }); + + LoopNest l({A, B}); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; + ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; + + ASSERT_EQ( + HazardKind::ReadAfterWrite, + getPotentialHazards(analyzer, loopRootA, loopRootB)); +} + +TEST(BoundsInference, GetPotentialHazardsLoopSplit) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + + LoopNest l({A}); + ForPtr inner, tail; + + // Splitting with tail by something offset creates a tail which also writes to + // A. + ForPtr outer = l.getLoopStmtsFor(A)[0]; + // `outer` loop get transformed to the outer loop after splitting. + LoopNest::splitWithTail(outer, 5, &inner, &tail); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + ASSERT_EQ( + HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail)); +} + +TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) { + // Input IR: + // for (const auto j : c10::irange(10, 100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(10, 100)) { + // A[k-1] = 20 * k; + // } + BufHandle a_buf("A", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) { + // Input IR: + // for (const auto j : c10::irange(10, 100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(10, 100)) { + // A[k] = 20 * k; + // } + BufHandle a_buf("A", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 10, 100, Store::make(a_buf, {k}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) { + // Input IR: + // for (const auto j : c10::irange(10, 100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(10, 100)) { + // B[k] = A[k]; + // } + BufHandle a_buf("A", {200}, kInt); + BufHandle b_buf("B", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(b_buf, {k}, Load::make(a_buf, {k}))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) { + // Input IR: + // for (const auto j : c10::irange(10, 100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(10, 100)) { + // A[k+100] = 20 * k; + // } + BufHandle a_buf("A", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) { + // Input IR: + // for (const auto i : c10::irange(20)) { + // for (const auto j : c10::irange(100)) { + // A[i,j] = i * j * 500; + // } + // } + // for (const auto m : c10::irange(20)) { + // for (const auto n : c10::irange(50)) { + // A[m+1,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); + auto forJ = For::make(j, 0, 100, storeA1); + auto forI = For::make(i, 0, 20, forJ); + auto storeA2 = + Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))); + auto forN = For::make(n, 0, 50, storeA2); + auto forM = For::make(m, 0, 20, forN); + auto par = Block::make({forI, forM}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forI, forM)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forM, forI)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forN)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forN, forJ)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, storeA2)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA2, storeA1)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, storeA2)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, forM)); +} + +TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) { + // Input IR: + // for (const auto i : c10::irange(20)) { + // for (const auto j : c10::irange(100)) { + // A[i,j] = i * j * 500; + // } + // } + // for (const auto m : c10::irange(20)) { + // for (const auto n : c10::irange(50)) { + // A[m+20,n+100] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); + auto forJ = For::make(j, 0, 100, storeA1); + auto forI = For::make(i, 0, 20, forJ); + auto storeA2 = + Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); + auto forN = For::make(n, 0, 50, storeA2); + auto forM = For::make(m, 0, 20, forN); + auto par = Block::make({forI, forM}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); +} + +TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) { + // Input IR: + // for (const auto i : c10::irange(20)) { + // for (const auto j : c10::irange(100)) { + // A[i,j] = i * j * 500; + // } + // } + // for (const auto m : c10::irange(20)) { + // for (const auto n : c10::irange(50)) { + // B[m,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); + auto forJ = For::make(j, 0, 100, storeA1); + auto forI = For::make(i, 0, 20, forJ); + auto storeA2 = Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))); + auto forN = For::make(n, 0, 50, storeA2); + auto forM = For::make(m, 0, 20, forN); + auto par = Block::make({forI, forM}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); +} + +TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) { + // Input IR: + // for (const auto j : c10::irange(100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(100)) { + // B[k] = 20 * A[99-k]; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make( + k, + 0, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) { + // Input IR: + // for (const auto k : c10::irange(100)) { + // B[k] = 20 * A[99-k]; + // } + // for (const auto j : c10::irange(100)) { + // A[j] = 10 * j; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forK = For::make( + k, + 0, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto par = Block::make({forK, forJ}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapWithLoads) { + // Input IR: + // for (const auto k : c10::irange(10, 100)) { + // B[k] = 20 * A[99-k]; + // } + // for (const auto j : c10::irange(10, 100)) { + // C[j] = 10 * A[j]; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + BufHandle c_buf("C", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forK = For::make( + k, + 10, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + auto forJ = For::make( + j, + 10, + 100, + Store::make(c_buf, {j}, Mul::make(10, Load::make(a_buf, {j})))); + auto par = Block::make({forK, forJ}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, IsOverlapping) { + // Input IR: + // for (const auto i : c10::irange(100)) { + // A[i] = i * 10; // storeA1 + // B[i] = A[99-i] * 20; // loadA1 + // C[i] = A[i + 100] * 10; // loadA2 + // A[i + 50] = i * 50; // storeA2 + // A[i + 150] = i * 150; // storeA3 + // } + BufHandle a_buf("A", {300}, kInt); + BufHandle b_buf("B", {100}, kInt); + BufHandle c_buf("C", {100}, kInt); + VarHandle i("i", kInt); + auto storeA1 = Store::make(a_buf, {i}, i * 10); + auto loadA1 = Load::make(a_buf, {ExprHandle(99) - i}); + auto storeB = Store::make(b_buf, {i}, Mul::make(loadA1, 20)); + auto loadA2 = Load::make(a_buf, {i + 100}); + auto storeC = Store::make(c_buf, {i}, Mul::make(loadA2, 10)); + auto storeA2 = Store::make(a_buf, {i + 50}, i * 50); + auto storeA3 = Store::make(a_buf, {i + 150}, i * 150); + auto forI = For::make( + i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3})); + tensorexpr::analysis::MemDependencyChecker analyzer; + forI->accept(&analyzer); + ASSERT_TRUE(isOverlapping(analyzer, storeA1, to(loadA1.node()))); + ASSERT_FALSE(isOverlapping(analyzer, storeA1, to(loadA2.node()))); + ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2)); + ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3)); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp new file mode 100644 index 000000000000..e72303873a6c --- /dev/null +++ b/test/cpp/tensorexpr/test_conv.cpp @@ -0,0 +1,234 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace te = torch::jit::tensorexpr; +namespace F = torch::nn::functional; + +#ifdef TORCH_ENABLE_LLVM + +// Generate test data with few bits of precision, to minimize error +// accumulation from floating-point reordering. +static at::Tensor genTestData(c10::IntArrayRef args) { + return at::trunc(at::randn(args) * 256.0f) / 256.0f; +} + +TEST(Conv, DepthwiseConv2D) { + constexpr int N = 1, C = 72, H = 56, W = 56; + constexpr int K = 72, R = 3, S = 3; + constexpr int kPad = 1, kStride = 2, kGroups = C; + constexpr int CperG = C / kGroups; + + te::BufHandle input("input", {N, C, H, W}, te::kFloat); + te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); + te::BufHandle bias("bias", {K}, te::kFloat); + te::Tensor output = + te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups); + + te::LoopNest loop({output}); + loop.simplify(); + loop.prepareForCodegen(); + te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output}); + + auto it = genTestData({N, C, H, W}); + auto wt = genTestData({K, CperG, R, S}); + auto bt = genTestData({K}); + auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups); + auto ot = at::zeros_like(ref); + cg.call( + {it.data_ptr(), + wt.data_ptr(), + bt.data_ptr(), + ot.data_ptr()}); + + ASSERT_TRUE(at::allclose(ref, ot)); +} + +TEST(Conv, DepthwiseConv2DNoBias) { + constexpr int N = 1, C = 72, H = 56, W = 56; + constexpr int K = 72, R = 3, S = 3; + constexpr int kPad = 1, kStride = 2, kGroups = C; + constexpr int CperG = C / kGroups; + + te::BufHandle input("input", {N, C, H, W}, te::kFloat); + te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); + te::Tensor output = + te::conv2d_depthwise(input, weight, kStride, kPad, kGroups); + + te::LoopNest loop({output}); + loop.simplify(); + loop.prepareForCodegen(); + te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output}); + + auto it = genTestData({N, C, H, W}); + auto wt = genTestData({K, CperG, R, S}); + auto ref = + at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); + auto ot = at::zeros_like(ref); + cg.call({it.data_ptr(), wt.data_ptr(), ot.data_ptr()}); + + ASSERT_TRUE(at::allclose(ref, ot)); +} + +TEST(Conv, DepthwiseConv2DDynamicShapes) { + te::VarHandle N_var("N", te::kInt); + te::VarHandle C_var("C", te::kInt); + te::VarHandle H_var("H", te::kInt); + te::VarHandle W_var("W", te::kInt); + te::VarHandle K_var("K", te::kInt); + te::VarHandle CperG_var("CperG", te::kInt); + te::VarHandle R_var("R", te::kInt); + te::VarHandle S_var("S", te::kInt); + te::VarHandle kPad_var("kPad", te::kInt); + te::VarHandle kStride_var("kStride", te::kInt); + te::VarHandle kGroups_var("kGroups", te::kInt); + + te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat); + te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat); + te::Tensor output = te::conv2d_depthwise( + input, + weight, + N_var, + C_var, + H_var, + W_var, + K_var, + CperG_var, + R_var, + S_var, + kStride_var, + kPad_var, + kGroups_var); + + te::LoopNest loop({output}); + loop.simplify(); + loop.prepareForCodegen(); + std::vector buffer_args = { + input, + weight, + N_var, + C_var, + H_var, + W_var, + K_var, + CperG_var, + R_var, + S_var, + kPad_var, + kStride_var, + kGroups_var, + output}; + te::LLVMCodeGen cg(loop.root_stmt(), buffer_args); + + constexpr int N = 1, C = 72, H = 56, W = 56; + constexpr int K = 72, R = 3, S = 3; + constexpr int kPad = 1, kStride = 2, kGroups = C; + constexpr int CperG = C / kGroups; + + auto it = genTestData({N, C, H, W}); + auto wt = genTestData({K, CperG, R, S}); + auto ref = + at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); + auto ot = at::zeros_like(ref); + std::vector call_args = { + it.data_ptr(), + wt.data_ptr(), + N, + C, + H, + W, + K, + CperG, + R, + S, + kPad, + kStride, + kGroups, + ot.data_ptr()}; + cg.call(call_args); + + ASSERT_TRUE(at::allclose(ref, ot)); +} + +#endif + +TEST(Conv, Conv2D) { + // Input dimensions. + constexpr int N = 1; + constexpr int C = 3; + constexpr int H = 11; + constexpr int W = 11; + + // Filter dimensions. + constexpr int K = 8; + constexpr int R = 3; + constexpr int S = 3; + + // Output dims. + constexpr int OH = H - R + 1; + constexpr int OW = W - S + 1; + + // Compute reference result. + at::Tensor input = torch::randn({N, C, H, W}); + at::Tensor filter = torch::randn({K, C, R, S}); + at::Tensor ref = F::conv2d(input, filter); + + // Double check the output size is as expected. + ASSERT_EQ(ref.size(0), N); + ASSERT_EQ(ref.size(1), K); + ASSERT_EQ(ref.size(2), OH); + ASSERT_EQ(ref.size(3), OW); + + te::BufHandle inputB("input", {N, C, H, W}, te::kFloat); + te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat); + + te::Tensor conv = te::Reduce( + "conv", + {N, K, OH, OW}, + te::Sum(), + // FIXME: We have to use a `std::vector` parameter here and then unpack + // it, because we don't have an overload allowing for an arbitrary number + // of ExprHandle/VarHandle parameters. + [&](const std::vector& v) { + auto const& n = v[0]; + auto const& k = v[1]; + auto const& oh = v[2]; + auto const& ow = v[3]; + auto const& c = v[4]; + auto const& r = v[5]; + auto const& s = v[6]; + // FIXME: We have to use `call` and construct a `std::vector` here + // because the `operator()` overload is only specialized for a small + // number of arguments. + return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s); + }, + // FIXME: If you forget one of the reduction dims, you get a segfault. + // Could that be caught by a verifier? + {C, R, S}); + + // FIXME: It'd be nice to have a single header that pulls in things like + // LoopNest, IRSimplifier, etc. + te::LoopNest loop({conv}); + loop.prepareForCodegen(); + te::StmtPtr s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + + at::Tensor result = at::empty_like(ref); + te::SimpleIREvaluator cg(s, {inputB, filterB, conv}); + cg.call( + {input.data_ptr(), + filter.data_ptr(), + result.data_ptr()}); + + ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3)); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_cpp_codegen.cpp b/test/cpp/tensorexpr/test_cpp_codegen.cpp new file mode 100644 index 000000000000..ed7679053637 --- /dev/null +++ b/test/cpp/tensorexpr/test_cpp_codegen.cpp @@ -0,0 +1,259 @@ +#include + +#include "test/cpp/tensorexpr/test_base.h" + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +#define STR_CHECK(node, expected) \ + std::stringstream ss; \ + CppPrinter printer(&ss); \ + printer.visit(node); \ + ASSERT_EQ(ss.str(), expected) + +#define FILE_CHECK(node, pattern) \ + std::stringstream ss; \ + CppPrinter printer(&ss); \ + printer.visit(node); \ + torch::jit::testing::FileCheck().run(pattern, ss.str()) + +TEST(CppPrinter, IntImm) { + auto i = alloc(10); + STR_CHECK(i, "10"); +} + +TEST(CppPrinter, FloatImm) { + auto f = alloc(10); + STR_CHECK(f, "10.f"); +} + +TEST(CppPrinter, FloatImm1) { + auto f = alloc(10); + STR_CHECK(f, "10.f"); +} + +TEST(CppPrinter, DoubleImm) { + auto d = alloc(10); + STR_CHECK(d, "10.0"); +} + +TEST(CppPrinter, DoubleImm1) { + auto d = alloc(10.1); + STR_CHECK(d, "10.1"); +} + +TEST(CppPrinter, HalfImm) { + auto h = alloc(10); + STR_CHECK(h, "10"); +} + +TEST(CppPrinter, Add) { + auto add = alloc(alloc(1), alloc(2)); + STR_CHECK(add, "1 + 2"); +} + +TEST(CppPrinter, AddExpr1) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc(alloc(2), alloc(3))); + STR_CHECK(add, "(0 + 1) + (2 - 3)"); +} + +TEST(CppPrinter, AddExpr2) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc(alloc(2), alloc(3))); + STR_CHECK(add, "0 * 1 + (2 - 3)"); +} + +TEST(CppPrinter, AddExpr3) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc
(alloc(2), alloc(3))); + STR_CHECK(add, "(0 + 1) + 2 / 3"); +} + +TEST(CppPrinter, Mod) { + auto mod = alloc(alloc(1), alloc(2)); + STR_CHECK(mod, "1 % 2"); +} + +TEST(CppPrinter, ModFloat) { + auto mod = alloc(alloc(1), alloc(2)); + STR_CHECK(mod, "std::fmod(1.f, 2.f)"); +} + +TEST(CppPrinter, Max) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "std::max(1, 2)"); +} + +TEST(CppPrinter, MaxFloat) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "std::max(1.f, 2.f)"); +} + +TEST(CppPrinter, MaxHalf) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "(1 < 2) ? 2 : 1"); +} + +TEST(CppPrinter, And) { + auto v = alloc(alloc(1), alloc(2)); + STR_CHECK(v, "1 & 2"); +} + +TEST(CppPrinter, CompareSelect) { + auto cs = alloc( + alloc(1), + alloc(2), + alloc(1), + alloc(2), + CompareSelectOperation::kLE); + STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)"); +} + +TEST(CppPrinter, IfThenElse) { + auto cond = alloc(alloc(1), alloc(2)); + auto true_value = alloc(alloc(0), alloc(1)); + auto false_value = alloc(alloc(2), alloc(3)); + auto v = alloc(cond, true_value, false_value); + STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)"); +} + +TEST(CppPrinter, AllocateFree) { + BufHandle buf("x", {2, 3}, kInt); + AllocatePtr alloc = Allocate::make(buf); + FreePtr free = Free::make(buf); + BlockPtr block = Block::make({alloc, free}); + + const std::string pattern = R"( + # CHECK: { + # CHECK: int* x = static_cast(malloc(24)); + # CHECK: free(x); + # CHECK: } + )"; + FILE_CHECK(block, pattern); +} + +TEST(CppPrinter, LoadStore) { + BufHandle a("A", {2, 3}, kInt); + BufHandle b("B", {3, 4}, kInt); + auto store = b.store({2, 2}, a.load(1, 1)); + STR_CHECK( + store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n"); +} + +TEST(CppPrinter, Var) { + auto var = alloc("x", kInt); + STR_CHECK(var, "x"); +} + +TEST(CppPrinter, Cast) { + auto cast = alloc(kFloat, alloc(1)); + STR_CHECK(cast, "static_cast(1)"); +} + +TEST(CppPrinter, BitCast) { + auto cast = alloc(kInt, alloc(20)); + STR_CHECK(cast, "std::bitcast(20.f)"); +} + +TEST(CppPrinter, Let) { + auto var = alloc("x", kFloat); + auto val = alloc(2); + auto let = alloc(var, val); + STR_CHECK(let, "float x = 2.f;\n"); +} + +TEST(CppPrinter, For) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + VarHandle i("i", kInt); + auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); + const std::string pattern = R"( + # CHECK: for (int i = 0; i < 1024; i++) { + # CHECK: C[i] = (A[i]) + (B[i]); + # CHECK: } + )"; + FILE_CHECK(f, pattern); +} + +TEST(CppPrinter, Cond) { + BufHandle x("X", {1}, kInt); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = + Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + const std::string pattern = R"( + # CHECK: if (((X[0] < 10) ? 1 : 0)) { + # CHECK: X[0] = (X[0]) + 1; + # CHECK: } else { + # CHECK: X[0] = (X[0]) - 1; + # CHECK: } + )"; + FILE_CHECK(cond, pattern); +} + +TEST(CppPrinter, Intrinsics) { + const std::unordered_set> unsupported_ops{ + kRand, kSigmoid}; + for (const auto i : c10::irange(static_cast(kMaxIntrinsicsOp))) { + IntrinsicsOp op = static_cast(i); + if (unsupported_ops.count(op)) { + continue; + } + + if (Intrinsics::OpArgCount(op) == 1) { + auto v = alloc(op, alloc(2.0f)); + STR_CHECK(v, "std::" + v->func_name() + "(2.f)"); + } else { + auto v = + alloc(op, alloc(1.0f), alloc(2.0f)); + STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)"); + } + } +} + +TEST(CppPrinter, ExternalCall) { + std::vector dims{alloc(2), alloc(2)}; + auto output = alloc("out", dims, kFloat); + auto buf_arg1 = alloc("a", dims, kFloat); + auto buf_arg2 = alloc("b", dims, kFloat); + auto scalar_arg = alloc(alloc(1), alloc(2)); + std::vector buf_args{buf_arg1, buf_arg2}; + std::vector scalar_args{scalar_arg}; + auto call = + alloc(output, "nnc_aten_matmul", buf_args, scalar_args); + const std::string pattern = R"( + # CHECK: { + # CHECK: void* buf_ptrs[]{out, a, b}; + # CHECK: int64_t buf_ranks[]{2, 2, 2}; + # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2}; + # CHECK: int8_t buf_dtypes[]{6, 6, 6}; + # CHECK: int64_t extra_args[]{1 + 2}; + # CHECK: nnc_aten_matmul( + # CHECK: 3, + # CHECK: buf_ptrs, + # CHECK: buf_ranks, + # CHECK: buf_dims, + # CHECK: buf_dtypes, + # CHECK: 1, + # CHECK: extra_args); + # CHECK: } + )"; + FILE_CHECK(call, pattern); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp new file mode 100644 index 000000000000..2e1e84e758db --- /dev/null +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -0,0 +1,2344 @@ +#ifdef USE_CUDA + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; +using namespace torch::jit::tensorexpr; + +template +static void testCudaTestVectorAdd01_impl() { + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Dtype dtype = ToDtype(); + BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); + BufHandle b_buf("b", {num_iter, block_count, block_size}, dtype); + Tensor c = Compute( + "c", + { + num_iter, + block_count, + block_size, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + loops[1]->set_gpu_block_index(0); + loops[2]->set_gpu_thread_index(0); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); + const int N = block_count * block_size * num_iter; + PaddedBuffer a_v(N); + PaddedBuffer b_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (const auto i : c10::irange(N)) { + a_v(i) = ctype(i); + b_v(i) = ctype(i * 3 + 7); + c_ref(i) = a_v(i) + b_v(i); + } + + // TODO: move gpu support into PaddedBuffer + ctype* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(ctype))); + ctype* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(ctype))); + ctype* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(ctype))); + C10_CUDA_CHECK( + cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); +} + +float sigmoid(float x) { + return 1.0f / (1.0f + expf(-0.0f - x)); +} + +TEST(Cuda, Sigmoid_CUDA) { + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Dtype dtype = ToDtype(); + BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); + Tensor c = Compute( + "c", + { + num_iter, + block_count, + block_size, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return sigmoid(sigmoid(a_buf.load(n, b_id, t_id))); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + loops[1]->set_gpu_block_index(0); + loops[2]->set_gpu_thread_index(0); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, a_buf); + const int N = block_count * block_size * num_iter; + PaddedBuffer a_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (const auto i : c10::irange(N)) { + a_v(i) = float(i); + c_ref(i) = sigmoid(sigmoid(a_v(i))); + } + + // TODO: move gpu support into PaddedBuffer + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); + C10_CUDA_CHECK( + cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, a_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); +} + +TEST(Cuda, TestVectorAdd01_CUDA) { + // floating types. + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + + // integer types. + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); +} + +static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) { + BufHandle a_buf("a", {N}, kFloat); + BufHandle b_buf("b", {N}, kFloat); + Tensor c = Compute("c", {N}, [&](const VarHandle& n) { + return a_buf.load(n) + b_buf.load(n); + }); + LoopNest l({c}); + ForPtr n_inner; + std::vector loops = l.getLoopStmtsFor(c); + l.splitWithMask(loops[0], block_size, &n_inner); + loops[0]->set_gpu_block_index(0); + n_inner->set_gpu_thread_index(0); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); + PaddedBuffer a_v(N); + PaddedBuffer b_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (const auto i : c10::irange(N)) { + a_v(i) = i; + b_v(i) = i * 3 + 7; + c_ref(i) = a_v(i) + b_v(i); + } + + // TODO: move gpu support into PaddedBuffer + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); + C10_CUDA_CHECK( + cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); +} + +TEST(Cuda, TestVectorAdd02_CUDA) { + testCudaTestVectorAdd02_impl(1024, 128); + testCudaTestVectorAdd02_impl(1030, 128); +} + +TEST(Cuda, HalfCast_CUDA) { + auto half = ToDtype(); + BufHandle a("a", {4}, half); + Tensor b = Compute("b", {4}, [&](const VarHandle& i) { + return Cast::make(kFloat, a.load(i)); + }); + + LoopNest l({b}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b}); + + std::vector aData(4, 2.0f); + std::vector bData(4, 0.0f); + at::Half* aDev = nullptr; + float* bDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto bSize = bData.size() * sizeof(bData[0]); + + C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); + C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); + C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev}); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy(bData.data(), bDev, bSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + assertAllEqual(bData, 2.0f); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); +} + +TEST(Cuda, DynamicShape2D_CUDA) { + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + BufHandle a("a", {m, n}, kFloat); + BufHandle b("b", {m, n}, kFloat); + Tensor c = + Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(i, j); + }); + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, c, m, n}); + + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + float* aDev = nullptr; + float* bDev = nullptr; + float* cDev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); + C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); + C10_CUDA_CHECK(cudaMalloc(&cDev, cData.size() * sizeof(cData[0]))); + C10_CUDA_CHECK(cudaMemcpy( + aDev, + aData.data(), + aData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + bDev, + bData.data(), + bData.size() * sizeof(bData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + cDev, + cData.data(), + cData.size() * sizeof(cData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev, cDev, M, N}); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + C10_CUDA_CHECK(cudaMemcpy( + cData.data(), + cDev, + cData.size() * sizeof(cData[0]), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); + C10_CUDA_CHECK(cudaFree(cDev)); + }; + testWithSize(32, 32); + testWithSize(1, 16); + testWithSize(27, 13); +} + +TEST(Cuda, TestRand01_CUDA) { + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Tensor c = Compute( + "c", + { + num_iter, + block_count, + block_size, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return Intrinsics::make(IntrinsicsOp::kRand, kFloat); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + loops[1]->set_gpu_block_index(0); + loops[2]->set_gpu_thread_index(0); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c); + const int N = block_count * block_size * num_iter; + PaddedBuffer c_v(N); + + // TODO: move gpu support into PaddedBuffer + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + float sum1 = 0; + float sum2 = 0; + float sum3 = 0; + for (const auto i : c10::irange(N)) { + float v = c_v.data()[i]; + sum1 += v; + sum2 += v * v; + sum3 += v * v * v; + ASSERT_TRUE(v >= 0 && v < 1); + } + sum1 /= N; + sum2 /= N; + sum3 /= N; + float sum1_mean = 1.f / 2; + float sum2_mean = 1.f / 3; + float sum3_mean = 1.f / 4; + + ASSERT_NEAR(sum1, sum1_mean, 2e-2); + ASSERT_NEAR(sum2, sum2_mean, 2e-2); + ASSERT_NEAR(sum3, sum3_mean, 2e-2); + C10_CUDA_CHECK(cudaFree(c_dev)); +} + +TEST(Cuda, DynamicShapeSplit_CUDA) { + constexpr int64_t N = 4096; + VarHandle n("n", kLong); + BufHandle a("a", {n}, kFloat); + Tensor b = + Compute("b", {n}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); + LoopNest l({b}); + ForPtr inner; + std::vector loops = l.getLoopStmtsFor(b); + l.splitWithMask(loops[0], 1024, &inner); + loops[0]->set_gpu_block_index(0); + inner->set_gpu_thread_index(0); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, n}); + + std::vector aData(N, 1.0f); + std::vector bData(N, 1.0f); + float* aDev = nullptr; + float* bDev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); + C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); + C10_CUDA_CHECK(cudaMemcpy( + aDev, + aData.data(), + aData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + bDev, + bData.data(), + bData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev, N}); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + C10_CUDA_CHECK(cudaMemcpy( + bData.data(), + bDev, + bData.size() * sizeof(aData[0]), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(bData, std::vector(N, 2.0f), 1e-7); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); +} + +TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { + const static int N = 1024; + BufHandle data_buf("data", {N}, kFloat); + BufHandle output_buf("output", {1}, kFloat); + + // The test adds the following code for trivial reduction: + // for (const auto bidx : c10::irange(1)) { // blockIdx.x + // for (const auto tidx : c10::irange(1)) { // threadIdx.x + // output[0] = 0.f; + // for (const auto i1 : c10::irange(1024)) { + // output[0] = output[0] + data[i1]; + // } + // } + // } + + StorePtr init_store = output_buf.store({0}, 0.f); + VarHandle i1("i1", kInt); + ExprHandle load_data = Load::make(data_buf, {i1}); + ExprHandle load_output = Load::make(output_buf, {0}); + ExprHandle add_value = load_output + load_data; + StorePtr store_output = output_buf.store({0}, add_value); + ForPtr for_output = For::make(i1, 0, N, store_output); + StmtPtr reduce_block = Block::make({init_store, for_output}); + VarHandle thread_idx("tidx", kInt); + LoopOptions thread_idx_options; + thread_idx_options.set_gpu_thread_index(0); + ForPtr thread_idx_loop = + For::make(thread_idx, 0, 1, reduce_block, thread_idx_options); + VarHandle block_idx("bidx", kInt); + LoopOptions block_idx_options; + block_idx_options.set_gpu_block_index(0); + ForPtr block_idx_loop = + For::make(block_idx, 0, 1, thread_idx_loop, block_idx_options); + + CudaCodeGen cuda_cg(block_idx_loop, data_buf, output_buf); + PaddedBuffer data_v(N); + PaddedBuffer output_v(1, "output_v"); + PaddedBuffer output_ref(1, "output_ref"); + + output_ref(0) = 0; + for (const auto i : c10::irange(N)) { + data_v(i) = i; + output_ref(0) += data_v(i); + } + + float* data_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&data_dev, N * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + data_dev, data_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + float* output_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&output_dev, 1 * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(data_dev, output_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + output_v.data(), output_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(output_v, output_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(data_dev)); + C10_CUDA_CHECK(cudaFree(output_dev)); +} + +TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { + const static int N = 1024; + + // This test does the following reduction: + // clang-format off + // for b in 0..1 // block-idx + // for t in 0..1024: // thread-idx + // if t < 1: + // b[0] = 0 + // // implied sync_threads + // for t in 0..1024: // thread-idx + // b[0] = b[0] + a[t] // implied atomic + // clang-format on + + BufHandle a_buf("a", {N}, kFloat); + BufHandle b_buf("b", {1}, kFloat); + + StorePtr init_store = b_buf.store({0}, 0.f); + VarHandle t("t", kInt); + VarHandle b("b", kInt); + + // for t in 0..1024: // thread-idx + // if t < 1: + // b[0] = 0 + ExprHandle cond_t_lt_1 = + CompareSelect::make(t, 1, CompareSelectOperation::kLT); + CondPtr masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr); + LoopOptions thread_idx_options; + thread_idx_options.set_gpu_thread_index(0); + ForPtr for_init = For::make(t, 0, N, masked_init_b, thread_idx_options); + + // for t in 0..1024: // thread-idx + // b[0] = b[0] + a[t] // implied atomic + ExprHandle load_a = Load::make(a_buf, {t}); + ExprHandle load_b = Load::make(b_buf, {0}); + ExprHandle add_value = load_b + load_a; + StorePtr store_b = b_buf.store({0}, add_value); + ForPtr for_b = For::make(t, 0, N, store_b, thread_idx_options); + + StmtPtr reduce_block = Block::make({for_init, for_b}); + + VarHandle block_idx("bidx", kInt); + LoopOptions block_idx_options; + block_idx_options.set_gpu_block_index(0); + ForPtr block_idx_loop = + For::make(block_idx, 0, 1, reduce_block, block_idx_options); + + CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); + PaddedBuffer a_v(N); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(N)) { + a_v(i) = i; + b_ref(0) += a_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); + C10_CUDA_CHECK( + cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(b_v, b_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); +} + +TEST(Cuda, NoThreadIdxWrite_1_CUDA) { + // This test does the following reduction: + // + // for k in 0..1: // block-idx + // a[0] = 0 + // for n in 0..2: + // a[0] = a[0] + n + // for m in 0..1024: // thread-idx + // b[m] = m + // a[1] = 1 + // for l in 0..2: + // a[1] = a[1] + n + // + // note that the statements not covered by thread-idx are supposed to be + // covered by its own thread-idx + + const static int N = 1024; + BufHandle a_buf("a", {2}, kFloat); + BufHandle b_buf("b", {N}, kFloat); + + VarHandle k("k", kInt); + VarHandle l("l", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + // a[0] = 0 + // for n in 0..2: + // a[0] = a[0] + n + StorePtr store_a0_0 = a_buf.store({0}, 0.f); + ExprHandle load_a0 = Load::make(a_buf, {0}); + ExprHandle v1 = load_a0 + n; + StorePtr store_a0_v1 = a_buf.store({0}, v1); + ForPtr loop_a_0 = For::make(n, 0, 2, store_a0_v1); + + // for m in 0..1024: // thread-idx + // b[m] = m + StorePtr store_bm_m = b_buf.store({m}, m + 0.f); + LoopOptions thread_idx_options; + thread_idx_options.set_gpu_thread_index(0); + ForPtr loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); + + // a[1] = 1 + // for l in 0..2: + // a[1] = a[1] + l + StorePtr store_a1_1 = a_buf.store({1}, 1.f); + ExprHandle load_a1 = a_buf.load(1); + ExprHandle v2 = load_a1 + l; + StorePtr store_a1_v2 = a_buf.store({1}, v2); + ForPtr loop_a_1 = For::make(l, 0, 2, store_a1_v2); + + StmtPtr reduce_block = + Block::make({store_a0_0, loop_a_0, loop_b_1, store_a1_1, loop_a_1}); + + VarHandle block_idx("bidx", kInt); + LoopOptions block_idx_options; + block_idx_options.set_gpu_block_index(0); + ForPtr block_idx_loop = + For::make(block_idx, 0, 1, reduce_block, block_idx_options); + + CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); + PaddedBuffer a_v(2); + PaddedBuffer b_v(N, "b_v"); + PaddedBuffer a_ref(2, "a_ref"); + PaddedBuffer b_ref(N, "b_ref"); + + a_ref(0) = 0; + for (const auto i : c10::irange(2)) { + a_ref(0) += i; + } + a_ref(1) = a_ref(0) + 1; + for (const auto i : c10::irange(N)) { + b_ref(i) = i; + } + + // TODO: add check of the generated code. + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, 2 * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(a_v.data(), a_dev, 2 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK( + cudaMemcpy(b_v.data(), b_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(a_v, a_ref, 1e-5); + ExpectAllNear(b_v, b_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); +} + +TEST(Cuda, SharedMemReduce_1_CUDA) { + // FIXME: this test is flaky in CI. + // This test does the following: + // for k in 0..1: // block-idx + // alloc(c, 64) + // for n in 0..64: // thread-idx + // c(n) = 0 + // for m in 0..128: + // for n in 0..64: // thread_idx + // c(n) = c(n) + a(k, m, n) + // b(k) = 0 + // for n in 0..64: // thread_idx + // b(k) = b(k) + c(n) + // free(c) + + const int M = 128; + const int N = 64; + const int kTotalSize = M * N; + LoopOptions thread_idx_opt; + thread_idx_opt.set_gpu_thread_index(0); + LoopOptions block_idx_opt; + block_idx_opt.set_gpu_block_index(0); + + BufHandle a("a", {1, M, N}, kFloat); + BufHandle b("b", {1}, kFloat); + VarHandle k("k", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + std::vector block; + std::vector dims; + dims.push_back(ExprHandle(N).node()); + BufHandle c{alloc("c", dims, kFloat)}; + { + // alloc(c, 64); + AllocatePtr alloc = Allocate::make(c); + block.push_back(alloc); + } + + { + // for n in 0..64: // thread-idx + // c(n) = 0 + StorePtr store_cn_0 = Store::make(c, {n}, 0.f); + ForPtr loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt); + block.push_back(loop_n1); + } + + { + // for m in 0..128: + // for n in 0..64: // thread_idx + // c(n) = c(n) + a(k, m, n) + ExprHandle load_cn = Load::make(kFloat, c, {n}); + ExprHandle a_kmn = Load::make(a, {k * (M * N) + m * N + n}); + ExprHandle v_add = load_cn + a_kmn; + StorePtr store_cn_v = Store::make(c, {n}, v_add); + ForPtr loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); + ForPtr loop_m1 = For::make(m, 0, M, loop_n2); + block.push_back(loop_m1); + } + + { + // b(k) = 0 + // for n in 0..64: // thread_idx + // b(k) = b(k) + c(n) + StorePtr store_bk_0 = b.store({k}, 0.f); + block.push_back(store_bk_0); + ExprHandle load_bk = b.load(k); + ExprHandle load_cn = Load::make(kFloat, c, {n}); + ExprHandle v_add = load_bk + load_cn; + StorePtr store_bk = b.store({k}, v_add); + ForPtr loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); + block.push_back(loop_n3); + } + + { + // free(c) + FreePtr free_stmt = Free::make(c); + block.push_back(free_stmt); + } + + BlockPtr reduce_body = Block::make(block); + ForPtr loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt); + + // TODO: check the generated code for correctness. + CudaCodeGen cuda_cg(loop_k1, a, b); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Check the c write is not masked, but the d write is. + const std::string& verification_pattern = + R"IR( +# CHECK: c_1 = 0 +# CHECK: for (int m = 0; m < 128 +# CHECK: c_1 = c_1 + +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<1 +# CHECK: b[blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: atomicAdd(&b[blockIdx.x], c_1) +)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(b_v, b_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); +} + +TEST(Cuda, LocalMemReduce_1_CUDA) { + // This test does the following: + // for k in 0..1: // block-idx + // b(k) = 0 + // for n in 0..64: // thread-idx + // alloc(c, 1) + // c(0) = 0 + // for m in 0..128: + // c(0) = c(0) + a(k, m, n) + // b(k) = b(k) + c(0) + // free(c) + + const int M = 128; + const int N = 64; + const int kTotalSize = M * N; + LoopOptions thread_idx_opt; + thread_idx_opt.set_gpu_thread_index(0); + LoopOptions block_idx_opt; + block_idx_opt.set_gpu_block_index(0); + + BufHandle a("a", {1, M, N}, kFloat); + BufHandle b("b", {1}, kFloat); + VarHandle k("k", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + BufHandle c{ + alloc("c", std::vector({alloc(1)}), kFloat)}; + std::vector block_k; + { + // b(k) = 0 + StorePtr store_bk_0 = b.store({k}, 0.f); + block_k.push_back(store_bk_0); + } + std::vector block_n; + { + // alloc(c, 1); + AllocatePtr alloc = Allocate::make(c); + block_n.push_back(alloc); + } + { + // c(0) = 0 + StorePtr store_c0_0 = Store::make(c, {0}, 0.f); + block_n.push_back(store_c0_0); + } + { + // for m in 0..128: + // c(0) = c(0) + a(k, m, n) + ExprHandle load_c0 = Load::make(kFloat, c, {0}); + ExprHandle a_kmn = a.load(k * (M * N) + m * N + n); + ExprHandle v_add = load_c0 + a_kmn; + StorePtr store_c0_v = Store::make(c, {0}, v_add); + ForPtr loop_m = For::make(m, 0, M, store_c0_v); + block_n.push_back(loop_m); + } + { + // b(k) = b(k) + c(0) + ExprHandle load_bk = b.load(k); + ExprHandle load_c0 = Load::make(kFloat, c, {0}); + ExprHandle v_add = load_bk + load_c0; + StorePtr store_bk = b.store({k}, v_add); + block_n.push_back(store_bk); + } + { + // free(c) + FreePtr free_stmt = Free::make(c); + block_n.push_back(free_stmt); + } + { + BlockPtr block_n_stmt = Block::make(block_n); + ForPtr for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt); + block_k.push_back(for_n); + } + BlockPtr block_k_stmt = Block::make(block_k); + ForPtr loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt); + + CudaCodeGen cuda_cg(loop_k, a, b); + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(b_v, b_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); +} + +TEST(Cuda, HalfSupport_CUDA) { + auto half = ToDtype(); + BufHandle a("a", {4}, half); + Tensor b = Compute("b", {4}, [&](const VarHandle& i) { + return Cast::make(half, ExprHandle(2.0f) * a.load(i)); + }); + + Tensor c = Compute("c", {4}, [&](const VarHandle& i) { + return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i)); + }); + + Tensor d = Compute("d", {4}, [&](const VarHandle& i) { + return Cast::make(half, c.load(i)); + }); + + LoopNest l({b, c, d}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, c, d}); + + std::vector aData(4, 2.0f); + std::vector cData(4, 0.0f); + std::vector dData(4, 0.0f); + at::Half* aDev = nullptr; + at::Half* bDev = nullptr; + at::Half* cDev = nullptr; + at::Half* dDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto bSize = aData.size() * sizeof(aData[0]); + auto cSize = cData.size() * sizeof(float); + auto dSize = dData.size() * sizeof(dData[0]); + + C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); + C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); + C10_CUDA_CHECK(cudaMalloc(&cDev, cSize)); + C10_CUDA_CHECK(cudaMalloc(&dDev, dSize)); + C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy(cDev, cData.data(), cSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy(dDev, dData.data(), dSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev, cDev, dDev}); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy(cData.data(), cDev, cSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy(dData.data(), dDev, dSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + assertAllEqual(cData, 46.0f); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); + C10_CUDA_CHECK(cudaFree(cDev)); + C10_CUDA_CHECK(cudaFree(dDev)); +} + +TEST(Cuda, HalfPropagation_CUDA) { + auto half = ToDtype(); + BufHandle a("a", {4}, half); + Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { + return Max::make(a.load(i), ExprHandle(alloc(0)), true); + }); + + LoopNest l({relu}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, relu}); + + std::ostringstream oss; + oss << *cg.stmt(); + + // Check the types used by the Max are Float. + const std::string& verification_pattern = + R"IR( +# CHECK: for ( +# CHECK: float v = float(a[i]); +# CHECK: relu[i] = half(Max(v, 0.f +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector aData(4, 2.0f); + std::vector reluData(4, 0.0f); + at::Half* aDev = nullptr; + at::Half* reluDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto reluSize = reluData.size() * sizeof(reluData[0]); + + C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); + C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); + C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, reluDev}); + C10_CUDA_CHECK( + cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + assertAllEqual(aData, reluData); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(reluDev)); +} + +TEST(Cuda, UnusedHalfArgument_CUDA) { + BufHandle a("a", {4}, kFloat); + auto half = ToDtype(); + BufHandle b("b", {4}, half); + Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { + return Max::make(a.load(i), ExprHandle(alloc(0)), true); + }); + + LoopNest l({relu}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, relu}); + + std::ostringstream oss; + oss << *cg.stmt(); + + // Check the types used by the Max are Float. + const std::string& verification_pattern = + R"IR( +# CHECK: for ( +# CHECK: float v = a[i]; +# CHECK: relu[i] = Max(v, 0.f +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // Sanity Cbeck; + std::vector aData(4, 2.0f); + std::vector bData(4, 2.0f); + std::vector reluData(4, 0.0f); + at::Half* aDev = nullptr; + at::Half* bDev = nullptr; + at::Half* reluDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto bSize = bData.size() * sizeof(bData[0]); + auto reluSize = reluData.size() * sizeof(reluData[0]); + + C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); + C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); + C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); + C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev, reluDev}); + C10_CUDA_CHECK( + cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + assertAllEqual(aData, reluData); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); + C10_CUDA_CHECK(cudaFree(reluDev)); +} + +TEST(Cuda, PrioritizeDependents_CUDA) { + BufHandle a("a", {10}, kFloat); + BufHandle b("b", {12}, kFloat); + BufHandle c("c", {12}, kFloat); + + LoopOptions block_idx_opt; + block_idx_opt.set_gpu_block_index(0); + + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + /* + * for (const auto i : c10::irange(12)) { + * c[i] = (i < 10 ? a[i] + b[i] : b[i]); + * } + */ + ExprHandle load_a = a.load({i}); + ExprHandle load_b = b.load({i}); + ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT); + ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b); + + ForPtr loop = + For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt); + + CudaCodeGen cuda_cg(loop, a, b, c); + + PaddedBuffer a_v(10, "a_v"); + PaddedBuffer b_v(12, "b_v"); + PaddedBuffer c_v(12, "c_v"); + PaddedBuffer c_ref(12, "c_ref"); + + for (const auto i : c10::irange(10)) { + a_v(i) = i * 100; + b_v(i) = i; + c_v(i) = 0; + } + + for (const auto i : c10::irange(10, 12)) { + b_v(i) = i; + c_v(i) = 0; + } + + float* a_dev = nullptr; + float* b_dev = nullptr; + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, 10 * sizeof(float))); + C10_CUDA_CHECK(cudaMalloc(&b_dev, 12 * sizeof(float))); + C10_CUDA_CHECK(cudaMalloc(&c_dev, 12 * sizeof(float))); + + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), 10 * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), 12 * sizeof(float), cudaMemcpyHostToDevice)); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev, c_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, 12 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + for (const auto i : c10::irange(12)) { + if (i < 10) { + c_ref(i) = i + i * 100; + } else { + c_ref(i) = i; + } + } + + ExpectAllNear(c_v, c_ref, 1e-5); +} + +/// Tests the case where there are two loops which have different extents bound +/// to the same block dimension. We must mask the smaller extent loop body. +TEST(Cuda, MaskBlockDim_CUDA) { + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {A_SIZE}, kFloat); + BufHandle b_buf("b", {B_SIZE}, kFloat); + Tensor c = Compute( + "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); + Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { + return a_buf.load(i) + b_buf.load(i); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Check the c write is not masked, but the d write is. + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: if (blockIdx +# CHECK: c[blockIdx.x] = +# CHECK: if (blockIdx.x<50 +# CHECK: d[blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(1))); + + // Sanity check that the kernel works. + PaddedBuffer a_v(A_SIZE); + PaddedBuffer b_v(B_SIZE); + PaddedBuffer c_v(A_SIZE); + PaddedBuffer d_v(B_SIZE); + + PaddedBuffer c_ref(A_SIZE); + PaddedBuffer d_ref(B_SIZE); + + for (const auto i : c10::irange(A_SIZE)) { + a_v(i) = (float)i; + c_ref(i) = (float)(i + 10); + } + + for (const auto i : c10::irange(B_SIZE)) { + b_v(i) = (float)(B_SIZE - i); + d_ref(i) = a_v(i) + b_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +/// Tests the case with two loops, which have different extents that are bound +/// to the same thread dimension. This is the same as the above - the smaller +/// rank write should be masked. But this time we also need to syncthreads. +TEST(Cuda, MaskThreadDim_CUDA) { + int A_SIZE = 50; + int B_SIZE = 100; + BufHandle a_buf("a", {A_SIZE}, kFloat); + BufHandle b_buf("b", {B_SIZE}, kFloat); + Tensor c = Compute( + "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); + Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { + return a_buf.load(i / 2) + b_buf.load(i); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Check the c write is masked, but the d write is not. + const std::string& verification_pattern = + R"IR( +# CHECK: if (threadIdx.x<50 +# CHECK: c[threadIdx.x] = +# CHECK: __syncthreads(); +# CHECK-NOT: if (threadIdx.x +# CHECK: d[threadIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); + + PaddedBuffer a_v(A_SIZE); + PaddedBuffer b_v(B_SIZE); + PaddedBuffer c_v(A_SIZE); + PaddedBuffer d_v(B_SIZE); + + PaddedBuffer c_ref(A_SIZE); + PaddedBuffer d_ref(B_SIZE); + + for (const auto i : c10::irange(A_SIZE)) { + a_v(i) = (float)i; + c_ref(i) = (float)(i + 10); + } + + for (const auto i : c10::irange(B_SIZE)) { + b_v(i) = (float)(B_SIZE - i); + d_ref(i) = a_v(i / 2) + b_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +/// Tests the case where there are two loops, and each is bound to a different +/// block dimension. In this case all writes should be masked since they occur +/// in distinct dimensions. +// Note: this is an extremely dumb pattern which we should never see, but is a +// useful edge case to make sure we've got things covered. +TEST(Cuda, MaskMultiBlockDim_CUDA) { + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {A_SIZE}, kFloat); + BufHandle b_buf("b", {B_SIZE}, kFloat); + Tensor c = Compute( + "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); + Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { + return a_buf.load(i) + b_buf.load(i); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(1); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Write to c should be masked against y, write to d against x. + const std::string& verification_pattern = + R"IR( +# CHECK: if (blockIdx.y<1 +# CHECK: c[blockIdx.x] = +# CHECK: if (blockIdx.x<1 +# CHECK: d[blockIdx.y] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[1], alloc(B_SIZE))); + + PaddedBuffer a_v(A_SIZE); + PaddedBuffer b_v(B_SIZE); + PaddedBuffer c_v(A_SIZE); + PaddedBuffer d_v(B_SIZE); + + PaddedBuffer c_ref(A_SIZE); + PaddedBuffer d_ref(B_SIZE); + + for (const auto i : c10::irange(A_SIZE)) { + a_v(i) = (float)i; + c_ref(i) = (float)(i + 10); + } + + for (const auto i : c10::irange(B_SIZE)) { + b_v(i) = (float)(B_SIZE - i); + d_ref(i) = a_v(i) + b_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +/// Tests the case where both the blockDim and threadDim are bound to different +/// loops. In this instance both stores should be masked since they are +/// distinct. +// Note: this is an extremely dumb pattern which we should never see, but is a +// useful edge case to make sure we've got things covered. +TEST(Cuda, MaskBlockAndThreadDim_CUDA) { + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {A_SIZE}, kFloat); + BufHandle b_buf("b", {B_SIZE}, kFloat); + Tensor c = Compute( + "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); + Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { + return a_buf.load(i) + b_buf.load(i); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + const std::string& verification_pattern = + R"IR( +# CHECK: if (threadIdx.x<1 +# CHECK: c[blockIdx.x] = +# CHECK: } +# CHECK: if (blockIdx.x<1 +# CHECK: d[threadIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); + + PaddedBuffer a_v(A_SIZE); + PaddedBuffer b_v(B_SIZE); + PaddedBuffer c_v(A_SIZE); + PaddedBuffer d_v(B_SIZE); + + PaddedBuffer c_ref(A_SIZE); + PaddedBuffer d_ref(B_SIZE); + + for (const auto i : c10::irange(A_SIZE)) { + a_v(i) = (float)i; + c_ref(i) = (float)(i + 10); + } + + for (const auto i : c10::irange(B_SIZE)) { + b_v(i) = (float)(B_SIZE - i); + d_ref(i) = a_v(i) + b_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +/// Tests the case where the loopnest has two loops of depth two: each with the +/// outer loop bound to blockDim.x and the inner loop bound to threadDim.x. In +/// this case all writes with a rank smaller than the max should be masked. +TEST(Cuda, MaskMultiDim_CUDA) { + int OUTER_SIZE = 10; + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + Tensor c = Compute( + "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return ExprHandle(2) * a_buf.load(i, j); + }); + Tensor d = Compute( + "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return c.load(i, j * 2) + b_buf.load(i, j); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // The write to D should be masked, but not the write to C. + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: if ( +# CHECK: C[threadIdx.x + 100 * blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<50 +# CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + } + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case where loop extents are symbolic and not known at compile time. +// In this case both stores must be masked against the extent of the other loop, +// in case it is larger. +TEST(Cuda, MaskMultiDimSymbolic_CUDA) { + VarHandle OUTER_SIZE("OUTER_SIZE", kLong); + VarHandle A_SIZE("A_SIZE", kLong); + VarHandle B_SIZE("B_SIZE", kLong); + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + Tensor c = Compute( + "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return ExprHandle(2) * a_buf.load(i, j); + }); + Tensor d = Compute( + "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return c.load(i, j * 2) + b_buf.load(i, j); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, OUTER_SIZE, A_SIZE, B_SIZE, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Since we don't know which is bigger (A_SIZE or B_SIZE) we must mask both. + const std::string& verification_pattern = + R"IR( +# CHECK: if (threadIdx.x(A_SIZE.node(), B_SIZE.node(), true))); + + int64_t OUTER_EXTENT = 10; + int64_t A_EXTENT = 100; + int64_t B_EXTENT = 50; + + PaddedBuffer a_v(OUTER_EXTENT, A_EXTENT); + PaddedBuffer b_v(OUTER_EXTENT, B_EXTENT); + PaddedBuffer c_v(OUTER_EXTENT, A_EXTENT); + PaddedBuffer d_v(OUTER_EXTENT, B_EXTENT); + + PaddedBuffer c_ref(OUTER_EXTENT, A_EXTENT); + PaddedBuffer d_ref(OUTER_EXTENT, B_EXTENT); + + for (const auto o : c10::irange(OUTER_EXTENT)) { + for (const auto i : c10::irange(A_EXTENT)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + } + + for (const auto o : c10::irange(OUTER_EXTENT)) { + for (const auto i : c10::irange(B_EXTENT)) { + b_v(o, i) = (float)(B_EXTENT - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_EXTENT * A_EXTENT * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_EXTENT * B_EXTENT * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_EXTENT * A_EXTENT * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_EXTENT * B_EXTENT * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, OUTER_EXTENT, A_EXTENT, B_EXTENT, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_EXTENT * A_EXTENT * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_EXTENT * B_EXTENT * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case where two loops are fused at a common parent loop, which is +// bound to the block dimension. Internally the inner loops have different +// extents but are bound to the same thread dimension. The smaller loop should +// be masked. +TEST(Cuda, MaskCompoundInnerLoop_CUDA) { + int OUTER_SIZE = 10; + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); + + // Can't build this using Compute and transforms yet. + LoopOptions blockBound; + blockBound.set_gpu_block_index(0); + LoopOptions threadBound; + threadBound.set_gpu_thread_index(0); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + + StmtPtr stmt = For::make( + i, + 0, + OUTER_SIZE, + Block::make( + {For::make( + j, + 0, + A_SIZE, + c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), + threadBound), + For::make( + k, + 0, + B_SIZE, + d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), + threadBound)}), + blockBound); + + stmt = FlattenIndexes(stmt); + stmt = IRSimplifier::simplify(stmt); + + CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // The write to D should be masked, but not the write to C. + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: if ( +# CHECK: c[threadIdx.x + 100 * blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<50 +# CHECK: d[threadIdx.x + 50 * blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev, c_dev, d_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case with two loops fused into a common parent, which is not bound +// to any block or thread dimension - however it's two inner loops are bound to +// the first thread dimensions. This should work just like the MaskThreadDim +// test where the bigger loop is unmasked but the smaller is masked. +TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { + int OUTER_SIZE = 10; + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); + + // Can't build this using Compute and transforms yet. + LoopOptions blockBound; + blockBound.set_gpu_block_index(0); + LoopOptions threadBound; + threadBound.set_gpu_thread_index(0); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + + StmtPtr stmt = For::make( + i, + 0, + OUTER_SIZE, + Block::make( + {For::make( + j, + 0, + A_SIZE, + c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), + threadBound), + For::make( + k, + 0, + B_SIZE, + d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), + threadBound)})); + + stmt = FlattenIndexes(stmt); + stmt = IRSimplifier::simplify(stmt); + + CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // The other loop remains the D write is masked. + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 10 +# CHECK-NOT: if ( +# CHECK: c[threadIdx.x + 100 * i] = +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<50 +# CHECK: d[threadIdx.x + 50 * i] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev, c_dev, d_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case with two loop nests, each of which bound to the same block +// size, but with internal loops bound to different thread rank (ie x and y). In +// this case both bodies must be masked against the other dimension being > 0. +// Note: this is a bit degenerate no one would actually write this for perf. +TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { + int OUTER_SIZE = 10; + int A_SIZE = 30; + int B_SIZE = 15; + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + Tensor c = Compute( + "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return ExprHandle(2) * a_buf.load(i, j); + }); + Tensor d = Compute( + "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return c.load(i, j * 2) + b_buf.load(i, j); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(1); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Both stores masked against the other thread dim < 1. + const std::string& verification_pattern = + R"IR( +# CHECK: if (threadIdx.y<1 +# CHECK: C[threadIdx.x + 30 * blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<1 +# CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + } + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case with two loop nests, each bound to both Block and Thread but +// the second loop is smaller in both cases - the second store must be masked +// for both the block and thread dimension. +TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { + int OUTER_A_SIZE = 10; + int OUTER_B_SIZE = 5; + int A_SIZE = 30; + int B_SIZE = 15; + BufHandle a_buf("a", {OUTER_A_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_B_SIZE, B_SIZE}, kFloat); + Tensor c = Compute( + "C", {OUTER_A_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return ExprHandle(2) * a_buf.load(i, j); + }); + Tensor d = Compute( + "D", {OUTER_B_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return c.load(i, j * 2) + b_buf.load(i, j); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // The write to D should be masked twice, but not the write to C. + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: if ( +# CHECK: C[threadIdx.x + 30 * blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: if (blockIdx.x<5 +# CHECK: if (threadIdx.x<15 +# CHECK: D[threadIdx.x + 15 * blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_A_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_B_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_A_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_B_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_A_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_B_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_A_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + } + + for (const auto o : c10::irange(OUTER_B_SIZE)) { + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_A_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_B_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_A_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_B_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_A_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_B_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +} // namespace jit +} // namespace torch + +#endif diff --git a/test/cpp/tensorexpr/test_dynamic_shapes.cpp b/test/cpp/tensorexpr/test_dynamic_shapes.cpp new file mode 100644 index 000000000000..07b9872fb832 --- /dev/null +++ b/test/cpp/tensorexpr/test_dynamic_shapes.cpp @@ -0,0 +1,701 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::indexing; +using namespace torch::jit::tensorexpr; + +TEST(DynamicShapes, SimpleGraph) { +#ifdef TORCH_ENABLE_LLVM + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Tensor, + %SS_2 : int, + %SS_3 : int): + %3 : Tensor = aten::tanh(%x) + %4 : Tensor = aten::erf(%3) + return (%4))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto x_type = TensorType::create(at::rand({10, 5})); + std::vector x_sym_dims( + {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); + auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); + graph->inputs().at(0)->setType(x_sym_type); + for (const auto n : graph->nodes()) { + n->output()->setType(x_sym_type); + } + + // Graph with symbolic shapes: + // + // graph(%x : Float(SS(-2), SS(-3)), + // %SS_2 : int, + // %SS_3 : int): + // %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x) + // %4 : Float(SS(-2), SS(-3)) = aten::erf(%3) + // return (%4) + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + std::vector symbolic_shape_inputs = c10::fmap( + x_sym_dims, + [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + // Run with the same static dims as the one we initialized the graph with. + { + auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::erf(at::tanh(a)); + + std::vector stack = fmap(std::vector({a})); + stack.push_back(10); + stack.push_back(5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + // Run with inputs having different dims. + { + auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::erf(at::tanh(a)); + + std::vector stack = fmap(std::vector({a})); + stack.push_back(50); + stack.push_back(100); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWith2InputsSameDims) { +#ifdef TORCH_ENABLE_LLVM + // The two inputs in this graph must have the same dims. + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Tensor, + %y : Tensor, + %SS_2 : int, + %SS_3 : int): + %3 : Tensor = aten::tanh(%x) + %4 : Tensor = aten::erf(%3) + %5 : Tensor = aten::mul(%4, %y) + return (%5))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto y_inp = graph->inputs()[1]; + auto x_type = TensorType::create(at::rand({10, 5})); + std::vector x_sym_dims( + {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); + auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); + graph->inputs().at(0)->setType(x_sym_type); + graph->inputs().at(1)->setType(x_sym_type); + for (const auto n : graph->nodes()) { + n->output()->setType(x_sym_type); + } + + // Graph with symbolic shapes: + // + // graph(%x : Float(SS(-4), SS(-5)), + // %y : Float(SS(-4), SS(-5)), + // %SS_2 : int, + // %SS_3 : int): + // %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x) + // %5 : Float(SS(-4), SS(-5)) = aten::erf(%4) + // %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y) + // return (%6) + + std::vector symbolic_shape_inputs = c10::fmap( + x_sym_dims, + [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[y_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + // Run with the same static dims as the one we initialized the graph with. + { + auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(10); + stack.push_back(5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + // Run with inputs having different dims. + { + auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(50); + stack.push_back(100); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWith2InputsAndBroadcast) { +#ifdef TORCH_ENABLE_LLVM + // The second input to the graph has a dim of size 1 which should be + // broadcasted in the at::mul op. + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Float(10, 5, requires_grad=0, device=cpu), + %y : Float(1, 5, requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int): + %3 : Tensor = aten::tanh(%x) + %4 : Tensor = aten::erf(%3) + %5 : Tensor = aten::mul(%4, %y) + return (%5))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto y_inp = graph->inputs()[1]; + auto x_type = TensorType::create(at::rand({10, 5})); + auto y_type = TensorType::create(at::rand({1, 5})); + auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); + auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); + auto x_sym_type = x_type->withSymbolicShapes( + std::vector({x_dim0_sym, x_dim1_sym})); + auto y_sym_type = y_type->withSymbolicShapes(std::vector( + {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); + graph->inputs().at(0)->setType(x_sym_type); + graph->inputs().at(1)->setType(y_sym_type); + for (const auto n : graph->nodes()) { + n->output()->setType(x_sym_type); + } + + // Graph with symbolic shapes: + // + // graph(%x : Float(SS(-6), SS(-7)), + // %y : Float(1, SS(-7)), + // %SS_2 : int, + // %SS_3 : int): + // %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x) + // %5 : Float(SS(-6), SS(-7)) = aten::erf(%4) + // %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y) + // return (%6) + + std::vector symbolic_shape_inputs( + {x_dim0_sym.value(), x_dim1_sym.value()}); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[y_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + // Run with the same static dims as the one we initialized the graph with. + { + auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(10); + stack.push_back(5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + // Run with inputs having different dims. + { + auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(50); + stack.push_back(100); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) { +#ifdef TORCH_ENABLE_LLVM + // The second input to the graph has a dim of size 1 which should be + // broadcasted in the at::mul op. + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Float(1, 5, requires_grad=0, device=cpu), + %y : Float(1, 5, requires_grad=0, device=cpu), + %SS_2 : int): + %4 : Tensor = aten::tanh(%x) + %5 : Tensor = aten::mul(%4, %y) + return (%5))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto y_inp = graph->inputs()[1]; + auto x_type = TensorType::create(at::rand({1, 5})); + auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); + auto x_sym_type = x_type->withSymbolicShapes(std::vector( + {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); + graph->inputs().at(0)->setType(x_sym_type); + graph->inputs().at(1)->setType(x_sym_type); + for (const auto n : graph->nodes()) { + n->output()->setType(x_sym_type); + } + + // Graph with symbolic shapes: + // + // graph(%x : Float(1, SS(-2)), + // %y : Float(1, SS(-2)), + // %SS_2 : int): + // %3 : Float(1, SS(-2)) = aten::tanh(%x) + // %4 : Float(1, SS(-2)) = aten::mul(%3, %y) + // return (%4) + + std::vector symbolic_shape_inputs({x_dim1_sym.value()}); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[y_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + // Run with the same static dims as the one we initialized the graph with. + { + auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::tanh(a), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + // Run with inputs having different dims. + { + auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::tanh(a), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(100); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWithSymbolicStrides) { +#ifdef TORCH_ENABLE_LLVM + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), + %1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), + %SS_3 : int, + %SS_2 : int): + %15 : int = prim::Constant[value=1]() + %21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15) + %22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0) + return (%22))IR"; + parseIR(graph_string, &*graph); + + std::vector input_desc = { + torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE}; + std::vector output_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = output_desc; + std::vector symbolic_shape_inputs = {-3, -2}; + TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + { + auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::add(x0, x1, 1), x0); + + std::vector inputs = {x0, x1}; + std::vector stack = at::fmap(inputs); + stack.push_back(32); + stack.push_back(10); + k.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + { + auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto out = + at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::add(x0, x1, 1), x0); + + std::vector inputs = {out, x0, x1}; + std::vector stack = at::fmap(inputs); + stack.push_back(32); + stack.push_back(10); + k.runWithAllocatedOutputs(stack); + + ASSERT_TRUE(at::allclose(out, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWithCatAndBroadcast) { +#ifdef TORCH_ENABLE_LLVM + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Float(10, 5, requires_grad=0, device=cpu), + %y : Float(4, 5, requires_grad=0, device=cpu), + %z : Float(1, 1, requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int, + %SS_4 : int, + %SS_5 : int): + %11 : int = prim::Constant[value=0]() + %3 : Tensor = aten::tanh(%x) + %out1 : Tensor = aten::erf(%3) + %out2 : Tensor = aten::relu(%y) + %10 : Tensor[] = prim::ListConstruct(%out1, %out2) + %25 : Tensor = aten::cat(%10, %11) + %28 : Tensor = aten::hardswish(%25) + %29 : Tensor = aten::mul(%28, %z) + return (%29))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto y_inp = graph->inputs()[1]; + auto z_inp = graph->inputs()[2]; + auto x_type = TensorType::create(at::rand({10, 5})); + auto y_type = TensorType::create(at::rand({4, 5})); + auto z_type = TensorType::create(at::rand({1, 1})); + auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); + auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); + auto x_sym_type = x_type->withSymbolicShapes( + std::vector({x_dim0_sym, x_dim1_sym})); + auto y_dim0_sym = c10::ShapeSymbol::newSymbol(); + auto y_sym_type = y_type->withSymbolicShapes( + std::vector({y_dim0_sym, x_dim1_sym})); + graph->inputs().at(0)->setType(x_sym_type); + graph->inputs().at(1)->setType(y_sym_type); + auto cat_dim0_sym = c10::ShapeSymbol::newSymbol(); + auto cat_out_type = x_type->withSymbolicShapes( + std::vector({cat_dim0_sym, x_dim1_sym})); + auto nodeIt = graph->nodes().begin(); + ++nodeIt; + nodeIt->output()->setType(x_sym_type); // aten::tanh + ++nodeIt; + nodeIt->output()->setType(x_sym_type); // aten::erf + ++nodeIt; + nodeIt->output()->setType(y_sym_type); // aten::relu + ++nodeIt; + ++nodeIt; + nodeIt->output()->setType(cat_out_type); // aten::cat + ++nodeIt; + nodeIt->output()->setType(cat_out_type); // aten::hardswish + ++nodeIt; + nodeIt->output()->setType(cat_out_type); // aten::mul + + // Graph with symbolic shapes: + // + // graph(%x : Float(SS(-2), SS(-3)), + // %y : Float(SS(-4), SS(-3)), + // %z : Float(1, 1), + // %SS_2 : int, + // %SS_3 : int, + // %SS_4 : int, + // %SS_5 : int): + // %7 : int = prim::Constant[value=0]() + // %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x) + // %9 : Float(SS(-2), SS(-3)) = aten::erf(%8) + // %10 : Float(SS(-4), SS(-3)) = aten::relu(%y) + // %11 : Tensor[] = prim::ListConstruct(%9, %10) + // %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7) + // %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12) + // %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z) + // return (%14) + + std::vector symbolic_shape_inputs( + {x_dim0_sym.value(), + x_dim1_sym.value(), + y_dim0_sym.value(), + cat_dim0_sym.value()}); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[y_inp] = input_desc; + symbolic_strides[z_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul( + at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c); + + std::vector stack = fmap(std::vector({a, b, c})); + stack.push_back(10); + stack.push_back(5); + stack.push_back(4); + stack.push_back(14); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); +#endif +} + +TEST(DynamicShapes, GraphFromModel) { +#ifdef TORCH_ENABLE_LLVM + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), + %1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu), + %2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu), + %input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu), + %4 : Float(SS(-7), requires_grad=0, device=cpu), + %5 : Float(SS(-7), requires_grad=0, device=cpu), + %SS_10 : int, + %SS_9 : int, + %SS_8 : int, + %SS_7 : int, + %SS_6 : int, + %SS_5 : int, + %SS_4 : int, + %SS_3 : int, + %SS_2 : int): + %15 : int = prim::Constant[value=1]() + %16 : bool = prim::Constant[value=0]() + %17 : int = prim::Constant[value=6]() + %18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16) + %19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2) + %20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15) + %21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15) + %22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4) + return (%22))IR"; + parseIR(graph_string, &*graph); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->inputs().at(2)] = input_desc; + symbolic_strides[graph->inputs().at(3)] = input_desc; + symbolic_strides[graph->inputs().at(4)] = input_desc; + symbolic_strides[graph->inputs().at(5)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + std::vector symbolic_shape_inputs = { + -10, -9, -8, -7, -6, -5, -4, -3, -2}; + TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + int64_t i2 = 10; + int64_t i3 = 32; + int64_t i4 = 19; + int64_t i5 = 71; + int64_t i6 = 139; + int64_t i7 = 261; + int64_t i8 = 261; + int64_t i9 = 261; + int64_t i10 = 261; + auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong)); + auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4); + + { + std::vector inputs = {x0, x1, x2, x3, x4, x5}; + std::vector stack = at::fmap(inputs); + stack.emplace_back(i10); + stack.emplace_back(i9); + stack.emplace_back(i8); + stack.emplace_back(i7); + stack.emplace_back(i6); + stack.emplace_back(i5); + stack.emplace_back(i4); + stack.emplace_back(i3); + stack.emplace_back(i2); + k.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + { + auto out = + at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + std::vector inputs = {out, x0, x1, x2, x3, x4, x5}; + std::vector stack = at::fmap(inputs); + stack.emplace_back(i10); + stack.emplace_back(i9); + stack.emplace_back(i8); + stack.emplace_back(i7); + stack.emplace_back(i6); + stack.emplace_back(i5); + stack.emplace_back(i4); + stack.emplace_back(i3); + stack.emplace_back(i2); + k.runWithAllocatedOutputs(stack); + + ASSERT_TRUE(at::allclose(out, ref)); + } +#endif +} + +TEST(DynamicShapes, MultiThreadedExecution) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_template = R"IR( + graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), + %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), + %SS_2 : int, + %SS_3 : int): + %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x) + %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3) + %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y) + return (%5))IR"; + for (bool use_cuda : {false, true}) { + if (!torch::cuda::is_available() && use_cuda) { + continue; + } + auto device = use_cuda ? at::kCUDA : at::kCPU; + at::jit::TemplateEnv env; + env.s("device", use_cuda ? "cuda:0" : "cpu"); + const auto graph_string = format(graph_template, env); + std::shared_ptr graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::vector symbolic_shape_inputs = {-2, -3}; + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + auto run_kernel = [&](int dim1, int dim2) { + auto a = + at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); + auto b = + at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); + + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.emplace_back(dim1); + stack.emplace_back(dim2); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + }; + + // Run the kernel in parallel to ensure that the run() method calls in + // TensorExprKernel are not changing any state. + constexpr size_t kNumThreads = 4; + std::vector threads; + for (size_t id = 0; id < kNumThreads; ++id) { + threads.emplace_back(run_kernel, id + 5, id + 20); + } + for (auto& t : threads) { + t.join(); + } + } +#endif +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp new file mode 100644 index 000000000000..eb2d6296b229 --- /dev/null +++ b/test/cpp/tensorexpr/test_expr.cpp @@ -0,0 +1,836 @@ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +using SimpleIRExprEval = ExprEval; + +TEST(Expr, BasicValueTest) { + ExprHandle a = IntImm::make(2), b = IntImm::make(3); + ExprHandle c = Add::make(a, b); + SimpleIRExprEval eval(c); + ASSERT_EQ(eval.value(), 5); +} + +TEST(Expr, BasicValueTest02) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle f = (a + b) - (c + d); + SimpleIRExprEval eval(f); + ASSERT_EQ(eval.value(), -4.0f); +} + +TEST(Expr, IsChannelsLastContiguous) { + std::vector vars = { + VarHandle("var1", kLong), + VarHandle("var2", kLong), + VarHandle("var3", kLong), + VarHandle("var4", kLong), + VarHandle("var5", kLong)}; + + // { + // key: ndims, + // value: [ + // ... + // [dim_2, dim_1, ..., dim_n] + // ] + // } + using shapGenInfo = std::unordered_map>>; + + // { + // size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n], + // strides: [ + // ... + // [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z] + // ] + // } + using shapeInfo = + std::pair, std::vector>>; + + std::vector dims = {3, 4, 5}; + + std::unordered_map> dims_expr_vec_conf = { + {3, std::vector(vars.begin(), vars.begin() + 2)}, + {4, std::vector(vars.begin(), vars.begin() + 3)}, + {5, std::vector(vars.begin(), vars.begin() + 4)}, + }; + + shapGenInfo channels_last_cont_shape_conf = { + {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}}; + shapGenInfo channels_last_non_cont_shape_conf = { + {3, {{2, 1, 0}, {1, 0, 2}}}, + {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}}, + {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}}; + + shapGenInfo cont_shape_conf = { + {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}}; + + auto shape_gen_fn = [dims_expr_vec_conf]( + int ndims, shapGenInfo shape_gen_info) -> shapeInfo { + auto dims_expr_vec = dims_expr_vec_conf.at(ndims); + std::vector> strides_expr_vec; + for (size_t i = 0; i < strides_expr_vec.size(); i++) { + strides_expr_vec[i].resize(ndims); + } + + auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) { + if (indicator % 2 == 0) { + return a * b; + } else { + return b * a; + } + }; + + auto stride_order_vec = shape_gen_info.at(ndims); + for (size_t i = 0; i < strides_expr_vec.size(); i++) { + auto stride_order = stride_order_vec[i]; + + strides_expr_vec[i][stride_order[0]] = 1; + for (size_t j = 1; j < stride_order.size(); j++) { + auto cur_dim_idx = stride_order[j]; + auto adjacent_dim_idx = stride_order[j - 1]; + + strides_expr_vec[i][cur_dim_idx] = stride_gen_fn( + i, + dims_expr_vec[adjacent_dim_idx], + strides_expr_vec[i][adjacent_dim_idx]); + } + } + + return {dims_expr_vec, strides_expr_vec}; + }; + + auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool { + if (ndims == 3) { + return buf_handle.is_channels_last_1d_contiguous(); + } else if (ndims == 4) { + return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast); + } else { + return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d); + } + }; + + // channels-last contiguous + for (size_t i = 0; i < dims.size(); i++) { + auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); + for (size_t j = 0; j < shape_info.second.size(); j++) { + BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); + ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true); + } + } + + // channels-last non-contiguous + for (size_t i = 0; i < dims.size(); i++) { + auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf); + for (size_t j = 0; j < shape_info.second.size(); j++) { + BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); + ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false); + } + } + + // contiguous + for (size_t i = 0; i < dims.size(); i++) { + auto shape_info = shape_gen_fn(dims[i], cont_shape_conf); + for (size_t j = 0; j < shape_info.second.size(); j++) { + BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); + ASSERT_EQ(buf_handle.is_contiguous(), true); + } + } + + // non-contiguous + for (size_t i = 0; i < dims.size(); i++) { + auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); + for (size_t j = 0; j < shape_info.second.size(); j++) { + BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); + ASSERT_EQ(buf_handle.is_contiguous(), false); + } + } +} + +TEST(Expr, LetTest01) { + VarHandle x("x", kFloat); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle(3.f)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, LetTest02) { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = + ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle(3.f)); + eval.bindVar(y, ExprHandle(6.f)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4 * 6)); +} + +TEST(Expr, LetStmtTest01) { + BufHandle a_buf("a", {1}, kFloat); + BufHandle b_buf("b", {1}, kFloat); + + ExprHandle load_a = a_buf.load(0); + VarHandle var = VarHandle("v", kFloat); + StmtPtr let_store = Let::make(var, load_a); + StmtPtr store_b = b_buf.store({0}, var); + BlockPtr block = Block::make({let_store, store_b}); + + SimpleIREvaluator eval(block, {a_buf, b_buf}); + + PaddedBuffer a_v(1); + PaddedBuffer b_v(1); + PaddedBuffer b_ref(1); + + a_v(0) = 23; + b_ref(0) = a_v(0); + eval(a_v, b_v); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +TEST(Expr, IntTest) { + VarHandle x("x", kInt); + ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle(3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, FloatTest) { + VarHandle x("x", kFloat); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle(3.f)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, ByteTest) { + VarHandle x("x", kByte); + ExprHandle body = ExprHandle((uint8_t)2) + + (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((uint8_t)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, CharTest) { + VarHandle x("x", kChar); + ExprHandle body = ExprHandle((int8_t)2) + + (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((int8_t)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, ShortTest) { + VarHandle x("x", kShort); + ExprHandle body = ExprHandle((int16_t)2) + + (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((int16_t)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, LongTest) { + VarHandle x("x", kLong); + ExprHandle body = ExprHandle((int64_t)2) + + (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((int64_t)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, HalfTest) { + VarHandle x("x", kHalf); + ExprHandle body = ExprHandle((at::Half)2) + + (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((at::Half)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, DoubleTest) { + VarHandle x("x", kDouble); + ExprHandle body = ExprHandle((double)2) + + (x * ExprHandle((double)3) + ExprHandle((double)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((double)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, VectorAdd01) { + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + BufHandle a_buf("A", {kTotalSize}, kFloat); + BufHandle b_buf("B", {kTotalSize}, kFloat); + BufHandle c_buf("C", {kTotalSize}, kFloat); + + /* + Build the following: + for (const auto index : c10::irange(kVectorCount)) { + store(c_buf, ramp(index * 8, 1, 8), + load(a_buf, ramp(index * 8, 1, 8) + + load(b_buf, ramp(index * 8, 1, 8)))) + } + */ + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = + a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); + ExprHandle load_b = + b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); + ExprHandle value = load_a + load_b; + StmtPtr store_c = + c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value); + StmtPtr stmt = For::make(index, 0, kVectorCount, store_c); + + ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize)); + ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize)); + ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize)); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer c_ref(kTotalSize); + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i * i; + b_v(i) = i * i * 4; + c_ref(i) = a_v(i) + b_v(i); + } + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(Expr, CompareSelectEQ) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kEQ))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 1); +} + +TEST(Expr, CompareSelectDtypes) { + // LHS and RHS expressions should have the same dtype, but this dtype could + // differ from the dtype of the return values (but dtypes of true and false + // return values should be the same). + // This test constructs a CompareSelect expression where the input dtype is + // different from the output dtype and verifies that it works correctly: + // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2 + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0.0f); + std::vector c_ref(N, 3.14f); + + VarHandle i("i", kInt); + // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f + // A and B are int, C is float. + auto select_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), + b.load(i), + FloatImm::make(3.14f), + FloatImm::make(2.78f), + CompareSelectOperation::kEQ))); + + SimpleIREvaluator ir_eval(select_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1); + assertAllEqual(b_buffer, 1); + ExpectAllNear(c_buffer, c_ref, 1e-7); +} + +TEST(Expr, IntrinsicsDtypes) { + constexpr int N = 256; + BufHandle a("A", {N}, kDouble); + BufHandle b("B", {N}, kDouble); + std::vector a_buffer(N, -10.0); + std::vector b_buffer(N, 0.0); + std::vector b_ref(N, 10.0); + + VarHandle i("i", kInt); + auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i)))); + + SimpleIREvaluator ir_eval(abs_expr, {a, b}); + ir_eval(a_buffer, b_buffer); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + + assertAllEqual(a_buffer, -10.0); + ExpectAllNear(b_buffer, b_ref, 1e-7); +} + +TEST(Expr, Substitute01) { + VarPtr x = alloc("x", kFloat); + VarPtr y = alloc("y", kFloat); + ExprPtr e = + alloc(alloc(x, alloc(1.0f)), alloc(x, y)); + + VarPtr z = alloc("z", kFloat); + ExprPtr e2 = Substitute(e, {{x, alloc(z, alloc(5.0f))}}); + ExprPtr e2_ref = alloc( + alloc(alloc(z, alloc(5.0f)), alloc(1.0f)), + alloc(alloc(z, alloc(5.0f)), y)); + std::ostringstream oss; + oss << *e2; + std::string e2_str = oss.str(); + + oss.str(""); + oss << *e2_ref; + std::string e2_ref_str = oss.str(); + ASSERT_EQ(e2_str, e2_ref_str); +} + +TEST(Expr, Math01) { + ExprHandle v = sin(ExprHandle(1.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "sin(1.f)"); + + SimpleIRExprEval eval(v); + float v_ref = std::sin(1.0f); + float res = eval.value(); + ASSERT_NEAR(res, v_ref, 1e-6); +} + +TEST(Expr, UnaryMath01) { + struct TestConfig { + std::function func; + std::function ref_func; + }; + + std::vector test_configs = { + {[](const ExprHandle& v) { return sin(v); }, + [](float v) { return std::sin(v); }}, + {[](const ExprHandle& v) { return sin(v); }, + [](float v) { return std::sin(v); }}, + {[](const ExprHandle& v) { return tan(v); }, + [](float v) { return std::tan(v); }}, + {[](const ExprHandle& v) { return asin(v); }, + [](float v) { return std::asin(v); }}, + {[](const ExprHandle& v) { return acos(v); }, + [](float v) { return std::acos(v); }}, + {[](const ExprHandle& v) { return atan(v); }, + [](float v) { return std::atan(v); }}, + {[](const ExprHandle& v) { return sinh(v); }, + [](float v) { return std::sinh(v); }}, + {[](const ExprHandle& v) { return cosh(v); }, + [](float v) { return std::cosh(v); }}, + {[](const ExprHandle& v) { return tanh(v); }, + [](float v) { return std::tanh(v); }}, + {[](const ExprHandle& v) { return exp(v); }, + [](float v) { return std::exp(v); }}, + {[](const ExprHandle& v) { return tensorexpr::abs(v); }, + [](float v) { return std::fabs(v); }}, + {[](const ExprHandle& v) { return log(v); }, + [](float v) { return std::log(v); }}, + {[](const ExprHandle& v) { return log2(v); }, + [](float v) { return std::log2(v); }}, + {[](const ExprHandle& v) { return log10(v); }, + [](float v) { return std::log10(v); }}, + {[](const ExprHandle& v) { return erf(v); }, + [](float v) { return std::erf(v); }}, + {[](const ExprHandle& v) { return sqrt(v); }, + [](float v) { return std::sqrt(v); }}, + {[](const ExprHandle& v) { return rsqrt(v); }, + [](float v) { return 1.0f / std::sqrt(v); }}, + {[](const ExprHandle& v) { return ceil(v); }, + [](float v) { return std::ceil(v); }}, + {[](const ExprHandle& v) { return floor(v); }, + [](float v) { return std::floor(v); }}, + {[](const ExprHandle& v) { return round(v); }, + [](float v) { return std::round(v); }}, + {[](const ExprHandle& v) { return trunc(v); }, + [](float v) { return std::trunc(v); }}, + }; + + for (const TestConfig& test_config : test_configs) { + const float input_v = 0.8765f; + ExprHandle v = test_config.func(ExprHandle(input_v)); + float v_ref = test_config.ref_func(input_v); + SimpleIRExprEval eval(v); + ASSERT_NEAR(eval.value(), v_ref, 1e-6); + } + + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + for (float input_v : {std::nan("1"), 0., .5}) { + ExprHandle v = FloatImm::make(input_v); + SimpleIRExprEval eval(Intrinsics::make(kIsNan, v)); + ASSERT_NEAR(eval.value(), std::isnan(input_v), 0); + } +} + +TEST(Expr, BinaryMath01) { + struct TestConfig { + std::function func; + std::function ref_func; + }; + + std::vector test_configs = { + {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); }, + [](float v1, float v2) { return std::pow(v1, v2); }}, + {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); }, + [](float v1, float v2) { return std::fmod(v1, v2); }}, + }; + + for (const TestConfig& test_config : test_configs) { + const float v1 = 0.8765f; + float v2 = 1.2345f; + ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2)); + float v_ref = test_config.ref_func(v1, v2); + SimpleIRExprEval eval(v_expr); + ASSERT_NEAR(eval.value(), v_ref, 1e-6); + } +} + +TEST(Expr, LogicalOps01) { + ExprHandle a(23); + ExprHandle b(11); + ExprHandle c(0.72f); + ExprHandle d(0.69f); + ExprHandle f1 = (a > b) && (c > d); + ExprHandle f2 = (a > b) && (c < d); + ExprHandle f3 = (a < b) && (c > d); + ExprHandle f4 = (a < b) && (c < d); + ExprHandle f5 = (a < b) || (c > d); + ExprHandle f6 = (a < b) || (c < d); + ExprHandle f7 = (a > b) || (c < d); + ExprHandle f8 = (a > b) || (c > d); + + SimpleIRExprEval eval1(f1); + SimpleIRExprEval eval2(f2); + SimpleIRExprEval eval3(f3); + SimpleIRExprEval eval4(f4); + SimpleIRExprEval eval5(f5); + SimpleIRExprEval eval6(f6); + SimpleIRExprEval eval7(f7); + SimpleIRExprEval eval8(f8); + ASSERT_EQ(eval1.value(), 1); + ASSERT_EQ(eval2.value(), 0); + ASSERT_EQ(eval3.value(), 0); + ASSERT_EQ(eval4.value(), 0); + ASSERT_EQ(eval5.value(), 1); + ASSERT_EQ(eval6.value(), 0); + ASSERT_EQ(eval7.value(), 1); + ASSERT_EQ(eval8.value(), 1); +} + +TEST(Expr, LogicalOps02) { + ExprHandle a(23); + ExprHandle b(11); + ExprHandle c(0.72f); + ExprHandle d(0.72f); + + ExprHandle f1 = (a > b) || (c > d); + ExprHandle f2 = (a > b) && (c <= d); + ExprHandle f3 = (a > b) && (c > d); + ExprHandle ff1 = f1 && f2; + ExprHandle ff2 = f2 || f3; + + SimpleIRExprEval eval1(ff1); + SimpleIRExprEval eval2(ff2); + ASSERT_EQ(eval1.value(), 1); + ASSERT_EQ(eval2.value(), 1); +} + +TEST(Expr, LogicalOps03) { + ExprHandle a(23); + ExprHandle b(11); + ExprHandle c(0.72f); + ExprHandle d(0.69f); + + // Bool types + ExprHandle bool_f1 = (a > b) && BoolImm::make(true); + ExprHandle bool_f2 = (c <= d) || BoolImm::make(true); + + // Int types + ExprHandle int_f1 = (a > b) && IntImm::make(1); + ExprHandle int_f2 = (c <= d) || IntImm::make(1); + + // Short types + ExprHandle short_f1 = (a > b) && ShortImm::make(1); + ExprHandle short_f2 = (c <= d) || ShortImm::make(1); + + // Long types + ExprHandle long_f1 = (a > b) && LongImm::make(1); + ExprHandle long_f2 = (c <= d) || LongImm::make(1); + + // Char types + ExprHandle char_f1 = (a > b) && CharImm::make(1); + ExprHandle char_f2 = (c <= d) || CharImm::make(1); + + // Byte types + ExprHandle byte_f1 = (a > b) && ByteImm::make(1); + ExprHandle byte_f2 = (c <= d) || ByteImm::make(1); + + SimpleIRExprEval eval1(bool_f1); + SimpleIRExprEval eval2(bool_f2); + SimpleIRExprEval eval3(int_f1); + SimpleIRExprEval eval4(int_f2); + SimpleIRExprEval eval5(short_f1); + SimpleIRExprEval eval6(short_f2); + SimpleIRExprEval eval7(long_f1); + SimpleIRExprEval eval8(long_f2); + SimpleIRExprEval eval9(char_f1); + SimpleIRExprEval eval10(char_f2); + SimpleIRExprEval eval11(byte_f1); + SimpleIRExprEval eval12(byte_f2); + + ASSERT_EQ(eval1.value(), true); + ASSERT_EQ(eval2.value(), true); + ASSERT_EQ(eval3.value(), 1); + ASSERT_EQ(eval4.value(), 1); + ASSERT_EQ(eval5.value(), 1); + ASSERT_EQ(eval6.value(), 1); + ASSERT_EQ(eval7.value(), 1); + ASSERT_EQ(eval8.value(), 1); + ASSERT_EQ(eval9.value(), 1); + ASSERT_EQ(eval10.value(), 1); + ASSERT_EQ(eval11.value(), 1); + ASSERT_EQ(eval12.value(), 1); +} + +TEST(Expr, BitwiseOps) { + ExprHandle a(59); + ExprHandle b(11); + ExprHandle c(101); + ExprHandle d(2); + ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; + + SimpleIRExprEval eval(f); + ASSERT_EQ(eval.value(), 11); +} + +TEST(Expr, DynamicShapeAdd) { + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + BufHandle b("b", {n}, kFloat); + BufHandle c("c", {n}, kFloat); + VarHandle i("i", kInt); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +TEST(Expr, OutOfBounds) { + ExprHandle N(10); + ExprHandle start(0); + ExprHandle stop(15); + VarHandle i("i", kInt); + + BufHandle X("X", {N}, kInt); + + auto body = Store::make(X, {i}, i); + auto stmt = For::make(i, start, stop, body); + + PaddedBuffer data(20); + + EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); +} + +TEST(Expr, OutOfBounds2d) { + std::vector> size_options = {{10, 15}, {15, 10}}; + for (auto sizes : size_options) { + ExprHandle N(sizes.first); + ExprHandle M(sizes.second); + ExprHandle start(0); + ExprHandle stopInner(15); + ExprHandle stopOuter(15); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + BufHandle X("X", {N, M}, kInt); + + auto body = Store::make(X, {i, j}, i); + auto inner = For::make(j, start, stopInner, body); + auto stmt = For::make(i, start, stopOuter, inner); + + PaddedBuffer data(400); + + EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); + } +} + +TEST(Expr, OutOfBounds2dFlattenedIndex) { + ExprHandle buf_size(149); + ExprHandle start(0); + ExprHandle stopInner(15); + ExprHandle stopOuter(10); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + BufHandle X("X", {buf_size}, kInt); + + auto idx = Add::make(Mul::make(i, stopInner), j); + auto body = Store::make(X, {idx}, i); + auto inner = For::make(j, start, stopInner, body); + auto stmt = For::make(i, start, stopOuter, inner); + + PaddedBuffer data(400); + + EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); +} + +void testCond01() { + const int N = 16; + PaddedBuffer a_v(N); + BufHandle a_buf("a", {N}, kFloat); + VarHandle index = VarHandle("index", kInt); + StmtPtr assign_x2 = a_buf.store({index}, cast(index) * 2); + StmtPtr assign_x3 = a_buf.store({index}, cast(index) * 3); + ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); + StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3); + StmtPtr for_stmt = For::make(index, 0, N, assign); + SimpleIREvaluator(for_stmt, {a_buf})(a_v); + + PaddedBuffer a_ref(N); + for (const auto i : c10::irange(N)) { + if (i % 2 == 0) { + a_ref(i) = i * 2; + } else { + a_ref(i) = i * 3; + } + } + ExpectAllNear(a_v, a_ref, 1e-5); +} + +void testIfThenElse01() { + ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)"); + + SimpleIRExprEval eval(v); + ASSERT_EQ(eval.value(), 1.0f); +} + +void testIfThenElse02() { + ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); + + SimpleIRExprEval eval(v); + ASSERT_EQ(eval.value(), 2.0f); +} + +void testIfThenElse03() { + ExprHandle v = + ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); + + SimpleIRExprEval eval(v); + ASSERT_EQ(eval.value(), 2.0f); +} + +void testStmtClone() { + const int N = 16; + + BufHandle a_buf("a", {N}, kInt); + VarHandle index = VarHandle("index", kInt); + StmtPtr body = a_buf.store({index}, 5); + StmtPtr loop = For::make(index, 0, N, body); + + StmtPtr cloned_loop = Stmt::clone(loop); + std::vector orig_loop_results(N); + std::vector cloned_loop_results(N); + SimpleIREvaluator(loop, {a_buf})(orig_loop_results); + SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results); + + assertAllEqual(orig_loop_results, 5); + assertAllEqual(cloned_loop_results, 5); + + // Let's add another assign to the body in the cloned loop and verify that the + // original statement hasn't changed while the cloned one has. + StmtPtr body_addition = a_buf.store({index}, 33); + BlockPtr cloned_body = static_to(static_to(cloned_loop)->body()); + cloned_body->append_stmt(body_addition); + + std::vector orig_loop_results_after_mutation(N); + std::vector cloned_loop_results_after_mutation(N); + SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation); + SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation); + + assertAllEqual(orig_loop_results_after_mutation, 5); + assertAllEqual(cloned_loop_results_after_mutation, 33); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_external_calls.cpp b/test/cpp/tensorexpr/test_external_calls.cpp new file mode 100644 index 000000000000..49f43d16b499 --- /dev/null +++ b/test/cpp/tensorexpr/test_external_calls.cpp @@ -0,0 +1,1061 @@ +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +TEST(ExternalCall, Conv1d_float) { + BufHandle Input("Input", {1, 100, 115}, kFloat); + BufHandle Weight("Weight", {100, 1, 7}, kFloat); + BufHandle Bias("Bias", {100}, kFloat); + BufHandle ResultBuf("Result", {1, 100, 115}, kFloat); + int64_t stride = 1; + int64_t pad = 3; + int64_t dilation = 1; + int64_t groups = 100; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_conv1d", + {Input, Weight, Bias}, + {stride, pad, dilation, groups})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 100, 115}, options) * 5.f; + at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f; + at::Tensor bias = at::ones({100}, options) * 11.f; + at::Tensor ref = + at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); + + at::Tensor nnc_result; + std::vector input_buf(1 * 100 * 115, 5.f); + std::vector weight_buf(100 * 1 * 7, 6.f); + std::vector bias_buf(100, 11.f); + std::vector result_buf(1 * 100 * 115, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); + + llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); + + ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv1d_int) { + // A similar test, but now using kInt tensors + BufHandle Input("Input", {1, 100, 115}, kInt); + BufHandle Weight("Weight", {100, 1, 7}, kInt); + BufHandle Bias("Bias", {100}, kInt); + BufHandle ResultBuf("Result", {1, 100, 115}, kInt); + int64_t stride = 1; + int64_t pad = 3; + int64_t dilation = 1; + int64_t groups = 100; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_conv1d", + {Input, Weight, Bias}, + {stride, pad, dilation, groups})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kInt) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 100, 115}, options) * 5; + at::Tensor weight = at::ones({100, 1, 7}, options) * 6; + at::Tensor bias = at::ones({100}, options) * 11; + at::Tensor ref = + at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); + + at::Tensor nnc_result; + std::vector input_buf(1 * 100 * 115, 5); + std::vector weight_buf(100 * 1 * 7, 6); + std::vector bias_buf(100, 11); + std::vector result_buf(1 * 100 * 115, -1); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); + + llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); + + ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv1d_nobias_noargs) { + BufHandle Input("Input", {1, 1, 115}, kFloat); + BufHandle Weight("Weight", {10, 1, 7}, kFloat); + BufHandle ResultBuf("Result", {1, 10, 109}, kFloat); + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 1, 115}, options) * 5.f; + at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f; + at::Tensor ref = at::conv1d(input, weight); + + at::Tensor nnc_result; + std::vector input_buf(1 * 1 * 115, 5.f); + std::vector weight_buf(10 * 1 * 7, 6.f); + std::vector result_buf(1 * 10 * 109, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); + + llvm_codegen.call({input_buf, weight_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); + + ir_eval.call({input_buf, weight_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv2d_float) { + BufHandle Input("Input", {1, 3, 224, 224}, kFloat); + BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat); + BufHandle Bias("Bias", {16}, kFloat); + BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); + int64_t stride = 2; + int64_t pad = 1; + int64_t dilation = 1; + int64_t groups = 1; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_conv2d", + {Input, Weight, Bias}, + {stride, stride, pad, pad, dilation, dilation, groups})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f; + at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f; + at::Tensor bias = at::ones({16}, options) * 11.f; + at::Tensor ref = at::conv2d( + input, + weight, + bias, + {stride, stride}, + {pad, pad}, + {dilation, dilation}, + groups); + + at::Tensor nnc_result; + std::vector input_buf(1 * 3 * 224 * 224, 5.f); + std::vector weight_buf(16 * 3 * 3 * 3, 6.f); + std::vector bias_buf(16, 11.f); + std::vector result_buf(1 * 16 * 112 * 112, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); + + llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); + + ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv2d_int) { + // A similar test, but now using kInt tensors + + BufHandle Input("Input", {1, 3, 224, 224}, kInt); + BufHandle Weight("Weight", {16, 3, 3, 3}, kInt); + BufHandle Bias("Bias", {16}, kInt); + BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt); + int64_t stride = 2; + int64_t pad = 1; + int64_t dilation = 1; + int64_t groups = 1; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_conv2d", + {Input, Weight, Bias}, + {stride, stride, pad, pad, dilation, dilation, groups})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kInt) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5; + at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6; + at::Tensor bias = at::ones({16}, options) * 11; + at::Tensor ref = at::conv2d( + input, + weight, + bias, + {stride, stride}, + {pad, pad}, + {dilation, dilation}, + groups); + + at::Tensor nnc_result; + std::vector input_buf(1 * 3 * 224 * 224, 5); + std::vector weight_buf(16 * 3 * 3 * 3, 6); + std::vector bias_buf(16, 11); + std::vector result_buf(1 * 16 * 112 * 112, -1); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); + + llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); + + ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv2d_nobias_noargs) { + BufHandle Input("Input", {1, 16, 112, 112}, kFloat); + BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat); + BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f; + at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; + at::Tensor ref = at::conv2d(input, weight); + + at::Tensor nnc_result; + std::vector input_buf(1 * 16 * 112 * 112, 5.f); + std::vector weight_buf(16 * 16 * 1 * 1, 6.f); + std::vector result_buf(1 * 16 * 112 * 112, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); + + llvm_codegen.call({input_buf, weight_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); + + ir_eval.call({input_buf, weight_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Addmm_float) { + BufHandle Input("Input", {100, 300}, kFloat); + BufHandle Mat1("Mat1", {100, 200}, kFloat); + BufHandle Mat2("Mat2", {200, 300}, kFloat); + BufHandle ResultBuf("Result", {100, 300}, kFloat); + int64_t beta = 2; + int64_t alpha = 2; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({100, 300}, options) * 5.f; + at::Tensor mat1 = at::ones({100, 200}, options) * 6.f; + at::Tensor mat2 = at::ones({200, 300}, options) * 11.f; + at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha); + + at::Tensor nnc_result; + std::vector input_buf(100 * 300, 5.f); + std::vector mat1_buf(100 * 200, 6.f); + std::vector mat2_buf(200 * 300, 11.f); + std::vector result_buf(100 * 300, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result}); + + llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result}); + + ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Embedding) { + BufHandle Weight("Weight", {256, 100}, kFloat); + BufHandle Indices("Indices", {1, 115}, kLong); + BufHandle ResultBuf("Result", {1, 115, 100}, kFloat); + int64_t padding_idx = -1; + bool scale_grad_by_freq = false; + bool sparse = false; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_embedding", + {Weight, Indices}, + {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + + at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f; + at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6; + at::Tensor ref = + at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); + + at::Tensor nnc_result; + std::vector weight_buf(256 * 100, 5.f); + std::vector indices_buf(1 * 115, 6); + std::vector result_buf(1 * 115 * 100, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result}); + + llvm_codegen.call({weight_buf, indices_buf, result_buf}); + nnc_result = at::from_blob( + result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result}); + + ir_eval.call({weight_buf, indices_buf, result_buf}); + nnc_result = at::from_blob( + result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, MaxReduction) { + BufHandle Input("Input", {1, 115, 152}, kFloat); + BufHandle ResultBuf("Result", {1, 152}, kFloat); + int64_t dim = 1; + bool keep_dim = false; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + + at::Tensor input = at::ones({1, 115, 152}, options) * 5.f; + at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim)); + + at::Tensor nnc_result; + std::vector input_buf(1 * 115 * 152, 5.f); + std::vector result_buf(1 * 152, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result}); + + llvm_codegen.call({input_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result}); + + ir_eval.call({input_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +#ifdef USE_XNNPACK + +TEST(ExternalCall, Prepacked_Linear_float) { + using namespace at::native::xnnpack; + + BufHandle Input("Input", {100, 200}, kFloat); + BufHandle ResultBuf("Result", {100, 300}, kFloat); + + // Calculate reference result using at::linear. + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = + at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200}); + at::Tensor weight = + at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200}); + at::Tensor bias = at::linspace(-10.0, 10.0, 300, options); + at::Tensor ref = at::linear(input, weight, bias); + + // Create prepacked xnnpack context object. + auto linear_clamp_prepack_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("prepacked::linear_clamp_prepack", "") + .typed( + at::Tensor, + std::optional, + const std::optional&, + const std::optional&)>(); + auto prepacked = linear_clamp_prepack_op.call( + weight, bias, std::optional(), std::optional()); + + BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_prepacked_linear_clamp_run", + {Input, DummyPrepacked}, + {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + at::Tensor nnc_result; + std::vector input_buf( + input.data_ptr(), input.data_ptr() + 100 * 200); + std::vector result_buf(100 * 300, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); + + llvm_codegen.call({input_buf, prepacked.get(), result_buf}); + nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); + + ir_eval.call({input_buf, prepacked.get(), result_buf}); + nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Prepacked_Conv2d_float) { + using namespace at::native::xnnpack; + + BufHandle Input("Input", {1, 3, 224, 224}, kFloat); + BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); + int64_t stride = 2; + int64_t pad = 1; + int64_t dilation = 1; + int64_t groups = 1; + + // Calculate reference result using at::conv2d. + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options) + .resize_({1, 3, 224, 224}); + at::Tensor weight = + at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3}); + at::Tensor bias = at::linspace(-10.0, 10.0, 16, options); + at::Tensor ref = at::conv2d( + input, + weight, + bias, + {stride, stride}, + {pad, pad}, + {dilation, dilation}, + groups); + + // Create prepacked xnnpack context object. + auto conv2d_clamp_prepack_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "") + .typed( + at::Tensor, + std::optional, + std::vector, + std::vector, + std::vector, + int64_t, + const std::optional&, + const std::optional&)>(); + auto prepacked = conv2d_clamp_prepack_op.call( + weight, + bias, + {stride, stride}, + {pad, pad}, + {dilation, dilation}, + groups, + std::optional(), + std::optional()); + + BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_prepacked_conv2d_clamp_run", + {Input, DummyPrepacked}, + {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + at::Tensor nnc_result; + std::vector input_buf( + input.data_ptr(), input.data_ptr() + 1 * 3 * 224 * 224); + std::vector result_buf(1 * 16 * 112 * 112, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); + + llvm_codegen.call({input_buf, prepacked.get(), result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); + + ir_eval.call({input_buf, prepacked.get(), result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); +} + +#endif // USE_XNNPACK + +TEST(ExternalCall, BinaryFloat) { + using TensorFunc = std::function; + using Test = std::tuple< + std::vector, + std::vector, + std::vector, + TensorFunc, + std::string>; + std::vector tests = {}; + tests.push_back( + Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"}); + tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"}); + tests.push_back(Test{ + {100, 200}, + {200, 300}, + {100, 300}, + [&](const at::Tensor& a, const at::Tensor& b) { return at::mm(a, b); }, + "nnc_aten_mm"}); + for (auto curTest : tests) { + auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest; + auto toExprHandleVec = [](std::vector v) { + auto intV = std::vector(v.begin(), v.end()); + return std::vector(intV.begin(), intV.end()); + }; + BufHandle A("A", toExprHandleVec(aShape), kFloat); + BufHandle B("B", toExprHandleVec(bShape), kFloat); + BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make(ResultBuf, externCallName, {A, B}, {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; + at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f; + at::Tensor ref = torchFunc(a, b); + + auto prod = [](std::vector v) { + // NOLINTNEXTLINE(modernize-use-transparent-functors) + return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); + }; + + at::Tensor nnc_result; + std::vector a_buf(prod(aShape), 5.f); + std::vector b_buf(prod(bShape), 6.f); + std::vector result_buf(prod(resShape), -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result}); + + llvm_codegen.call({a_buf, b_buf, result_buf}); + nnc_result = + at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result}); + ir_eval.call({a_buf, b_buf, result_buf}); + nnc_result = + at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); + } +} + +TEST(ExternalCall, UnaryFloat) { + using TensorFunc = std::function; + auto toExprHandleVec = [](std::vector v) { + auto intV = std::vector(v.begin(), v.end()); + return std::vector(intV.begin(), intV.end()); + }; + using Test = std::tuple< + std::vector, + std::vector, + TensorFunc, + std::string, + std::vector>; + std::vector tests = {}; + tests.push_back(Test{ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + {1, 64, 8, 9}, + {1, 64, 5, 7}, + [](at::Tensor x) { return at::adaptive_avg_pool2d(x, {5, 7}); }, + "nnc_aten_adaptive_avg_pool2d", + toExprHandleVec({5, 7})}); + tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + {100, 200}, + {100}, + [](at::Tensor x) { return at::mean(x, {1}); }, + "nnc_aten_mean", + toExprHandleVec({1, /*keepdim=*/0})}); + for (auto curTest : tests) { + auto [aShape, resShape, torchFunc, externCallName, externCallArgs] = + curTest; + BufHandle A("A", toExprHandleVec(aShape), kFloat); + BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs)); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; + at::Tensor ref = torchFunc(a); + + auto prod = [](std::vector v) { + // NOLINTNEXTLINE(modernize-use-transparent-functors) + return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); + }; + + at::Tensor nnc_result; + std::vector a_buf(prod(aShape), 5.f); + std::vector result_buf(prod(resShape), -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result}); + + llvm_codegen.call({a_buf, result_buf}); + nnc_result = + at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result}); + ir_eval.call({a_buf, result_buf}); + nnc_result = + at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); + } +} + +TEST(ExternalCall, ComputeInterop) { + // This test verifies that Tensors using external calls can be used by and can + // use Tensors built with Compute API. + + BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat); + BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat); + + Tensor Input = Compute( + "Input", + {1, 16, 32, 32}, + [&](const VarHandle& n, + const VarHandle& c, + const VarHandle& h, + const VarHandle& w) { return FloatImm::make(5.0f); }); + Tensor Weight = Compute( + "Weight", + {16, 16, 1, 1}, + [&](const VarHandle& n, + const VarHandle& c, + const VarHandle& h, + const VarHandle& w) { return FloatImm::make(6.0f); }); + + Tensor ConvResult = Tensor( + ConvResultBuf.node(), + ExternalCall::make( + ConvResultBuf, + "nnc_aten_conv2d", + {BufHandle(Input.buf()), BufHandle(Weight.buf())}, + {})); + Tensor MatmulResult = Tensor( + MatmulResultBuf.node(), + ExternalCall::make( + MatmulResultBuf, + "nnc_aten_matmul", + {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())}, + {})); + Tensor Result = Compute( + "Result", + {1, 16, 32, 32}, + [&](const VarHandle& n, + const VarHandle& c, + const VarHandle& h, + const VarHandle& w) { + return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w); + }); + + LoopNest l({Input, Weight, ConvResult, MatmulResult, Result}); + + // Inlining should not inline anything here since all Bufs are either defined + // or used in ExternalCalls - we run it just for testing + l.inlineIntermediateBufs(true); + + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f; + at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; + at::Tensor t = at::conv2d(input, weight); + at::Tensor t2 = at::matmul(t, t); + at::Tensor ref = t + t2; + + at::Tensor nnc_result; + std::vector input_buf(1 * 16 * 32 * 32, 5.f); + std::vector weight_buf(16 * 16 * 1 * 1, 6.f); + std::vector conv_result_buf(1 * 16 * 32 * 32, -1.f); + std::vector matmul_result_buf(1 * 16 * 32 * 32, -1.f); + std::vector result_buf(1 * 16 * 32 * 32, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen( + l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); + + llvm_codegen.call( + {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval( + l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); + + ir_eval.call( + {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Inlining) { + // This test verifies that Tensors using external calls can be used by and + // can use Tensors built with Compute API. + + BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat); + + Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { + return FloatImm::make(5.0f); + }); + Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { + return FloatImm::make(4.0f); + }); + Tensor MatmulResult = Tensor( + MatmulResultBuf.node(), + ExternalCall::make( + MatmulResultBuf, + "nnc_aten_matmul", + {BufHandle(A.buf()), BufHandle(B.buf())}, + {})); + Tensor Result = + Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { + return MatmulResult.load(i, j) + FloatImm::make(3.0f); + }); + + StmtPtr root_stmt = alloc(std::vector( + {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()})); + LoopNest l(root_stmt, {Result.buf()}); + + // Inlining should not inline anything here since all Bufs are either + // defined or used in ExternalCalls + l.inlineIntermediateBufs(false); + + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor a = at::ones({8, 8}, options) * 5.f; + at::Tensor b = at::ones({8, 8}, options) * 4.f; + at::Tensor t = at::matmul(a, b); + at::Tensor ref = t + 3.f; + + at::Tensor nnc_result; + std::vector result_buf(8 * 8); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Result}); + + llvm_codegen.call({result_buf}); + nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Result}); + + ir_eval.call({result_buf}); + nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, JitCustomFusionOp) { + const char* custom_op_schema_literal = + "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor"; + const char* external_func_name = "nnc_add_mul"; + + auto add_mul_lowering_func = + [external_func_name]( + const std::vector& inputs, + const std::vector& output_shape, + const std::vector& output_strides, + const std::optional& output_type, + at::Device device) { + auto output_dtype = Dtype(*output_type); + torch::jit::tensorexpr::BufHandle result_buf( + "nnc_add_mul_res_buf", output_shape, output_dtype); + const torch::jit::tensorexpr::BufHandle& a = + std::get(inputs[0]); + const torch::jit::tensorexpr::BufHandle& b = + std::get(inputs[1]); + const torch::jit::tensorexpr::BufHandle& c = + std::get(inputs[1]); + torch::jit::tensorexpr::StmtPtr s = + torch::jit::tensorexpr::ExternalCall::make( + result_buf, external_func_name, {a, b, c}, {}); + return Tensor(result_buf.node(), s); + }; + + auto add_mul_external_func = [](int64_t bufs_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int64_t* buf_strides, + int8_t* buf_dtypes, + int64_t args_num, + int64_t* extra_args) {}; + + torch::jit::RegisterOperators reg({Operator( + custom_op_schema_literal, + [](const Node* node) -> Operation { + return [](Stack& _stack) { + auto a = std::move(peek(_stack, 0, 3)).toTensor(); + auto b = std::move(peek(_stack, 1, 3)).toTensor(); + auto c = std::move(peek(_stack, 2, 3)).toTensor(); + drop(_stack, 3); + auto result = (a + b) * c; + pack(_stack, std::move(result)); + return 0; + }; + }, + c10::AliasAnalysisKind::FROM_SCHEMA)}); + + auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet(); + custom_operator_set.insert({custom_op_schema_literal}); + + auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry(); + te_lowering_registry.insert( + parseSchema(custom_op_schema_literal), add_mul_lowering_func); + + auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry(); + te_nnc_func_registry[external_func_name] = add_mul_external_func; + + std::string graph_string = R"IR( + graph(%a : Float(10, 20, strides=[20, 1], device=cpu), + %b : Float(10, 20, strides=[20, 1], device=cpu), + %c : Float(10, 20, strides=[20, 1], device=cpu)): + %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c) + return (%res))IR"; + + auto graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::string shape_compute_python_string = R"PY( + def computOutput(a: List[int], b: List[int], c: List[int]): + expandedSizes: List[int] = [] + dimsA = len(a) + dimsB = len(b) + dimsC = len(c) + ndim = max(dimsA, dimsB, dimsC) + for i in range(ndim): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + dimC = dimsC - 1 - offset + sizeA = a[dimA] if (dimA >= 0) else 1 + sizeB = b[dimB] if (dimB >= 0) else 1 + sizeC = a[dimC] if (dimC >= 0) else 1 + + if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1: + # TODO: only assertion error is bound in C++ compilation right now + raise AssertionError( + "The size of tensor a {} must match the size of tensor b (" + "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i) + ) + + expandedSizes.append(max(sizeA, sizeB, sizeC)) + + return expandedSizes + )PY"; + auto cu_ptr = torch::jit::compile(shape_compute_python_string); + torch::jit::GraphFunction* gf = + (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput"); + ASSERT_TRUE(gf); + +#ifdef TORCH_ENABLE_LLVM + auto static_graph_case = graph->copy(); + FuseTensorExprs(static_graph_case, 1); + torch::jit::testing::FileCheck() + .check("prim::TensorExprGroup_") + ->check("nnc_custom::add_mul") + ->run(*static_graph_case); + + auto dynamic_graph_case = graph->copy(); + auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal); + ASSERT_TRUE(custom_op); + torch::jit::RegisterShapeComputeGraphForSchema( + custom_op->schema(), gf->graph()); + FuseTensorExprs(dynamic_graph_case, 1, false, true); + torch::jit::testing::FileCheck() + .check("prim::TensorExprGroup_") + ->check("nnc_custom::add_mul") + ->run(*dynamic_graph_case); +#else + torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph); +#endif +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_graph_opt.cpp b/test/cpp/tensorexpr/test_graph_opt.cpp new file mode 100644 index 000000000000..aed73d09d14d --- /dev/null +++ b/test/cpp/tensorexpr/test_graph_opt.cpp @@ -0,0 +1,319 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +class GraphOpt : public ::testing::Test { + public: + void SetUp() override { + old_cat_wo_conditionals_ = getCatWoConditionals(); + getCatWoConditionals() = true; + } + + void TearDown() override { + getCatWoConditionals() = old_cat_wo_conditionals_; + } + + private: + bool old_cat_wo_conditionals_; +}; + +TEST_F(GraphOpt, OptimizeCat) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // The `aten::log` op must be moved to the inputs of `aten::cat`. + testing::FileCheck() + .check("aten::log") + ->check("aten::log") + ->check("aten::log") + ->check("aten::cat") + ->check_not("aten::log") + ->run(*kernel.graph()); + + auto x = at::rand({10}, at::kFloat); + auto y = at::rand({20}, at::kFloat); + auto z = at::rand({30}, at::kFloat); + auto ref = at::log(at::cat({x, y, z}, 0)); + + std::vector inputs = {x, y, z}; + std::vector stack = fmap(inputs); + kernel.run(stack); + auto out = stack[0].toTensor(); + ASSERT_EQ(out.sizes(), ref.sizes()); + ASSERT_EQ(out.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(out, ref)); +#endif +} + +TEST_F(GraphOpt, OptimizeCat2) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) + %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5) + return (%6))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // The `aten::log` and `aten::tanh` ops must be moved to the inputs of + // `aten::cat`. + testing::FileCheck() + .check("aten::log") + ->check("aten::log") + ->check("aten::log") + ->check("aten::tanh") + ->check("aten::tanh") + ->check("aten::tanh") + ->check("aten::cat") + ->check_not("aten::log") + ->check_not("aten::tanh") + ->run(*kernel.graph()); + + auto x = at::rand({10}, at::kFloat); + auto y = at::rand({20}, at::kFloat); + auto z = at::rand({30}, at::kFloat); + auto ref = at::tanh(at::log(at::cat({x, y, z}, 0))); + + std::vector inputs = {x, y, z}; + std::vector stack = fmap(inputs); + kernel.run(stack); + auto out = stack[0].toTensor(); + ASSERT_EQ(out.sizes(), ref.sizes()); + ASSERT_EQ(out.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(out, ref)); +#endif +} + +TEST_F(GraphOpt, OptimizeCat3) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%a : Float(60, strides=[1], device=cpu), + %x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) + %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5) + return (%6))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // The `aten::tanh` op must be moved to the inputs of `aten::cat`. + // But the `aten::mul` op must not be moved since it is not a single-tensor + // op (it has 2 tensor inputs). + testing::FileCheck() + .check("aten::tanh") + ->check("aten::tanh") + ->check("aten::tanh") + ->check("aten::cat") + ->check("aten::mul") + ->check_not("aten::tanh") + ->run(*kernel.graph()); + + auto a = at::rand({60}, at::kFloat); + auto x = at::rand({10}, at::kFloat); + auto y = at::rand({20}, at::kFloat); + auto z = at::rand({30}, at::kFloat); + auto ref = at::tanh(at::cat({x, y, z}, 0)) * a; + + std::vector inputs = {a, x, y, z}; + std::vector stack = fmap(inputs); + kernel.run(stack); + auto out = stack[0].toTensor(); + ASSERT_EQ(out.sizes(), ref.sizes()); + ASSERT_EQ(out.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(out, ref)); +#endif +} + +TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Int(10, strides=[1], device=cpu), + %y : Int(20, strides=[1], device=cpu), + %z : Int(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // The `aten::tanh` op must be moved to the inputs of `aten::cat`. + // The scalar type of the inputs to `cat` should now be `Float` since they + // are the result of `tanh` which does the type promotion. + testing::FileCheck() + .check("aten::tanh") + ->check("aten::tanh") + ->check("aten::tanh") + ->check("aten::cat") + ->check_not("aten::tanh") + ->run(*kernel.graph()); + + auto x = at::randint(std::numeric_limits::max(), {10}, at::kInt); + auto y = at::randint(std::numeric_limits::max(), {20}, at::kInt); + auto z = at::randint(std::numeric_limits::max(), {30}, at::kInt); + auto ref = at::tanh(at::cat({x, y, z}, 0)); + + std::vector inputs = {x, y, z}; + std::vector stack = fmap(inputs); + kernel.run(stack); + auto out = stack[0].toTensor(); + ASSERT_EQ(out.sizes(), ref.sizes()); + ASSERT_EQ(out.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(out, ref)); +#endif +} + +TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Double(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // No transformation should have happened because the `aten::cat` op performs + // type promotion. This case is currently not handled. + testing::FileCheck() + .check("aten::cat") + ->check("aten::log") + ->check_not("aten::cat") + ->check_not("aten::log") + ->run(*kernel.graph()); +#endif +} + +TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(60, strides=[1], device=cpu), + %x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // No transformation is expected since the consumers of cat are not + // single-tensor element-wise ops. + testing::FileCheck() + .check("aten::cat") + ->check("aten::mul") + ->check_not("aten::cat") + ->check_not("aten::mul") + ->run(*kernel.graph()); +#endif +} + +TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(60, strides=[1], device=cpu), + %1 : Float(60, strides=[1], device=cpu), + %x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %one : int = prim::Constant[value=1]() + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) + %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one) + return (%6))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // No transformation is expected since the consumers of cat are not + // single-tensor element-wise ops. + testing::FileCheck() + .check("aten::cat") + ->check("aten::mul") + ->check("aten::add") + ->check_not("aten::cat") + ->check_not("aten::mul") + ->check_not("aten::add") + ->run(*kernel.graph()); +#endif +} + +TEST_F(GraphOpt, AOTGraphPrepPasses) { + const auto graph_string = R"IR( + graph(%x, %y, %z, %i : int): + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + return (%xyz_list, %i))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + removeGraphOutput(g, 1); + replaceListOutputWithTuple(g); + LowerAllTuples(g); + + testing::FileCheck().check("return (%x, %y, %z)")->run(*g); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp new file mode 100644 index 000000000000..4d2f8c6e906e --- /dev/null +++ b/test/cpp/tensorexpr/test_ir_printer.cpp @@ -0,0 +1,98 @@ +#include + +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include +#include +#include +#include +#include +#include + +#include +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +TEST(IRPrinter, BasicValueTest) { + ExprHandle a = IntImm::make(2), b = IntImm::make(3); + ExprHandle c = Add::make(a, b); + + std::stringstream ss; + ss << c; + ASSERT_EQ(ss.str(), "2 + 3"); +} + +TEST(IRPrinter, BasicValueTest02) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle f = (a + b) - (c + d); + + std::stringstream ss; + ss << f; + ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)"); +} + +TEST(IRPrinter, BasicValueTest03) { + ExprHandle a(3.402823466385289e+38f); + ExprHandle b(-3.402823466385289e+38f); + std::stringstream ss; + ss << a << ", " << b; + ASSERT_EQ(ss.str(), "3.402823466385289e+38f, -3.402823466385289e+38f"); +} + +TEST(IRPrinter, CastTest) { + VarHandle x("x", kHalf); + VarHandle y("y", kFloat); + ExprHandle body = ExprHandle(2.f) + + (Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y); + + std::stringstream ss; + ss << body; + ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)"); +} + +TEST(IRPrinter, FunctionName) { + int M = 4; + int N = 20; + + Tensor producer = Compute( + "producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return m * n; + }); + + Tensor chunk_0 = Compute( + "chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { + return producer.load(m, n); + }); + + Tensor chunk_1 = Compute( + "chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { + return producer.load(m, n + ExprHandle(N / 2)); + }); + + Tensor consumer = Compute( + "consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) { + return i * chunk_1.load(i, j); + }); + + LoopNest l({chunk_0, chunk_1, consumer}); + auto body = LoopNest::sanitizeNames(l.root_stmt()); + + std::stringstream ss; + ss << *body; + + const std::string& verification_pattern = + R"IR( + # CHECK: for (int i_2 + # CHECK: for (int j_2 + # CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, ss.str()); +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_verifier.cpp b/test/cpp/tensorexpr/test_ir_verifier.cpp new file mode 100644 index 000000000000..886213ea9c76 --- /dev/null +++ b/test/cpp/tensorexpr/test_ir_verifier.cpp @@ -0,0 +1,191 @@ +#include + +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include +#include +#include +#include +#include +#include + +#include +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +TEST(IRVerifier, BitwiseOps) { + VarPtr X = alloc("x", kInt); + VarPtr Y = alloc("y", kFloat); + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, CompareSelect) { + ExprPtr X = alloc(1); + ExprPtr Y = alloc(3.14f); + { + auto a = alloc(X, X, X, Y, kEQ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y, X, X, kEQ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, Ramp) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kFloat); + { + auto a = alloc(I, J, 4); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, Load) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); + BufPtr B = alloc( + "b", + std::vector({alloc(10), alloc(20)}), + kFloat); + { + // Indices with different int dtypes (kInt, kLong) are ok + auto a = alloc(B, std::vector({I, J})); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_NO_THROW(verify(a)); + } + { + // Float index + auto a = alloc(B, std::vector({K, K})); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Multilanes are only allowed in flattened indices + auto multilane_index = alloc(I, alloc(1), 4); + auto a = alloc(B, std::vector({I, multilane_index})); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, IfThenElse) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); + { + // Condition must be integral + auto a = alloc(K, I, I); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Dtypes of true and false exprs must match + auto a = alloc(I, I, J); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Can't have multiple lanes in condition expr + auto a = alloc(alloc(I, 4), I, I); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, For) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kInt); + StmtPtr body = alloc(std::vector({})); + { + // Can't have nullptr as a Var + auto a = alloc(nullptr, I, J, body); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, Block) { + VarPtr I = alloc("i", kInt); + BufPtr B = alloc("B", std::vector({alloc(10)}), kInt); + { + StmtPtr store = alloc(B, std::vector({I}), I); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + StmtPtr block1 = alloc(std::vector({store})); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + StmtPtr block2 = alloc(std::vector({store})); + // Stmt can't have multiple parents, thus inserting it into several blocks + // is illegal + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(block2)); + } +} + +TEST(IRVerifier, Store) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); + BufPtr B = alloc( + "b", + std::vector({alloc(10), alloc(20)}), + kFloat); + { + // Indices with different int dtypes (kInt, kLong) are ok + auto a = alloc(B, std::vector({I, J}), K); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_NO_THROW(verify(a)); + } + { + // Float index + auto a = alloc(B, std::vector({K, K}), K); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Multilanes are only allowed in flattened indices + auto multilane_index = alloc(I, alloc(1), 4); + auto a = alloc(B, std::vector({I, multilane_index}), K); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Value and buf dtypes mismatch + auto a = alloc(B, std::vector({I}), I); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp new file mode 100644 index 000000000000..dc67928b111a --- /dev/null +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -0,0 +1,2133 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::indexing; +using namespace torch::jit::tensorexpr; + +class Kernel : public ::testing::Test { + public: + void SetUp() override { + getTEMustUseLLVMOnCPU() = false; + } +}; + +TEST_F(Kernel, ParallelExternalCallBuf) { + const auto graph_string = R"IR( + graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu), + %1 : Float(1000, 5000, strides=[5000, 1], device=cpu), + %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)): + %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1) + %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2) + return (%4))IR"; + auto graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); +#ifdef TORCH_ENABLE_LLVM + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR"; + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +#endif +} + +TEST_F(Kernel, InliningIntermediates) { + // here, each mul has only one use, so it should be completely inlined + { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %one : int = prim::Constant[value=1]() + %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one) + return (%5))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); + } + { + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=${device}), + %1 : Float(5, 3, strides=[3, 1], device=${device})): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %one : int = prim::Constant[value=1]() + %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one) + %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one) + %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0) + return (%4, %5))IR"; + for (bool use_cuda : {false, true}) { + if (!torch::cuda::is_available() && use_cuda) { + continue; + } + + at::jit::TemplateEnv env; + env.s("device", use_cuda ? "cuda:0" : "cpu"); + const auto graph_string = format(graph_template, env); + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + // aten_mul only has one use, inlined completely + torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); + + // aten_sub should be removed by the CUDA backend by metavar rewriting + // and by the CPU backend by horizontal fusion. + torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str()); + } + } +} + +TEST_F(Kernel, PreAllocIntermediateBufs) { + const auto graph_string = R"IR( +graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu), + %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)): + %2 : int = prim::Constant[value=1]() + %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12 + %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15 + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::matmul(a, b) + a; + TensorExprKernel k(graph, {}, {}, true); + + std::vector inputs = {a, b}; + auto stmt = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *stmt; + + // Check whether the intermediate buffer has been added to constants + auto constants = k.getConstantDescriptors(); + ASSERT_EQ(constants.size(), 1); + + // Check the IR we produced + torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str()); + torch::jit::testing::FileCheck().check_not("Free")->run(oss.str()); + + // Check correctness + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, _1) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + +TEST_F(Kernel, _2) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + +TEST_F(Kernel, _3) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[12, 2], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) + .index({Slice(None, None, 2), Slice(None, None, 2)}); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + +TEST_F(Kernel, Huge) { + const auto graph_string = R"IR( + graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)): + %1 : int = prim::Constant[value=0]() + %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1) + %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + std::ostringstream oss; + oss << *k.getCodeGenStmt(); + // The 4000000000 iterations loop will be split into 500000000 x 8 and the + // outer loop will be parallel. If LLVM is not present, it will not be split, + // and to cover both of these cases we're looking for 00000000ll; in the + // output. + const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST_F(Kernel, ParallelStrided) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), + %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)): + %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat)) + .index( + {Slice(None, None, 2), + Slice(None, None, 2), + Slice(None, None, 2)}); + auto ref = a * (a * b); + auto o = at::zeros_like(ref); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + +TEST_F(Kernel, DISABLED_Shape_Inference) { + // disabled: doesn't do stride propagation, and isn't being used currently + + // Test TensorExpr shape inference capabilities: it should only require shapes + // for the inputs + { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[12, 2], device=cpu)): + %2 : Tensor = aten::mul(%0, %1) + %3 : Tensor = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) + .index({Slice(None, None, 2), Slice(None, None, 2)}); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + } + { + const auto graph_string = R"IR( + graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), + %1 : Float(8, 8, strides=[8, 1], device=cpu)): + %2 : Tensor = aten::mul(%0, %1) + %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2) + %r : Tensor = aten::mul(%3, %4) + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat)); + auto t = torch::chunk(a * b, 2, 1); + auto ref = t[0] * t[1]; + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + TORCH_CHECK_EQ(o.sizes()[0], 8); + TORCH_CHECK_EQ(o.sizes()[1], 4); + for (size_t i = 0; i < 8 * 4; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + } + { + // Test that shape inference handles aten::unsqueeze + + const auto graph_string = R"IR( + graph(%a : Float(4, 2, strides=[2, 1], device=cpu), + %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu), + %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)): + %one : int = prim::Constant[value=1]() + %minus_one : int = prim::Constant[value=-1]() + %three : int = prim::Constant[value=3]() + %minus_four : int = prim::Constant[value=-4]() + %a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2] + %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1] + %b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1] + %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2] + %ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1] + %abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2] + return (%abc))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) * + at::unsqueeze(c, -4); + + TensorExprKernel k(graph); + std::vector inputs = {a, b, c}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_mul)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + } + { + // Test that shape inference handles aten::cat + + const auto graph_string = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::cat({a, b, c}, 1); + + TensorExprKernel k(graph); + std::vector inputs = {a, b, c}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_cat)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + } + { + // Test that we throw an error when input list for aten::cat is empty + + const auto graph_string = R"IR( + graph(): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct() + %r : Tensor = aten::cat(%inputs, %dim) + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + auto compile = [&]() { + TensorExprKernel k(graph); + k.getCodeGenStmt(); + }; + ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat"); + } + { + // Test that we throw an error when 'dim' passed to aten::cat is invalid + + const auto ir_dim_99 = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=99]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b) + %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) + return (%r))IR"; + const auto ir_dim_minus_6 = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=-6]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b) + %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) + return (%r))IR"; + + auto compile = [](const std::string& graph_string) { + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + k.getCodeGenStmt(); + }; + ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index"); + ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index"); + } +} + +TEST_F(Kernel, CatInputTypesPromotion) { + { + // Test that we properly promote input types for aten::cat + + const auto graph_string = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble)); + auto ref = at::cat({a, b, c}, 1); + + TensorExprKernel k(graph); + std::vector inputs = {a, b, c}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_cat)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); + } + } +} + +TEST_F(Kernel, ToDType) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): + %1 : NoneType = prim::Constant() + %2 : bool = prim::Constant[value=0]() + %3 : int = prim::Constant[value=6]() + %4 : int = prim::Constant[value=15]() + %5 : int = prim::Constant[value=5]() + %6 : bool = prim::Constant[value=1]() + %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1) + %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4) + %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6) + %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1) + %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1) + %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1) + return (%k.3))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_to +# CHECK-NEXT: } +# CHECK-NEXT: })IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16)); + auto ref = + at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat)); + + std::vector inputs = {a}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); +#endif +} + +TEST_F(Kernel, CatAndInlineWithAConstantDim) { + const auto graph_string = R"IR( + graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu), + %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)): + %2 : bool = prim::Constant[value=0]() + %3 : int = prim::Constant[value=1]() + %4 : Tensor[] = prim::ListConstruct(%0, %1) + %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3) + %6 : Tensor[] = prim::ListConstruct(%5) + %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3) + %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2) + return (%8, %7))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + + auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::_cast_Float(at::cat({a, b}, 1), 0); + + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, CatWithEmptyInputs) { + bool curr_cat_wo_conditionals = getCatWoConditionals(); + for (auto cat_wo_conditionals : {true, false}) { + getCatWoConditionals() = cat_wo_conditionals; + const auto graph_string = R"IR( + graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu), + %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)): + %3 : int = prim::Constant[value=0]() + %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0) + %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1) + %10 : Tensor[] = prim::ListConstruct(%6, %7) + %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3) + return (%11))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + + auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0); + + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); + } + getCatWoConditionals() = curr_cat_wo_conditionals; +} + +TEST_F(Kernel, CatWoConditionals) { + bool old_cat_wo_conditionals = getCatWoConditionals(); + getCatWoConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) + return (%r))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK: for +# CHECK: for +# CHECK: aten_cat +# CHECK: for +# CHECK: for +# CHECK: aten_cat +# CHECK: for +# CHECK: for +# CHECK: aten_cat)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::cat({a, b, c}, 1); + + std::vector inputs = {a, b, c}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + getCatWoConditionals() = old_cat_wo_conditionals; +} + +TEST_F(Kernel, OptimizeConditionals) { + bool old_cat_wo_conditionals = getCatWoConditionals(); + bool old_opt_conditionals = getOptConditionals(); + getCatWoConditionals() = false; + getOptConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(5, 3, strides=[3, 1], device=cpu), + %b : Float(5, 7, strides=[7, 1], device=cpu), + %c : Float(5, 9, strides=[9, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim) + %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r) + return (%t))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_relu +# CHECK: for +# CHECK-NEXT: aten_relu +# CHECK: for +# CHECK-NEXT: aten_relu +# CHECK-NOT: Allocate +# CHECK-NOT: Free)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat)); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::relu(at::cat({a, b, c}, 1)); + + std::vector inputs = {a, b, c}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + getOptConditionals() = old_opt_conditionals; + getCatWoConditionals() = old_cat_wo_conditionals; +} + +namespace { + +std::string dtypeConstant(ScalarType scalar_type) { + if (scalar_type == ScalarType::Undefined) { + return "None = prim::Constant()"; + } else { + at::jit::TemplateEnv env_dtype; + env_dtype.d("scalar_type", static_cast(scalar_type)); + return format("int = prim::Constant[value=${scalar_type}]()", env_dtype); + } +} + +at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { + int64_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + 1, + // NOLINTNEXTLINE(modernize-use-transparent-functors) + std::multiplies()); + std::vector values(numel); + std::iota(values.begin(), values.end(), 0); + auto a = at::tensor(values, options); + return a.reshape(sizes); +} + +} // namespace + +TEST_F(Kernel, SumAllAxes) { + // Test lowering of sum on all axes. + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): + %1 : ${dtype} + %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1) + return (%2))IR"; + auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { + at::jit::TemplateEnv env; + env.s("dtype", dtypeConstant(scalar_type)); + if (scalar_type == ScalarType::Undefined) { + env.s("out_dtype", "Float"); + } else { + env.s("out_dtype", "Double"); + } + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto o = at::empty({}, TensorOptions(kCPU)); + std::optional dtype; + if (scalar_type != ScalarType::Undefined) { + dtype = static_cast(scalar_type); + } + auto ref = a.sum(/*dtype=*/dtype); + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); + } +} + +std::string li_to_str(at::ArrayRef li) { + std::stringstream out; + bool first = true; + for (auto elem : li) { + if (!first) { + out << ", "; + } + out << elem; + first = false; + } + return out.str(); +} + +TEST_F(Kernel, SumOneAxis) { + // Test lowering of sum on one axis. + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): + %1 : int[] = prim::Constant[value=[${dim}]]() + %2 : bool = prim::Constant[value=${keepdim}]() + %3 : ${dtype} + %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3) + return (%4))IR"; + auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + for (int dim = -a.dim(); dim < a.dim(); ++dim) { + for (bool keepdim : {false, true}) { + for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { + at::jit::TemplateEnv env; + env.d("dim", dim); + env.d("keepdim", keepdim); + env.s("dtype", dtypeConstant(scalar_type)); + std::optional dtype; + if (scalar_type != ScalarType::Undefined) { + dtype = static_cast(scalar_type); + } + auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype); + if (scalar_type == ScalarType::Undefined) { + env.s("out_dtype", "Float"); + } else { + env.s("out_dtype", "Double"); + } + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto o = at::empty({}, TensorOptions(kCPU)); + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t +# CHECK-NEXT: sum +# CHECK-NEXT: for (int64_t +# CHECK-NEXT: sum)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); + } + } + } +} + +TEST_F(Kernel, SumMultipleAxes) { + // Test lowering of sum on multiple axes. + const auto graph_template = R"IR( + graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)): + %1 : int = prim::Constant[value=${dim1}]() + %2 : int = prim::Constant[value=${dim2}]() + %3 : int[] = prim::ListConstruct(%1, %2) + %4 : bool = prim::Constant[value=${keepdim}]() + %5 : ${dtype} + %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5) + return (%6))IR"; + auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + // Only iterate over positive values of axes to keep the running time + // reasonable, since the number of pairs is quadratic. + for (const auto dim1 : c10::irange(a.dim())) { + for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) { + for (bool keepdim : {false, true}) { + at::jit::TemplateEnv env; + env.d("dim1", dim1); + env.d("dim2", dim2); + env.d("keepdim", keepdim); + env.s("dtype", dtypeConstant(ScalarType::Undefined)); + auto o = at::empty({}, TensorOptions(kCPU)); + auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); + + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t +# CHECK: for (int64_t +# CHECK: for (int64_t +# CHECK: for (int64_t +# CHECK: sum)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); + } + } + } +} + +// This test and the following ones testing Softmax only tests with dim set +// to one of the valid input dimensions. It does not test with dim=None +// because that is supposed to be deprecated. +TEST_F(Kernel, Softmax2D) { + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %dt_float : int = prim::Constant[value=7]() + %dt_none : NoneType = prim::Constant() + %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt}) + return (%4))IR"; + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 5 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK-NEXT: aten_softmax)IR"; + + for (bool empty_dtype : {false, true}) { + for (auto log_softmax : {false, true}) { + for (const auto softmax_dim : c10::irange(a.dim())) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + auto other_dim = (softmax_dim + 1) % a.dim(); + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + at::jit::TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + env.s("dt", empty_dtype ? "dt_none" : "dt_float"); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + at::jit::TemplateEnv ver_env; + ver_env.d("other_dim", other_dim); + ver_env.d("other_dim_size", a.sizes()[other_dim]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = + format(verification_template, ver_env); + + // verification string temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, + // oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } + } +} + +TEST_F(Kernel, Softmax3D) { + const auto graph_template = R"IR( + graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} + # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 3 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 + # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 + # CHECK-NEXT: aten_softmax)IR"; + + for (auto log_softmax : {false, true}) { + for (const auto softmax_dim : c10::irange(a.dim())) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + std::vector other_dims; + for (const auto i : c10::irange(a.dim())) { + if (i != softmax_dim) { + other_dims.push_back(i); + } + } + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + + at::jit::TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + at::jit::TemplateEnv ver_env; + ver_env.d("dim1", other_dims[0]); + ver_env.d("dim1_size", a.sizes()[other_dims[0]]); + ver_env.d("dim2", other_dims[1]); + ver_env.d("dim2_size", a.sizes()[other_dims[1]]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + + // verification string temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } +} + +TEST_F(Kernel, Softmax4D) { + const auto graph_template = R"IR( + graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} + # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} + # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} + # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 2 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 + # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 + # CHECK-NEXT: aten_softmax)IR"; + + for (auto log_softmax : {false, true}) { + for (const auto softmax_dim : c10::irange(a.dim())) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + std::vector other_dims; + for (const auto i : c10::irange(a.dim())) { + if (i != softmax_dim) { + other_dims.push_back(i); + } + } + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + + at::jit::TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + at::jit::TemplateEnv ver_env; + ver_env.d("dim1", other_dims[0]); + ver_env.d("dim1_size", a.sizes()[other_dims[0]]); + ver_env.d("dim2", other_dims[1]); + ver_env.d("dim2_size", a.sizes()[other_dims[1]]); + ver_env.d("dim3", other_dims[2]); + ver_env.d("dim3_size", a.sizes()[other_dims[2]]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + + // verification string temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } +} + +TEST_F(Kernel, SignTest) { + const auto graph_template = R"IR( + graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)): + %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0) + return (%2))IR"; + + auto run_test = [](const std::string& graph_string, const at::Tensor& input) { + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + + std::vector inputs = {input}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = at::sign(input); + ASSERT_TRUE(at::allclose(o, ref)); + }; + auto common_options = at::TensorOptions() + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + int default_input_size = 100; + for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) { + at::Tensor corner_case_inputs; + at::jit::TemplateEnv env; + auto options = common_options; + switch (scalar_type) { + case ScalarType::Float: { + env.s("dtype", "Float"); + options = options.dtype(at::kFloat); + std::vector input_float = { + 0.0f, + -0.0f, + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::nanf("1"), + -std::nanf("1")}; + corner_case_inputs = at::from_blob( + input_float.data(), + {static_cast(input_float.size())}, + options); + auto rand_input = at::rand({default_input_size}, options); + auto input = at::cat({rand_input, corner_case_inputs}); + env.d("size", at::numel(input)); + const auto graph_string = format(graph_template, env); + run_test(graph_string, input); + break; + } + case ScalarType::Double: { + env.s("dtype", "Double"); + options = options.dtype(at::kDouble); + std::vector input_double = { + 0.0, + -0.0, + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::nan("1"), + -std::nan("1")}; + corner_case_inputs = at::from_blob( + input_double.data(), + {static_cast(input_double.size())}, + options); + auto rand_input = at::rand({default_input_size}, options); + auto input = at::cat({rand_input, corner_case_inputs}); + env.d("size", at::numel(input)); + const auto graph_string = format(graph_template, env); + run_test(graph_string, input); + break; + } + default: + throw unsupported_dtype(); + } + } +} + +TEST_F(Kernel, InlineProducerIntoReduction) { + // Inline producer (mul) into reduction (sum). + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1) + %3 : int = prim::Constant[value=7]() + %4 : Double(device=cpu) = aten::sum(%2, %3) + return (%4))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + // Check the IR we produced. + // We should have only one loop in the end. + const std::string& verification_pattern = + R"IR( + # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 + # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 + # CHECK-NEXT: sum + # CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = (a * b).sum(at::kDouble); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, InlineReductionIntoConsumer) { + // Inline producer (mul %2) into reduction (sum %4) but DO NOT + // inline the reduction into consumer (mul %4). + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : int = prim::Constant[value=6]() + %4 : Float(device=cpu) = aten::sum(%2, %3) + %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4) + return (%5))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + // Check the IR we produced. + // We should have two loops in the end. + const std::string& verification_pattern = + R"IR( + # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 + # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 + # CHECK-NEXT: sum + # CHECK: for (int64_t i_2 = 0ll; i_2 < 5 + # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3 + # CHECK-NEXT: aten_mul + # CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = (a * b).sum(at::kFloat) * (a * b); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, SanitizeNames_CUDA) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0), + %1 : Float(5, 3, strides=[3, 1], device=cuda:0)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%4))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + graph->inputs().at(0)->setDebugName("aten::add:"); + graph->inputs().at(1)->setDebugName("aten::add_"); + TensorExprKernel k(graph); + auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); + auto ref = a * (a * b); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, SanitizeConstants_CUDA) { + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)): + %none : NoneType = prim::Constant() + %size : int = prim::Constant[value=16]() + %sizes : int[] = prim::ListConstruct(%size, %size) + %30 : Device = prim::Constant[value="cuda"]() + %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none) + %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we insert a call to + // aten::ones and then const-prop it + ConstantPropagation(graph); + + // We set the name of the constant to include special characters that are + // not allowed. This should be fixed by the sanitizer in TensorExprKernel. + graph->nodes().front()->output()->setDebugName("illegal.name"); + + // Check if we have a constant node with illegal name in the graph. + auto const_node = graph->nodes().front(); + ASSERT_EQ(const_node->kind(), prim::Constant); + ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos); + + TensorExprKernel k(graph); + + auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); + std::vector inputs = {x}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); + auto ref = x * y; + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, ConstantTensors) { + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): + %none : NoneType = prim::Constant() + %size : int = prim::Constant[value=16]() + %sizes : int[] = prim::ListConstruct(%size, %size) + %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none) + %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we insert a call to + // aten::ones and then const-prop it + ConstantPropagation(graph); + + TensorExprKernel k(graph); + + auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {x}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = x * y; + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, ConstantTensorsNonContiguous) { + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): + %none : NoneType = prim::Constant() + %dtype : int = prim::Constant[value=6]() + %c0 : int = prim::Constant[value=0]() + %c256 : int = prim::Constant[value=256]() + %c16 : int = prim::Constant[value=16]() + %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) + %sizes : int[] = prim::ListConstruct(%c16, %c16) + %y_t : Tensor = aten::view(%y_flat, %sizes) + %y : Tensor = aten::t(%y_t) + %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we generate several aten + // calls to produce non-contiguous constant tensor and then const-prop it + ConstantPropagation(graph); + + TensorExprKernel k(graph); + + auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {x}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat)) + .view({16, 16}) + .t(); + auto ref = x * y; + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, RunFast) { +#ifdef TORCH_ENABLE_LLVM + // TODO: Implement call_raw in IREval and remove the ifdef + + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + + k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()}); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +TEST_F(Kernel, RunWithAllocatedOutputs) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + + std::vector args = {o, a, b}; + std::vector stack = fmap(args); + k.runWithAllocatedOutputs(stack); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +TEST_F(Kernel, CodegenInspection) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): + %none : NoneType = prim::Constant() + %dtype : int = prim::Constant[value=6]() + %c0 : int = prim::Constant[value=0]() + %c256 : int = prim::Constant[value=256]() + %c16 : int = prim::Constant[value=16]() + %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) + %sizes : int[] = prim::ListConstruct(%c16, %c16) + %y_t : Tensor = aten::view(%y_flat, %sizes) + %y : Tensor = aten::t(%y_t) + %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we generate several aten + // calls to produce non-contiguous constant tensor and then const-prop it + ConstantPropagation(graph); + + TensorExprKernel k(graph); + + // Check that we could retrieve generated assembly + auto asm_str = k.getCodeText("asm"); + const std::string& asm_verification_pattern = + R"ASM( + # CHECK: .text + # CHECK: retq)ASM"; + torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str); + + // Check that we could retrieve info about codegen parameters + auto constants = k.getConstantDescriptors(); + auto buf_args = k.getBufferArgs(); + // Expected buf args: [input0, output0, constant0] + ASSERT_EQ(buf_args.size(), 3); + ASSERT_EQ(constants.size(), 1); + ASSERT_TRUE( + !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar()); +#endif +} + +Tensor lowerNanToNum( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device) { + auto input_buf = std::get(inputs[0]); + auto e = Compute( + "custom_nan_to_num", + outputShape, + outputStrides, + [&](const std::vector& axes) { + std::vector indices(axes.begin(), axes.end()); + auto load = input_buf.load(indices); + return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load); + }); + return e; +} + +TEST_F(Kernel, CustomLowering) { + const auto graph_string = R"IR( + graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): + %none : NoneType = prim::Constant() + %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) + return (%y) +)IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + std::unordered_map lowerings = { + {aten::nan_to_num, lowerNanToNum}}; + TensorExprKernel k(graph, lowerings); + + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + + // Check that our custom lowering is actually used + torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str()); + torch::jit::testing::FileCheck().check("isnan")->run(oss.str()); +} + +TEST_F(Kernel, Vectorize) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(100, 16, strides=[16, 1], device=cpu), + %1 : Float(100, 16, strides=[16, 1], device=cpu)): + %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1) + %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 100 * 16; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first. +TEST_F(Kernel, DISABLED_FlattenVectorize) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(100, 3, strides=[3, 1], device=cpu), + %1 : Float(100, 3, strides=[3, 1], device=cpu)): + %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 100 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +TEST_F(Kernel, Strided1dWithinBounds) { + auto ir = R"IR( + graph(%0 : Float(3, strides=[1], device=cpu), + %1 : Float(3, strides=[2], device=cpu)): + %2 : int = prim::Constant[value=1]() + %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2) + return (%3))IR"; + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR(ir, graph.get(), vmap); + TensorExprKernel k(graph); + + auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat)) + .index({Slice(None, None, 2)}); + auto expect = a + b; + + std::vector inputs = {a, b}; + + std::vector stack = fmap(inputs); + k.run(stack); + + auto output = stack[0].toTensor(); + + for (size_t i = 0; i < 3; ++i) { + TORCH_CHECK_EQ( + ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]); + } +} + +TEST_F(Kernel, InputAsOutput) { + const auto graph_string = R"IR( + graph(%x : Float(5, 3, strides=[3, 1], device=cpu), + %y : Float(5, 3, strides=[1, 5], device=cpu)): + return (%x, %y))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto y = + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + TensorExprKernel k(graph); + std::vector inputs = {x, y}; + + std::vector stack = fmap(inputs); + k.run(stack); + CHECK(at::allclose(x, stack[0].toTensor())); + CHECK(at::allclose(y, stack[1].toTensor())); +} + +TEST_F(Kernel, ScalarOut) { + auto ir = R"IR( +graph(%x : int, %y : int): + %z : int = aten::mul(%x, %y) + %r : int = aten::mul(%z, %x) + return (%r, %z))IR"; + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR(ir, graph.get(), vmap); + TensorExprKernel k(graph); + + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + + // Verify the generated IR. We expect to see a scalar variable (Let) followed + // by a store to a 0-dim buffer. + const std::string& verification_pattern = R"IR( +# CHECK: int64_t +# CHECK-NEXT: [0ll] = +# CHECK-NEXT: int64_t +# CHECK-NEXT: [0ll] = +)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + int64_t x = 2, y = 3, r = 0, z = 0; + + // Verify that TEK::runFast works correctly with scalar outputs + std::vector inputs = {&x, &y}; + std::vector outputs = {&r, &z}; + k.runFast(inputs, outputs); + TORCH_CHECK_EQ(z, x * y); + TORCH_CHECK_EQ(r, z * x); + + // Verify that TEK::run works correctly with scalar outputs + std::vector stack = {x, y}; + k.run(stack); + TORCH_CHECK_EQ(stack[0], x * y * x); + TORCH_CHECK_EQ(stack[1], x * y); +} + +TEST_F(Kernel, ScalarTensorOut) { + auto ir = R"IR( +graph(%x : int, + %xt : Long(3, strides=[1], device=cpu), + %y : int, + %yt : Long(3, strides=[1], device=cpu)): + %z : int = aten::mul(%x, %y) + %r : int = aten::mul(%z, %x) + %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y) + %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt) + return (%r, %rt, %z, %zt))IR"; + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR(ir, graph.get(), vmap); + TensorExprKernel k(graph); + int64_t x = 2, y = 3, r = 0, z = 0; + auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2; + auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3; + auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); + auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); + + // Verify that TEK::runFast works correctly with mixed scalar and tensor + // inputs/outputs + std::vector inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; + std::vector outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; + k.runFast(inputs, outputs); + TORCH_CHECK_EQ(z, x * y); + TORCH_CHECK_EQ(r, z * x); + ASSERT_TRUE(at::equal(zt, xt * yt)); + ASSERT_TRUE(at::equal(rt, zt * xt)); + + // Verify that TEK::run works correctly with mixed scalar and tensor + // inputs/outputs + std::vector stack = {x, xt, y, yt}; + k.run(stack); + TORCH_CHECK_EQ(stack[0], x * y * x); + ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt)); + TORCH_CHECK_EQ(stack[2], x * y); + ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt)); +} + +TEST_F(Kernel, FuseLoopsWithVariableBounds) { +#ifdef TORCH_ENABLE_LLVM + bool old_cat_wo_conditionals = getCatWoConditionals(); + getCatWoConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu), + %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu), + %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] + return (%r))IR"; + std::shared_ptr graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::vector symbolic_shape_inputs = {-2, -3}; + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->inputs().at(2)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + std::ostringstream oss; + oss << *kernel.getCodeGenStmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t i +# CHECK-NEXT: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK-NOT: for (int64_t i + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto run_kernel = [&](int dim1, int dim2) { + auto a = + at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto c = + at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + + auto ref = at::cat({a, b, c}, 1); + + std::vector stack = + fmap(std::vector({a, b, c})); + stack.emplace_back(dim1); + stack.emplace_back(dim2); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + }; + + run_kernel(10, 20); + getCatWoConditionals() = old_cat_wo_conditionals; +#endif +} + +TEST_F(Kernel, FuseLoopsWithVariableConcatDim) { +#ifdef TORCH_ENABLE_LLVM + bool old_cat_wo_conditionals = getCatWoConditionals(); + getCatWoConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), + %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), + %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int, + %SS_4 : int, + %SS_5 : int): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] + return (%r))IR"; + std::shared_ptr graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::vector symbolic_shape_inputs = {-2, -3, -4, -5}; + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->inputs().at(2)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + std::ostringstream oss; + oss << *kernel.getCodeGenStmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t i +# CHECK-NEXT: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK-NOT: for (int64_t i + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto run_kernel = [&](int dim1, int dim2, int dim3) { + auto a = + at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto c = + at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + + auto ref = at::cat({a, b, c}, 1); + + std::vector stack = + fmap(std::vector({a, b, c})); + stack.emplace_back(dim1); + stack.emplace_back(dim2); + stack.emplace_back(dim3); + stack.emplace_back(3 * dim3); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + }; + + run_kernel(10, 20, 15); + getCatWoConditionals() = old_cat_wo_conditionals; +#endif +} + +TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) { +#ifdef TORCH_ENABLE_LLVM + bool old_cat_wo_conditionals = getCatWoConditionals(); + getCatWoConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), + %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int, + %SS_4 : int, + %SS_5 : int, + %SS_6 : int): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b) + %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] + return (%r))IR"; + std::shared_ptr graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::vector symbolic_shape_inputs = {-2, -3, -4, -5, -6}; + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + std::ostringstream oss; + oss << *kernel.getCodeGenStmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t i +# CHECK-NEXT: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK-NOT: for (int64_t j +# CHECK-NOT: for (int64_t i + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) { + auto a = + at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); + + auto ref = at::cat({a, b}, 1); + + std::vector stack = fmap(std::vector({a, b})); + stack.emplace_back(dim2); + stack.emplace_back(dim3); + stack.emplace_back(dim4); + stack.emplace_back(dim5); + stack.emplace_back(dim4 + dim5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + }; + + run_kernel(10, 20, 15, 8); + getCatWoConditionals() = old_cat_wo_conditionals; +#endif +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp new file mode 100644 index 000000000000..f6ffc84f62c0 --- /dev/null +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -0,0 +1,1799 @@ +#ifdef TORCH_ENABLE_LLVM +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +using LLVMExprEval = ExprEval; + +// Typed tests, can't use gtest params here due to the way we instantiate tests. +#define TEST_LLVM_SCALAR_TYPES(_) \ + _(uint8_t, Byte, 24) \ + _(int8_t, Char, -20) \ + _(int16_t, Short, 3332) \ + _(int, Int, 123456) \ + _(int64_t, Long, 2631563121321) \ + _(float, Float, 0.122) \ + _(double, Double, 0.21312) \ + _(at::Half, Half, 0.128f) + +#define IMM_TEST(Type, Name, Val) \ + TEST(LLVM, Name##ImmTest) { \ + auto a = Name##Imm::make(Val); \ + LLVMExprEval cg(a); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), Val, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), Val); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(IMM_TEST) +#undef IMM_TEST + +#define ADD_TEST(Type, Name, Val) \ + TEST(LLVM, Name##AddTest) { \ + auto a = Name##Imm::make(Val); \ + auto b = Name##Imm::make(Val * 2); \ + auto c = Add::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), Val * 3, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), Val * 3); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(ADD_TEST) +#undef ADD_TEST + +#define SUB_TEST(Type, Name, Val) \ + TEST(LLVM, Name##SubTest) { \ + auto a = Name##Imm::make(Val * 2); \ + auto b = Name##Imm::make(Val); \ + auto c = Sub::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), Val, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), Val); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(SUB_TEST) +#undef SUB_TEST + +#define MUL_TEST(Type, Name, Val) \ + TEST(LLVM, Name##MulTest) { \ + auto a = Name##Imm::make(Val); \ + auto b = Name##Imm::make((Type)4); \ + auto c = Mul::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), Val * 4, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), Val * 4); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(MUL_TEST) +#undef MUL_TEST + +#define DIV_TEST(Type, Name, Val) \ + TEST(LLVM, Name##DivTest) { \ + auto a = Name##Imm::make((Type)6); \ + auto b = Name##Imm::make((Type)3); \ + auto c = Div::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), 2, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), 2); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(DIV_TEST) +#undef DIV_TEST + +TEST(LLVM, IntToFloatCastTest) { + auto a = IntImm::make(2); + auto b = Cast::make(kFloat, a); + LLVMExprEval cg(b, {}); + ASSERT_EQ(cg.value(), 2.0); +} + +TEST(LLVM, FloatToIntCastTest) { + auto a = FloatImm::make(2.0); + auto b = Cast::make(kInt, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 2); +} + +TEST(LLVM, IntToLongCastTest) { + auto a = IntImm::make(12345); + auto b = Cast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 12345); +} + +TEST(LLVM, ByteToCharCastTest) { + auto a = ByteImm::make(250); + auto b = Cast::make(kChar, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), (int8_t)250); +} + +TEST(LLVM, HalfToLongCastTest) { + auto a = HalfImm::make(2.0); + auto b = Cast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 2); +} + +TEST(LLVM, ByteToDoubleCastTest) { + auto a = ByteImm::make(2); + auto b = Cast::make(kDouble, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 2); +} + +TEST(LLVM, FloatToByteCastTest) { + auto a = FloatImm::make(254.0); + auto b = Cast::make(kByte, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 254); +} + +TEST(LLVM, FloatToCharCastTest) { + auto a = FloatImm::make(-2.0); + auto b = Cast::make(kChar, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), -2); +} + +TEST(LLVM, ByteToFloatCastTest) { + auto a = ByteImm::make(254); + auto b = Cast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 254.0); +} + +TEST(LLVM, CharToFloatCastTest) { + auto a = CharImm::make(-2); + auto b = Cast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), -2.0); +} + +TEST(LLVM, BitCast) { + /* constexpr int16_t ref16 = 1337; */ + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + + // this is broken + /*{ + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(k); + auto b = BitCast::make(kShort, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } +} + +TEST(LLVM, fastLogFloat) { + const int kTotalSize = 128 * 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = at::randn({1}).item().to(); + } + + LLVMCodeGen ir_eval(stmt, {a_buf, b_buf}); + ir_eval.call({a_v, b_v}); + + for (const auto i : c10::irange(kTotalSize)) { + auto test = b_v(i); + auto ref = std::log(a_v(i)); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_FLOAT_EQ(test, ref); + } + } +} + +TEST(LLVM, LetTest01) { + BufHandle a("A", {1}, kFloat); + std::vector v = {1, 0}; + std::vector args({v.data()}); + VarHandle x("x", kFloat); + auto block = Block::make({ + Let::make(x, 3.f), + a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))), + }); + + LLVMCodeGen cg(block, {a}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f); +} + +TEST(LLVM, LetTest02) { + BufHandle a("A", {1}, kFloat); + std::vector v = {1, 0}; + std::vector args({v.data()}); + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + auto block = Block::make( + {Let::make(x, 3.f), + Let::make(y, 6.f), + a.store( + {IntImm::make(0)}, + ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))}); + + LLVMCodeGen cg(block, {a}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f); +} + +TEST(LLVM, LetTestMultitype) { + BufHandle a("A", {1}, kDouble); + std::vector v = {1, 0}; + std::vector args({v.data()}); + VarHandle x("x", kByte); + VarHandle y("y", kHalf); + auto block = Block::make( + {Let::make(x, 3), + Let::make(y, 6.f), + a.store( + {0}, + Cast::make( + kDouble, + ExprHandle(2.f) + + (x * ExprHandle(3.f) + y * ExprHandle(4.f))))}); + + LLVMCodeGen cg(block, {a}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f); +} + +TEST(LLVM, BufferTest) { + BufHandle a("A", {32}, kFloat); + std::vector v(5); + std::vector args({v.data()}); + auto rv = IntImm::make(0); + LLVMExprEval cg(rv, {a}); + ASSERT_EQ(cg.value(args), 0); +} + +TEST(LLVM, BlockTest) { + BufHandle a("A", {32}, kInt); + std::vector v = {1, 2}; + std::vector args({v.data()}); + + auto block = Block::make({ + a.store({0}, 3), + a.store({1}, 4), + a.store({0}, 4), + }); + + LLVMCodeGen cg(block, {a}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(v[0], 4); + ASSERT_EQ(v[1], 4); +} + +TEST(LLVM, LoadStoreTest) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + std::vector a_buffer = {42}; + std::vector b_buffer = {-11}; + + auto store = b.store({0}, a.load(0)); + LLVMCodeGen cg(store, {a, b}); + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(a_buffer[0], 42); + ASSERT_EQ(b_buffer[0], 42); +} + +TEST(LLVM, IfThenElseTest) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {1}, kInt); + std::vector a_buffer = {42}; + std::vector b_buffer = {-11}; + std::vector c_buffer = {1}; + + auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0)); + LLVMCodeGen cg(store, {a, b, c}); + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(a_buffer[0], 42); + ASSERT_EQ(b_buffer[0], 42); +} + +// if (x < 10) x = x + 1 +TEST(LLVM, CondNoFalseBlockTest) { + BufHandle x("X", {1}, kInt); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value); + } + } +} + +// if (x < 10) { +// x = x + 1; +// } else { +// x = x - 1; +// } +TEST(LLVM, CondTest) { + BufHandle x("X", {1}, kInt); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = + Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto block = Block::make({ + cond, + x.store({0}, x.load(0) * 2), + }); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(block, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], (x_value + 1) * 2); + } else { + ASSERT_EQ(x_buffer[0], (x_value - 1) * 2); + } + } +} + +// if (x < 10) { +// if (x > 5) { +// x = x + 1; +// } else { +// x = x - 1; +// } +// } else { +// if (x <= 15) { +// x = x + 2; +// } else { +// x = x - 2; +// } +// } +TEST(LLVM, CondNestedTest) { + BufHandle x("X", {1}, kInt); + auto true_cmp = + CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); + auto true_cond = Cond::make( + true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto false_cmp = + CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE); + auto false_cond = Cond::make( + false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, true_cond, false_cond); + + for (int32_t x_value : {0, 8, 15, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + if (x_value > 5) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value - 1); + } + } else { + if (x_value <= 15) { + ASSERT_EQ(x_buffer[0], x_value + 2); + } else { + ASSERT_EQ(x_buffer[0], x_value - 2); + } + } + } +} + +TEST(LLVM, DirectVectorization) { + constexpr int M = 3; + constexpr int N = 64; + BufHandle a("a", {M, N}, kFloat); + BufHandle b("b", {M, N}, kFloat); + BufHandle c("c", {M, N}, kFloat); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + StmtPtr s = For::make( + m, + 0, + M, + Store::make( + c, + {Ramp::make(m * 64, 1, 64)}, + Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) * + Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)}))); + LLVMCodeGen cg(s, {a, b, c}); +} + +TEST(LLVM, VecLoadStoreTest) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + std::vector a_buffer = {1, 1, 1, 1}; + std::vector b_buffer = {2, 2, 2, 2}; + + auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)})); + LLVMCodeGen cg(store, {a, b}); + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(a_buffer[0], 1); + ASSERT_EQ(a_buffer[1], 1); + ASSERT_EQ(a_buffer[2], 1); + ASSERT_EQ(a_buffer[3], 1); + ASSERT_EQ(b_buffer[0], 1); + ASSERT_EQ(b_buffer[1], 1); + ASSERT_EQ(b_buffer[2], 1); + ASSERT_EQ(b_buffer[3], 1); +} + +#define FLOAT_INTRINSICS_TEST(Name, Lanes) \ + TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \ + BufHandle a("A", {1}, kFloat); \ + BufHandle b("B", {1}, kFloat); \ + float val = 0.5f; \ + std::vector a_buffer(Lanes, val); \ + std::vector b_buffer(Lanes, val); \ + auto store = b.store( \ + {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ + LLVMCodeGen cg(store, {a, b}); \ + std::vector args({a_buffer.data(), b_buffer.data()}); \ + ASSERT_EQ(cg.value(args), 0); \ + for (const auto i : c10::irange(Lanes)) { \ + ASSERT_FLOAT_EQ(a_buffer[i], val); \ + } \ + } // namespace jit +FLOAT_INTRINSICS_TEST(erf, 4) +FLOAT_INTRINSICS_TEST(erfc, 4) +FLOAT_INTRINSICS_TEST(acos, 4) +FLOAT_INTRINSICS_TEST(asin, 4) +FLOAT_INTRINSICS_TEST(atan, 4) +FLOAT_INTRINSICS_TEST(cosh, 4) +FLOAT_INTRINSICS_TEST(sinh, 4) +FLOAT_INTRINSICS_TEST(tanh, 4) +FLOAT_INTRINSICS_TEST(expm1, 4) +FLOAT_INTRINSICS_TEST(lgamma, 4) +FLOAT_INTRINSICS_TEST(erf, 8) +FLOAT_INTRINSICS_TEST(erfc, 8) +FLOAT_INTRINSICS_TEST(acos, 8) +FLOAT_INTRINSICS_TEST(asin, 8) +FLOAT_INTRINSICS_TEST(atan, 8) +FLOAT_INTRINSICS_TEST(cosh, 8) +FLOAT_INTRINSICS_TEST(sinh, 8) +FLOAT_INTRINSICS_TEST(tanh, 8) +FLOAT_INTRINSICS_TEST(expm1, 8) +FLOAT_INTRINSICS_TEST(lgamma, 8) +#undef FLOAT_INTRINSICS_TEST + +#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \ + TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \ + BufHandle a("A", {1}, kDouble); \ + BufHandle b("B", {1}, kDouble); \ + float val = 0.5f; \ + std::vector a_buffer(Lanes, val); \ + std::vector b_buffer(Lanes, val); \ + auto store = b.store( \ + {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ + LLVMCodeGen cg(store, {a, b}); \ + std::vector args({a_buffer.data(), b_buffer.data()}); \ + ASSERT_EQ(cg.value(args), 0); \ + for (const auto i : c10::irange(Lanes)) { \ + ASSERT_FLOAT_EQ(a_buffer[i], val); \ + } \ + } // namespace jit +DOUBLE_INTRINSICS_TEST(erf, 2) +DOUBLE_INTRINSICS_TEST(erfc, 2) +DOUBLE_INTRINSICS_TEST(acos, 2) +DOUBLE_INTRINSICS_TEST(asin, 2) +DOUBLE_INTRINSICS_TEST(atan, 2) +DOUBLE_INTRINSICS_TEST(cosh, 2) +DOUBLE_INTRINSICS_TEST(sinh, 2) +DOUBLE_INTRINSICS_TEST(tanh, 2) +DOUBLE_INTRINSICS_TEST(expm1, 2) +DOUBLE_INTRINSICS_TEST(lgamma, 2) +DOUBLE_INTRINSICS_TEST(erf, 4) +DOUBLE_INTRINSICS_TEST(erfc, 4) +DOUBLE_INTRINSICS_TEST(acos, 4) +DOUBLE_INTRINSICS_TEST(asin, 4) +DOUBLE_INTRINSICS_TEST(atan, 4) +DOUBLE_INTRINSICS_TEST(cosh, 4) +DOUBLE_INTRINSICS_TEST(sinh, 4) +DOUBLE_INTRINSICS_TEST(tanh, 4) +DOUBLE_INTRINSICS_TEST(expm1, 4) +DOUBLE_INTRINSICS_TEST(lgamma, 4) +#undef DOUBLE_INTRINSICS_TEST + +TEST(LLVM, VectorizerLoadStoreTest) { + BufHandle a("A", {1}, kInt); + + Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); }); + + BufHandle c_buf(c.buf()); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); + + ASSERT_TRUE(to(to(s)->front()) == nullptr); + + LLVMCodeGen cg(s, {a, c_buf}); + + std::vector a_vec(4, 21); + std::vector c_vec(4, 0); + std::vector args({a_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 21); +} + +TEST(LLVM, VectorizeBitCast) { + BufHandle a("A", {128}, kInt); + + Tensor c = Compute("c", {128}, [&](const VarHandle& i) { + return bitcast(a.load(i)); + }); + + BufHandle c_buf(c.buf()); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); + ASSERT_TRUE(to(to(s)->front()) == nullptr); + + LLVMCodeGen cg(s, {a, c_buf}); + + std::vector a_vec(128); + std::vector c_vec(128); + for (const auto i : c10::irange(128)) { + a_vec[i] = raw_bitcast(1337.f); + } + std::vector args({a_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 1337.f); +} + +TEST(LLVM, MemcpyTest) { + constexpr int N = 32; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + std::vector a_buffer(N, 42); + std::vector b_buffer(N, 0); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, b.store({i}, a.load(i))); + + LLVMCodeGen cg(expr, {a, b}); + + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(a_buffer, 42); + assertAllEqual(b_buffer, 42); +} + +TEST(LLVM, BzeroTest) { + constexpr int N = 32; + BufHandle b("B", {N}, kInt); + std::vector b_buffer(N, 11); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, b.store({i}, 0)); + + LLVMCodeGen cg(expr, {b}); + + std::vector args({b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(b_buffer, 0); +} + +TEST(LLVM, ElemwiseAdd) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 42); +} + +TEST(LLVM, ElemwiseAddFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 42.0f); +} + +TEST(LLVM, ElemwiseLog10Float) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + std::vector a_buffer(N, 10.0f); + std::vector b_buffer(N, 2.0f); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N / 4, + b.store( + {Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)})))); + + LLVMCodeGen cg(expr, {a, b}); + + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(a_buffer, 10.0f); + assertAllEqual(b_buffer, 1.0f); +} + +TEST(LLVM, ElemwiseLog1pFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + std::vector a_buffer(N, expf(3.0f) - 1); + std::vector b_buffer(N, 42.0f); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N / 4, + b.store( + {Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)})))); + + LLVMCodeGen cg(expr, {a, b}); + + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(a_buffer, expf(3.0f) - 1); + ExpectAllNear(b_buffer, 3.0f, 1e-5f); +} + +TEST(LLVM, ElemwiseMaxInt) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 41); +} + +TEST(LLVM, ElemwiseMinInt) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 1); +} + +TEST(LLVM, ElemwiseMaxFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 41.0f); +} + +TEST(LLVM, ElemwiseMaxNaNFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, NAN); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(b_buffer, 1.0f); + for (auto const& elt : c_buffer) { + ASSERT_TRUE(std::isnan(elt)); + } +} + +TEST(LLVM, ElemwiseMinFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1.0f); +} + +TEST(LLVM, ElemwiseMinNaNFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, NAN); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(b_buffer, 1.0f); + for (auto const& elt : c_buffer) { + ASSERT_TRUE(std::isnan(elt)); + } +} + +TEST(LLVM, ElemwiseMod) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 23); + std::vector c_buffer(N, 18); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i)))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 23); + assertAllEqual(c_buffer, 18); +} + +TEST(LLVM, CompareSelectIntEQ) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + for (int i = 0; i < N / 2; i++) { + b_buffer[i] = 0; + c_ref[i] = 0; + } + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kEQ))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, CompareSelectFloatEQ) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 1.0f); + std::vector b_buffer(N, 1.0f); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kEQ))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1); +} + +TEST(LLVM, CompareSelectByteGT) { + constexpr int N = 1024; + BufHandle a("A", {N}, kByte); + BufHandle b("B", {N}, kByte); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 0); + std::vector b_buffer(N, 0); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 0); + + for (int i = 0; i < N / 2; i++) { + a_buffer[i] = 128; + c_ref[i] = 1; + } + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kGT))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(b_buffer, uint8_t(0)); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, CompareSelectByteGE) { + constexpr int N = 1024; + BufHandle a("A", {N}, kByte); + BufHandle b("B", {N}, kByte); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 0); + std::vector b_buffer(N, 0); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kGE))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(b_buffer, uint8_t(0)); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, CompareSelectByteLT) { + constexpr int N = 1024; + BufHandle a("A", {N}, kByte); + BufHandle b("B", {N}, kByte); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 0); + std::vector b_buffer(N, 128); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + for (int i = 0; i < N / 2; i++) { + a_buffer[i] = 128; + c_ref[i] = 0; + } + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kLT))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(b_buffer, uint8_t(128)); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, CompareSelectByteLE) { + constexpr int N = 1024; + BufHandle a("A", {N}, kByte); + BufHandle b("B", {N}, kByte); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 0); + std::vector b_buffer(N, 128); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kLE))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(b_buffer, uint8_t(128)); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, StoreFloat) { + BufHandle result("result", {1}, kFloat); + std::vector result_buffer = {0.0f}; + auto expr = result.store({0}, FloatImm::make(3.14f)); + LLVMCodeGen cg(expr, {result}); + std::vector args({result_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(result_buffer[0], 3.14f); +} + +TEST(LLVM, SimpleMath01) { + const int N = 1024; + Tensor tensor = Compute( + "f", {N}, [](const VarHandle& i) { return cast(i * i + 1); }); + LoopNest l({tensor}); + StmtPtr stmt = l.root_stmt(); + BufHandle f_buf(tensor.buf()); + LLVMCodeGen cg(stmt, {f_buf}); + + PaddedBuffer f_v(N, "f_v"); + std::vector args({f_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer f_ref(N, "f_ref"); + for (const auto i : c10::irange(N)) { + f_ref(i) = i * i + 1; + } + ExpectAllNear(f_v, f_ref, 1e-5); +} + +TEST(LLVM, ComputeMul) { + const int N = 1024; + BufHandle a("a", {N}, kFloat); + BufHandle b("b", {N}, kFloat); + Tensor c = Compute( + "c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); }); + + BufHandle c_buf(c.buf()); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + + LLVMCodeGen cg(s, {a, b, c_buf}); + + std::vector a_vec(N, 21.0f); + std::vector b_vec(N, 2.0f); + std::vector c_vec(N, 0.0f); + std::vector args({a_vec.data(), b_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 42.0f); +} + +TEST(LLVM, BroadcastAdd) { + const int M = 32; + const int N = 1024; + BufHandle a("a", {M, N}, kFloat); + BufHandle b("b", {N}, kFloat); + Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(j); + }); + + BufHandle c_buf(c.buf()); + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + LLVMCodeGen cg(s, {a, b, c_buf}); + + std::vector av(M * N); + std::iota(av.begin(), av.end(), 0); + std::vector bv(N); + std::iota(bv.begin(), bv.end(), 0); + std::vector cv(M * N, 0); + std::vector args({av.data(), bv.data(), cv.data()}); + ASSERT_EQ(cg.value(args), 0); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]); + } + } +} + +TEST(LLVM, BitwiseOps) { + auto a = IntImm::make(59); + auto b = IntImm::make(11); + auto c = IntImm::make(101); + auto d = IntImm::make(2); + + ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; + LLVMExprEval cg(f); + + ASSERT_EQ(cg.value(), 11); +} + +TEST(LLVM, ArithmeticRightShift) { + auto a = CharImm::make(-4); + auto b = CharImm::make(1); + ExprHandle f = a >> b; + LLVMExprEval cg(f); + ASSERT_EQ(cg.value(), -2); +} + +TEST(LLVM, LogicalRightShift) { + auto a = ByteImm::make(0xfc); + auto b = ByteImm::make(1); + ExprHandle f = a >> b; + LLVMExprEval cg(f); + ASSERT_EQ(cg.value(), 0x7e); +} + +TEST(LLVM, DynamicShapeAdd) { + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + BufHandle b("b", {n}, kFloat); + BufHandle c("c", {n}, kFloat); + VarHandle i("i", kInt); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + LLVMCodeGen cg(s, {a, b, c, n}); + std::vector args({aData.data(), bData.data(), cData.data(), &size}); + cg.value(args); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +TEST(LLVM, BindDynamicShapeAdd) { + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + BufHandle b("b", {n}, kFloat); + BufHandle c("c", {n}, kFloat); + VarHandle i("i", kInt); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + LLVMCodeGen cg(s, {a, b, c, n}); + cg.call({aData, bData, cData, size}); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +TEST(LLVM, TensorDynamicShapeAdd) { + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + BufHandle b("b", {n}, kFloat); + Tensor c = Compute( + "c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); }); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + LLVMCodeGen cg(s, {a, b, c, n}); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + cg.call({aData, bData, cData, size}); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +TEST(LLVM, DynamicShape2D) { + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + BufHandle a("a", {m, n}, kFloat); + BufHandle b("b", {m, n}, kFloat); + Tensor c = + Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(i, j); + }); + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + LLVMCodeGen cg(s, {a, b, c, m, n}); + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + cg.call({aData, bData, cData, M, N}); + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + }; + testWithSize(1, 8); + testWithSize(16, 32); + testWithSize(37, 11); +} + +TEST(LLVM, EmptyStmt) { + StmtPtr s = alloc(std::vector({})); + + LLVMCodeGen cg(s, {}); + cg.call({}); + // Just don't crash. +} + +TEST(LLVM, EliminatedStmt) { + BufHandle a("a", {1}, kFloat); + + Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; }); + + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + s = IRSimplifier::simplify(s); + LLVMCodeGen cg(s, {a, c}); + std::vector aData(1, 1.0f); + std::vector cData(0, 0.0f); + cg.call({aData, cData}); +} + +TEST(LLVM, SimpleReduction) { + int M = 128; + int N = 64; + + BufHandle a("a", {1, M, N}, kFloat); + + Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); + LoopNest loop({b}); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + LLVMCodeGen cg(s, {a, b}); + + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + cg.call({a_v, b_v}); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +TEST(LLVM, RFactorReduction) { + int M = 128; + int N = 64; + + BufHandle a("a", {1, M, N}, kFloat); + + Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); + LoopNest loop({b}); + + std::vector loops = loop.getLoopStmtsFor(b); + ForPtr loop_m = loops.at(1); + ForPtr loop_n = loops.at(2); + loop.reorderAxis(loop_m, loop_n); + + loops = loop.getLoopStmtsFor(b); + loop_m = loops.at(2); + loop_n = loops.at(1); + auto b_body = loop.getAllWritesToBuf(b.buf())[1]; + ASSERT_TRUE(loop.rfactor(b_body, loop_n)); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + LLVMCodeGen cg(s, {a, b}); + + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + cg.call({a_v, b_v}); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +TEST(LLVM, RFactorVectorizedReduction) { + int M = 128; + int N = 64; + + BufHandle a("a", {1, M, N}, kFloat); + + Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); + LoopNest loopnest({b}); + std::vector loops = loopnest.getLoopStmtsFor(b); + // Reorder n and m loops + loopnest.reorderAxis(loops.at(1), loops.at(2)); + auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1); + auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf()); + ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3); + ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1])); + auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]); + + // Vectorize initializer of rfac_buf + ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0])); + // Vectorize producer of rfac_buf + ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1])); + loopnest.simplify(); + + loopnest.prepareForCodegen(); + + StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt()); + LLVMCodeGen cg(s, {a, b}); + + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + cg.call({a_v, b_v}); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +template +static void testSimpleParallel() { + // Compute a simple operation, and try all loop-axis combination to be + // parallel or sequential. + const int M = 4; + const int N = 6; + Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) { + return cast(m + n); + }); + LoopNest loop_nest({f}); + auto const& loops = loop_nest.getLoopStmtsFor(f); + ForPtr m = loops[0]; + ForPtr n = loops[1]; + if (outer) { + m->set_parallel(); + } + if (inner) { + n->set_parallel(); + } + loop_nest.prepareForCodegen(); + StmtPtr stmt = loop_nest.root_stmt(); + LLVMCodeGen cg(stmt, {f}); + + PaddedBuffer f_v(M, N, "f_v"); + std::vector args({f_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer f_ref(M, N, "f_ref"); + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + f_ref(m, n) = m + n; + } + } + ExpectAllNear(f_v, f_ref, 1e-5); +} + +TEST(LLVM, SimpleParallelSS) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelSP) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelPS) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelPP) { + testSimpleParallel(); +} + +TEST(LLVM, CompositeParallel) { + int loop_count = 6; + int test_count = 1 << loop_count; + // Compute a composite operation, and try all loop-axis combination to be + // parallel or sequential. + for (const auto test_cfg : c10::irange(test_count)) { + int M = 5; + int N = 7; + Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; }); + Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; }); + Tensor t3 = + Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) { + return t1.load(m) * t2.load(n); + }); + Tensor t4 = + Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) { + return t3.load(m, n) + m + n; + }); + LoopNest loop_nest({t4}, {t1, t2, t3, t4}); + std::vector loop_list; + { + auto const& loops = loop_nest.getLoopStmtsFor(t1); + loop_list.push_back(loops[0]); + } + { + auto const& loops = loop_nest.getLoopStmtsFor(t2); + loop_list.push_back(loops[0]); + } + { + auto const& loops = loop_nest.getLoopStmtsFor(t3); + loop_list.push_back(loops[0]); + loop_list.push_back(loops[1]); + } + { + auto const& loops = loop_nest.getLoopStmtsFor(t4); + loop_list.push_back(loops[0]); + loop_list.push_back(loops[1]); + } + ASSERT_EQ(loop_list.size(), loop_count); + for (const auto i : c10::irange(loop_count)) { + if (test_cfg & (1 << i)) { + loop_list[i]->set_parallel(); + } + } + loop_nest.prepareForCodegen(); + StmtPtr stmt = loop_nest.root_stmt(); + LLVMCodeGen cg(stmt, {t4}); + + PaddedBuffer t4_v(M, N, "t4_v"); + std::vector args({t4_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer t4_ref(M, N, "t4_ref"); + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + t4_ref(m, n) = (m + 1) * (n + 2) + m + n; + } + } + ExpectAllNear(t4_v, t4_ref, 1e-5); + } +} + +TEST(LLVM, VectorizedGEMM) { + int M = 32; + int N = 32; + int K = 48; + + BufHandle AP("A", {M, K}, kFloat); + BufHandle BP("B", {K, N}, kFloat); + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr m = loops[0]; + loop.splitWithMask(m, 16); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr n = loops[2]; + loop.splitWithMask(n, 16); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr mi = loops[1]; + ForPtr no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr ni = loops[3]; + ForPtr k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr mi = loops[2]; + ForPtr k = loops[3]; + loop.reorderAxis(mi, k); + } + { + auto loops = NodeFinder::find(loop.root_stmt()); + ASSERT_TRUE(LoopNest::vectorize(loops[3])); + ASSERT_TRUE(LoopNest::vectorize(loops.back())); + } + + loop.prepareForCodegen(); + + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + LLVMCodeGen cg(s, {AP, BP, CT}); + + PaddedBuffer a_v(M, K, "a_v"); + PaddedBuffer b_v(K, N, "b_v"); + PaddedBuffer c_v(M, N, "c_v"); + PaddedBuffer c_ref(M, N, "c_ref"); + + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + c_ref(m, n) = 0.f; + for (const auto k : c10::irange(K)) { + c_ref(m, n) += a_v(m, k) * b_v(k, n); + } + } + } + + cg.call({a_v, b_v, c_v}); + + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LLVM, CallRaw) { + const int M = 32; + VarHandle N("N", kInt); + BufHandle a("a", {M, N}, kFloat); + BufHandle b("b", {N}, kFloat); + Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(j); + }); + + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + int32_t N_value = 1024; + std::vector av(M * N_value); + std::iota(av.begin(), av.end(), 0); + std::vector bv(N_value); + std::iota(bv.begin(), bv.end(), 0); + std::vector cv(M * N_value, 0); + std::vector args({av.data(), bv.data(), cv.data(), &N_value}); + + LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N}); + cg.call_raw(args); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N_value)) { + ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); + } + } + + SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N}); + eval.call_raw(args); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N_value)) { + ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); + } + } +} + +TEST(LLVM, CustomTarget) { + constexpr int M = 16; + BufHandle a("a", {M}, kFloat); + BufHandle b("b", {M}, kFloat); + BufHandle c("c", {M}, kFloat); + Tensor d = Compute("d", {M}, [&](const VarHandle& m) { + return a.load(m) * b.load(m) + c.load(m); + }); + LoopNest nest({d}); + nest.prepareForCodegen(); + auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d}) + .triple("i686-elf") + .cpu("i386") + .build(); + std::ostringstream ss; + ss << cg->getCodeText("asm"); + torch::jit::testing::FileCheck() + .check("fadds") + ->check("fmuls") + ->check_not("vfmadd") + ->run(ss.str()); +} + +TEST(LLVM, CodeGenKernelFuncName) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + std::vector a_buffer = {42}; + std::vector b_buffer = {-11}; + auto store = b.store({0}, a.load(0)); + + { + LLVMCodeGen cg(store, {a, b}); + // Check that the kernel function name used by LLVMCodeGen + // is not empty. + ASSERT_NE(cg.kernel_func_name(), ""); + } + + { + LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func"); + // Check that the kernel function name used by LLVMCodeGen + // is the one that was given above. + ASSERT_EQ(cg.kernel_func_name(), "new_func"); + } +} + +} // namespace jit +} // namespace torch + +#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp new file mode 100644 index 000000000000..a8bda8814dba --- /dev/null +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -0,0 +1,6894 @@ +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +void checkIR(StmtPtr s, const std::string& pattern) { + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run(pattern, oss.str()); +} + +void checkExprIR(ExprPtr e, const std::string& pattern) { + std::string prefixed_pattern = "# CHECK: " + pattern + "\n"; + std::ostringstream oss; + oss << *e << "\n"; + torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str()); +} + +void checkExprIR(const ExprHandle& e, const std::string& pattern) { + checkExprIR(e.node(), pattern); +} + +TEST(LoopNest, ExprSimple01) { + Tensor tensor = + Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + LoopNest::splitWithTail(loops[0], 2); + LoopNest::splitWithTail(loops[0], 2); +} + +TEST(LoopNest, ExprLower01) { + Tensor tensor = + Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + StmtPtr stmt = l.root_stmt(); + std::ostringstream oss; + oss << *stmt; + ASSERT_GT(oss.str().size(), 20); + ASSERT_LT(oss.str().size(), 200); +} + +TEST(LoopNest, ExprSimple02) { + auto func = [](const ExprHandle& x, const ExprHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }; + Tensor tensor = Compute("f", {26, 5}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + LoopNest::splitWithTail(loops[0], 4); + + StmtPtr stmt = l.root_stmt(); + std::ostringstream oss; + oss << *stmt; + ASSERT_GT(oss.str().size(), 200); + ASSERT_LT(oss.str().size(), 600); + + { + // Compare to a reference loop structure structure. + VarHandle x_outer("i_outer", kInt); + VarHandle x_inner("i_inner", kInt); + VarHandle y("i", kInt); + VarHandle x_tail("i_tail", kInt); + BufHandle f("f", {26, 5}, kFloat); + ExprHandle x_1 = x_outer * 4 + x_inner; + ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; + ForPtr stmt1 = For::make( + x_outer, + 0, + x_outer_end, + For::make( + x_inner, + 0, + 4, + For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y))))); + ExprHandle x_2 = x_tail + x_outer_end * 4; + ForPtr stmt2 = For::make( + x_tail, + 0, + (ExprHandle(26) - 0) % 4, + For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y)))); + StmtPtr stmt = Block::make({stmt1, stmt2}); + + std::ostringstream oss_ref; + oss_ref << *stmt; + ASSERT_EQ(oss.str(), oss_ref.str()); + } + + { + PaddedBuffer f_v(26, 5, "f_v"); + PaddedBuffer f_ref(26, 5, "f_res"); + + stmt = FlattenIndexes(stmt); + SimpleIREvaluator ir_eval(stmt, {tensor}); + ir_eval(f_v); + + for (int x = 0; x < 26; x++) { + for (int y = 0; y < 5; y++) { + f_ref(x, y) = 1 + x * x + y * y; + } + } + + ExpectAllNear(f_v, f_ref, 1e-5); + } +} + +BlockPtr getSimplifiedBody(const LoopNest& l) { + StmtPtr stmt = l.root_stmt(); + StmtPtr simplified = IRSimplifier::simplify(stmt); + return to(simplified); +} + +void assertForRange(ForPtr f, int expected_start, int expected_stop) { + ASSERT_NE(f, nullptr); + IntImmPtr start = to(f->start()); + ASSERT_NE(start, nullptr); + ASSERT_EQ(start->value(), expected_start); + IntImmPtr stop = to(f->stop()); + ASSERT_NE(stop, nullptr); + ASSERT_EQ(stop->value(), expected_stop); +} + +void assertForRanges( + BlockPtr body, + const std::vector>& start_stops) { + ASSERT_EQ(body->nstmts(), start_stops.size()); + + auto it = body->begin(); + for (size_t i = 0; i < start_stops.size(); i++, it++) { + ForPtr loop = to(*it); + assertForRange(loop, start_stops[i].first, start_stops[i].second); + } +} + +TEST(LoopNest, ExprSliceHeadWithLoopOptions) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); + LoopNest::sliceHead(loops[0], 2, &head, &tail); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 2}, {0, 8}}); + + ASSERT_TRUE(tail->loop_options().is_gpu_block_index()); + ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y); + + ASSERT_TRUE(head->loop_options().isDefault()); +} + +TEST(LoopNest, ExprSliceTailWithLoopOptions) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceTail(loops[0], 4, &head, &tail); + + ForPtr tail_head; + ForPtr tail_tail; + tail->set_gpu_block_index(LoopOptions::IDX_Y); + LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}}); + + ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index()); + ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y); + + ASSERT_TRUE(head->loop_options().isDefault()); + ASSERT_TRUE(tail_tail->loop_options().isDefault()); +} + +TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { + // When factor equals the For loop's original size, keep using the original + // For loop. + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceHead(loops[0], 10, &head, &tail); + + ASSERT_EQ(head, loops[0]); + ASSERT_EQ(tail, nullptr); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 10}}); +} + +TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceHead(loops[0], 100, &head, &tail); + + ASSERT_EQ(head, loops[0]); + ASSERT_EQ(tail, nullptr); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 10}}); +} + +TEST(LoopNest, ExprSliceHead) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceHead(loops[0], 4, &head, &tail); + + ASSERT_NE(head, nullptr); + ASSERT_NE(head, loops[0]); + ASSERT_NE(tail, nullptr); + ASSERT_EQ(tail, loops[0]); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 4}, {4, 10}}); +} + +TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + ForPtr head; + ForPtr tail; + LoopNest::sliceTail(loops[0], 4, &head, &tail); + // head: [0, 6) + // tail: [6, 10) + + LoopNest::sliceHead(tail, 2); + // tail_head: [6, 8) + // tail_tail: [8, 10) + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); +} + +TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { + // When factor equals the For loop's original size, keep using the original + // For loop. + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceTail(loops[0], 10, &head, &tail); + + ASSERT_EQ(head, nullptr); + ASSERT_EQ(tail, loops[0]); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 10}}); +} + +TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { + // When factor equals the For loop's original size, keep using the original + // For loop. + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceTail(loops[0], 100, &head, &tail); + + ASSERT_EQ(head, nullptr); + ASSERT_EQ(tail, loops[0]); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 10}}); +} + +TEST(LoopNest, ExprSliceTail) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceTail(loops[0], 4, &head, &tail); + + ASSERT_NE(head, nullptr); + ASSERT_EQ(head, loops[0]); + ASSERT_NE(tail, nullptr); + ASSERT_NE(tail, loops[0]); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 6}, {6, 10}}); +} + +TEST(LoopNest, ExprSplitAndSlice) { + // 0: splitWithTail + // 1: sliceTail on inner loop + // 2: sliceHead on outer loop + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {100}, func); + LoopNest l({tensor}); + + ForPtr inner; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // outer: [0, 4) + // inner: [0, 21) + // tail: [84, 100) + LoopNest::splitWithTail(loops[0], 21, &inner, &tail); + LoopNest::sliceTail(inner, 2); + LoopNest::sliceHead(loops[0], 2); + + // for (int x_outer = 0; x_outer < 2; x_outer++) { + // for (int x_inner = 0; x_inner < 19; x_inner++) { + // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); + // } + // for (int x_inner = 19; x_inner < 21; x_inner++) { + // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); + // } + // } + // for (int x_outer = 2; x_outer < 4; x_outer++) { + // for (int x_inner = 0; x_inner < 19; x_inner++) { + // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); + // } + // for (int x_inner = 19; x_inner < 21; x_inner++) { + // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); + // } + // } + // for (int x_tail = 0; x_tail < 16; x_tail++) { + // f[x_tail + 84] = 1.f + float(x_tail + 84); + // } + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}}); + + auto biter = body->begin(); + + ForPtr loop = to(*biter++); + assertForRanges(loop->body(), {{0, 19}, {19, 21}}); + + loop = to(*biter); + assertForRanges(loop->body(), {{0, 19}, {19, 21}}); +} + +TEST(LoopNest, ExprSliceAndNormalize) { + // 0: sliceHead + // 1: normalize tail + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + ForPtr head; + ForPtr tail; + LoopNest::sliceHead(loops[0], 2, &head, &tail); + // head: [0, 2) + // tail: [2, 10) + + LoopNest::normalize(tail); + // normalized_tail: [0, 8) + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 2}, {0, 8}}); +} + +template +T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) { + ExprEval eval(expr, {var}); + return eval.value(value); +} + +TEST(LoopNest, ExprSliceWithVariableDimension) { + auto testWithDimension = + [](int dimension, + const std::vector>& expected_for_ranges) { + VarHandle dim("dim", kInt); + Tensor tensor = + Compute("f", {dim}, [](const ExprHandle& x) { return x; }); + LoopNest l({tensor}); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + ForPtr head; + ForPtr tail; + LoopNest::sliceHead(loops[0], 2, &head, &tail); + + LoopNest::sliceTail(tail, 2); + + BlockPtr body = getSimplifiedBody(l); + ASSERT_EQ(expected_for_ranges.size(), 3); + auto it = body->begin(); + for (auto& start_stop : expected_for_ranges) { + ForPtr loop = to(*it++); + int start = evalExpr(ExprHandle(loop->start()), dim, dimension); + int stop = evalExpr(ExprHandle(loop->stop()), dim, dimension); + ASSERT_EQ(start, start_stop.first); + ASSERT_EQ(stop, start_stop.second); + } + }; + + testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}}); + testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}}); + testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}}); + testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}}); + testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}}); + testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}}); +} + +TEST(LoopNest, ExprSplitWithTail) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {199}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + LoopNest::splitWithTail(loops[0], 17); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + LoopNest::splitWithTail(loops[0], 7); + + StmtPtr stmt = l.root_stmt(); + StmtPtr simplified = IRSimplifier::simplify(stmt); + BlockPtr body = to(simplified); + ASSERT_EQ(body->nstmts(), 3); + auto biter = body->begin(); + + // Verify that the split loops are ordered correctly. + ForPtr loop = to(*biter++); + assertForRange(loop, 0, 7); + + loop = to(*biter++); + assertForRange(loop, 0, 4); + + loop = to(*biter); + assertForRange(loop, 0, 12); +} + +TEST(LoopNest, ExprSplitWithTailNone) { + auto func = [](const ExprHandle& x, const ExprHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }; + Tensor tensor = Compute("f", {24, 5}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::splitWithTail(loops[0], 4); + + StmtPtr stmt = l.root_stmt(); + std::ostringstream oss; + oss << *stmt; + ASSERT_GT(oss.str().size(), 200); + ASSERT_LT(oss.str().size(), 600); + + { + // Compare to a reference loop structure structure. + VarHandle x_outer("i_outer", kInt); + VarHandle x_inner("i_inner", kInt); + VarHandle y("i", kInt); + VarHandle x_tail("i_tail", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) + BufHandle f("f", {24, 5}, kFloat); + ExprHandle x_1 = x_outer * 4 + x_inner; + ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; + StmtPtr stmt = alloc(std::vector({For::make( + x_outer, + 0, + x_outer_end, + For::make( + x_inner, + 0, + 4, + For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))})); + + std::ostringstream oss_ref; + oss_ref << *stmt; + ASSERT_EQ(oss.str(), oss_ref.str()); + } + + { + PaddedBuffer f_v(24, 5, "f_v"); + PaddedBuffer f_ref(24, 5, "f_res"); + + SimpleIREvaluator ir_eval(stmt, {tensor}); + ir_eval(f_v); + + for (int x = 0; x < 24; x++) { + for (int y = 0; y < 5; y++) { + f_ref(x, y) = 1 + x * x + y * y; + } + } + + ExpectAllNear(f_v, f_ref, 1e-5); + } +} + +TEST(LoopNest, ExprSplitWithMask01) { + const int M = 26; + const int N = 5; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {M, N}, kFloat); + Tensor tensor = + Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; + }); + + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::splitWithMask(loops[1], 4); + + StmtPtr stmt = l.root_stmt(); + + PaddedBuffer a_v(M, N, "a"); + PaddedBuffer b_v(M, N, "b"); + PaddedBuffer c_v(M, N, "c"); + PaddedBuffer c_ref(M, N, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 2 * m; + b_v(m, n) = 3 * n; + c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; + } + } + + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); + + ExpectAllNear(c_v, c_ref, 1e-5); +} + +// Tests the case where we split a loop cleanly multiple times, we should not +// insert any masks. +TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { + const int M = 64; + BufHandle a_buf("a", {M}, kFloat); + BufHandle b_buf("b", {M}, kFloat); + Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { + return a_buf.load(m) + b_buf.load(m) + 1.0f; + }); + + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4); + LoopNest::splitWithMask(loops[0], 4); + + StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt()); + + // Two splits mean 3 loops, but should need no masks in this case. + checkIR(stmt1, R"IR( +# CHECK: for ( +# CHECK-NOT: if ( +# CHECK: for ( +# CHECK-NOT: if ( +# CHECK: for ( +# CHECK-NOT: if ( +# CHECK: f[)IR"); +} + +TEST(LoopNest, getLoopAt) { + // Input IR: + // for (int i = 0; i < 100; i++) { + // for (int j = 0; j < 100; j++) { + // A[i, j] = sin(i * j); + // for (int k1 = 0; k1 < 200; k1++) { + // B[i, j, k1] = (A[i, j]) / (k1 + 1); + // } + // for (int k2 = 0; k2 < 300; k2++) { + // C[i, j, k2] = (A[i, j]) * (k2 + 1); + // } + // } + // } + BufPtr A = alloc( + "A", + std::vector({alloc(100), alloc(100)}), + kInt); + BufPtr B = alloc( + "B", + std::vector( + {alloc(100), alloc(100), alloc(200)}), + kInt); + BufPtr C = alloc( + "C", + std::vector( + {alloc(100), alloc(100), alloc(300)}), + kInt); + BufHandle a_buf(A); + BufHandle b_buf(B); + BufHandle c_buf(C); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k1("k1", kInt); + VarHandle k2("k2", kInt); + auto store1 = Store::make(a_buf, {i, j}, sin(i * j)); + auto store2 = Store::make( + b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1))); + auto store3 = Store::make( + c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1))); + auto for_k2 = For::make(k2, 0, 300, Block::make({store3})); + auto for_k1 = For::make(k1, 0, 200, Block::make({store2})); + auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2})); + auto for_i = For::make(i, 0, 100, for_j); + LoopNest l(Block::make({for_i}), {B, C}); + auto ret_k2 = l.getLoopAt(for_i, {0, 2}); + TORCH_CHECK(ret_k2 == for_k2); + + std::ostringstream oss; + oss << *ret_k2; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int k2 +# CHECK-NEXT: C[i, j, k2] = + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, TileSimple) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + const int M = 64, N = 64; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {M, N}, kFloat); + Tensor tensor = + Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; + }); + + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + l.tile(loops[0], loops[1], 4, 8); + + // IR check + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + checkIR(stmt, R"IR( +# CHECK: for (int i_outer +# CHECK: for (int i_outer_1 +# CHECK: for (int i_inner +# CHECK: for (int i_inner_1 +# CHECK: f[ +# CHECK-NOT: for (int i_tail +# CHECK-NOT: for (int i_tail)IR"); + + // Correctness check + PaddedBuffer a_v(M, N, "a"); + PaddedBuffer b_v(M, N, "b"); + PaddedBuffer c_v(M, N, "c"); + PaddedBuffer c_ref(M, N, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 2 * m; + b_v(m, n) = 3 * n; + c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; + } + } + + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LoopNest, TileWithTails) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + const int M = 64, N = 64; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {M, N}, kFloat); + Tensor tensor = + Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; + }); + + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + l.tile(loops[0], loops[1], 5, 9); + + // IR check + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + checkIR(stmt, R"IR( +# CHECK: for (int i_outer +# CHECK: for (int i_outer_1 +# CHECK: for (int i_inner +# CHECK: for (int i_inner_1 +# CHECK: f[ +# CHECK: for (int i_inner +# CHECK: f[ +# CHECK: for (int i_tail)IR"); + + // Correctness check + PaddedBuffer a_v(M, N, "a"); + PaddedBuffer b_v(M, N, "b"); + PaddedBuffer c_v(M, N, "c"); + PaddedBuffer c_ref(M, N, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 2 * m; + b_v(m, n) = 3 * n; + c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; + } + } + + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LoopNest, TileInMiddle) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + const int M = 8, N = 8, L = 8, K = 8; + BufHandle a_buf("a", {M, N, L, K}, kFloat); + BufHandle b_buf("b", {M, N, L, K}, kFloat); + Tensor tensor = Compute( + "f", + {M, N, L, K}, + [&](const ExprHandle& m, + const ExprHandle& n, + const ExprHandle& l, + const ExprHandle& k) { + return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f; + }); + + LoopNest nest({tensor}); + std::vector loops = + nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + nest.tile(loops[1], loops[2], 3, 3); + + // IR check + StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt()); + checkIR(stmt, R"IR( +# CHECK: for (int i +# CHECK: for (int i_outer +# CHECK: for (int i_outer_1 +# CHECK: for (int i_inner +# CHECK: for (int i_inner_1 +# CHECK: for (int i_1 +# CHECK: f[ +# CHECK: for (int i_tail_1 +# CHECK: for (int i_inner_1 +# CHECK: for (int i_1 +# CHECK: f[ +# CHECK: for (int i_tail)IR"); + + // Correctness check + PaddedBuffer a_v(M, N, L, K, "a"); + PaddedBuffer b_v(M, N, L, K, "b"); + PaddedBuffer c_v(M, N, L, K, "c"); + PaddedBuffer c_ref(M, N, L, K, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int l = 0; l < L; l++) { + for (int k = 0; k < K; k++) { + a_v(m, n, l, k) = 2 * (m + l); + b_v(m, n, l, k) = 3 * (n + k); + c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f; + } + } + } + } + + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LoopNest, SplitWithTailWithLoopOptions) { + const int M = 21; + BufHandle a_buf("a", {M}, kFloat); + BufHandle b_buf("b", {M}, kFloat); + Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { + return a_buf.load(m) + b_buf.load(m) + 1.0f; + }); + ForPtr inner, tail; + + LoopNest l({tensor}); + auto loops = NodeFinder::find(l.root_stmt()); + ASSERT_GT(loops.size(), 0); + loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); + LoopNest::splitWithTail(loops[0], 4, &inner, &tail); + ASSERT_NE(inner, nullptr); + ASSERT_NE(tail, nullptr); + ForPtr outer = loops[0]; + + // Outer loop carries loop axis bindings. + ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); + ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); + + // Inner loop has none. + ASSERT_TRUE(inner->loop_options().isDefault()); + + // Tail loop has none. + ASSERT_TRUE(tail->loop_options().isDefault()); +} + +TEST(LoopNest, SplitWithMaskWithLoopOptions) { + const int M = 21; + BufHandle a_buf("a", {M}, kFloat); + BufHandle b_buf("b", {M}, kFloat); + Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { + return a_buf.load(m) + b_buf.load(m) + 1.0f; + }); + ForPtr inner; + + LoopNest l({tensor}); + auto loops = NodeFinder::find(l.root_stmt()); + loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); + LoopNest::splitWithMask(loops[0], 4, &inner); + ForPtr outer = loops[0]; + + // Outer loop carries loop axis bindings. + ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); + ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); + + // Inner loop has none. + ASSERT_TRUE(inner->loop_options().isDefault()); +} + +TEST(LoopNest, ScheduleBroadcastAddBuffer) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + Tensor c = Compute( + "broadcast_add", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + LoopNest l({c}); + StmtPtr stmt = l.root_stmt(); + + PaddedBuffer a_v(M, N, "a_v"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 7 * m * n; + } + } + a_v.Backup(); + + PaddedBuffer b_v(N, K, "b_v"); + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + b_v(n, k) = 11 * n * k; + } + } + b_v.Backup(); + + PaddedBuffer c_v(M, N, K, "c_buf"); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c}); + ir_eval(a_v, b_v, c_v); + + a_v.CheckBackup(); + b_v.CheckBackup(); + PaddedBuffer c_ref(M, N, K, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + c_ref(m, n, k) = 7 * m * n + 11 * n * k; + } + } + } + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LoopNest, ScheduleFunctionCall01) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + Tensor c = Compute( + "broadcast_add", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + Tensor d = Compute( + "d", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c.load(m, n, k) + 1; + }); + + LoopNest l({d}, {c, d}); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + std::ostringstream oss; + oss << *stmt; + ASSERT_GT(oss.str().size(), 100); + + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N, K); + PaddedBuffer d_v(M, N, K); + PaddedBuffer d_ref(M, N, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1; + } + } + } + + SimpleIREvaluator eval(stmt, {a_buf, b_buf, d}); + eval(a_v, b_v, d_v); + + ExpectAllNear(d_v, d_ref, 1e-5); +} + +TEST(LoopNest, ScheduleInlineSimple) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + BufHandle c_buf("c", {M, N}, kFloat); + BufHandle d_buf("d", {M, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); + }); + + LoopNest l1({y}, {x, y}); + LoopNest l2(l1); + l2.computeInline(x.buf()); + + l1.prepareForCodegen(); + l2.prepareForCodegen(); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); + + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); + + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N); + PaddedBuffer d_v(M, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + c_v(i, j) = i + j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + d_v(i, j) = i * j; + } + } + + PaddedBuffer y_1(M, N, K); + PaddedBuffer y_2(M, N, K); + + eval1(a_v, b_v, c_v, d_v, y_1); + eval2(a_v, b_v, c_v, d_v, y_2); + ExpectAllNear(y_1, y_2, 1e-5); + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt2; + ASSERT_GT(oss1.str().size(), oss2.str().size()); +} + +static std::string remove_space(const std::string& str) { + std::string str_new = str; + str_new.erase( + remove_if(str_new.begin(), str_new.end(), isspace), str_new.end()); + return str_new; +} + +void InlineFunc01Helper(const std::vector& inline_order) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + BufHandle c_buf("c", {M, N}, kFloat); + BufHandle d_buf("d", {M, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); + }); + Tensor z = Compute( + "z", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + y.load(m, n, k); + }); + + LoopNest l({z}, {x, y, z}); + for (const std::string& order : inline_order) { + if (order == "x") { + l.computeInline(x.buf()); + } else if (order == "y") { + l.computeInline(y.buf()); + } else { + throw std::runtime_error("Invalid order: " + order); + } + } + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + + std::ostringstream oss; + oss << *stmt; + std::string str1 = remove_space(oss.str()); + + { + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N); + PaddedBuffer d_v(M, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + c_v(i, j) = i + j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + d_v(i, j) = i * j; + } + } + + PaddedBuffer z_v(M, N, K); + PaddedBuffer z_ref(M, N, K); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); + } + } + } + + SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); + eval(a_v, b_v, c_v, d_v, z_v); + ExpectAllNear(z_v, z_ref, 1e-5); + } + + if (inline_order.size() == 2) { + Tensor z2 = Compute( + "z", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k) + + (c_buf.load(m, n) * d_buf.load(m, k) + + a_buf.load(m, n) * b_buf.load(n, k)); + }); + LoopNest l2({z2}); + l2.prepareForCodegen(); + StmtPtr stmt2 = l2.root_stmt(); + + std::ostringstream oss2; + oss2 << *stmt2; + std::string str2 = remove_space(oss2.str()); + + ASSERT_EQ(str1, str2); + ASSERT_GT(str1.size(), 100); + } +} + +TEST(LoopNest, ScheduleInlineFunc01) { + InlineFunc01Helper({"x", "y"}); + InlineFunc01Helper({"y", "x"}); + InlineFunc01Helper({"x"}); + InlineFunc01Helper({"y"}); + InlineFunc01Helper({}); +} + +// Make sure we cache random vars if we should. +TEST(LoopNest, ScheduleInlineRandom) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return Mod::make(Intrinsics::make(kRand, kInt), 5); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + x.load(m, n, k); + }); + + LoopNest l1({y}, {x, y}); + l1.computeInline(x.buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: int x = rand(); +# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); +} + +// Make sure we don't cache random vars that are not being inlined. +TEST(LoopNest, ScheduleInlineRandomUnrelated) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return m * n * k; + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + Intrinsics::make(kRand, kInt) + + Intrinsics::make(kRand, kInt); + }); + + LoopNest l1({y}, {x, y}); + l1.computeInline(x.buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR"); +} + +// Make sure we generate the right number of random values == the dimensionality +// of the production tensor. +TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute("x", {M}, [&](const VarHandle& m) { + return Mod::make(Intrinsics::make(kRand, kInt), 5); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m) + x.load(m); + }); + + LoopNest l1({y}, {x, y}); + l1.computeInline(x.buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: int x = rand(); +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); +} + +// Make sure we don't screw up intrinsics thinking they're rand. +TEST(LoopNest, ScheduleInlineIntrinsics) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return Intrinsics::make(kSqrt, x.load(m, n, k)); + }); + + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + + LoopNest l1({y}, {x, y}); + LoopNest l2(l1); + l2.computeInline(x.buf()); + + l1.prepareForCodegen(); + l2.prepareForCodegen(); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); + + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); + + PaddedBuffer y_1(M, N, K); + PaddedBuffer y_2(M, N, K); + + eval1(a_v, b_v, y_1); + eval2(a_v, b_v, y_2); + ExpectAllNear(y_1, y_2, 1e-5); + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt2; + ASSERT_GT(oss1.str().size(), oss2.str().size()); +} + +// Make sure we can handle rand and non-rand intrinsics. +TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return Intrinsics::make(kRand, kFloat); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return Intrinsics::make(kSqrt, x.load(m, n, k)); + }); + + LoopNest l1({y}, {x, y}); + l1.computeInline(x.buf()); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: float x = rand(); +# CHECK: y[i, i_1, i_2] = sqrt(x);)IR"); +} + +// Split a Compute then inline it into another compute. +TEST(LoopNest, ScheduleSplitAThenInline) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + + LoopNest l({b}, {a, b}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4); + ASSERT_FALSE(l.computeInline(a.buf())); +} + +// Split a Compute then inline another Compute into it. +TEST(LoopNest, ScheduleSplitBThenInline) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + + LoopNest l({b}, {a, b}); + std::vector loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); + LoopNest::splitWithMask(loops[0], 3); + l.computeInline(a.buf()); + l.prepareForCodegen(); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + + std::vector output(6, 0); + SimpleIREvaluator eval(s, {b}); + eval(output); + + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(output[i], (i + 8) * (i + 8)); + } +} + +// Split a Compute twice then inline it. +TEST(LoopNest, ScheduleSplitTwiceThenInline) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + ForPtr i_inner; + + LoopNest l({b}, {a, b}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4, &i_inner); + LoopNest::splitWithMask(i_inner, 2); + ASSERT_FALSE(l.computeInline(a.buf())); +} + +// Inline a Compute, then split. +TEST(LoopNest, ScheduleInlineThenSplit) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + + LoopNest l({b}, {a, b}); + l.computeInline(a.buf()); + + std::vector loops = NodeFinder::find(l.root_stmt()); + LoopNest::splitWithMask(loops.back(), 3); + l.prepareForCodegen(); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(6, 0); + SimpleIREvaluator eval(s, {b}); + eval(output); + + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(output[i], (i + 8) * (i + 8)); + } +} + +// Split a Compute, inline it, then split the result. +TEST(LoopNest, ScheduleSplitInlineThenSplit) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + + LoopNest l({b}, {a, b}); + auto loops = NodeFinder::find(l.root_stmt()); + LoopNest::splitWithMask(loops.back(), 2); + l.computeInline(a.buf()); + + loops = NodeFinder::find(l.root_stmt()); + LoopNest::splitWithMask(loops.front(), 2); + l.prepareForCodegen(); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(16, 0); + SimpleIREvaluator eval(s, {b}); + eval(output); + + for (int i = 0; i < 16; ++i) { + ASSERT_EQ(output[i], (i + 8) * (i + 8)); + } +} + +// Oversplit a loop that is simplified out after inlining. +TEST(LoopNest, ScheduleSplitInlineSimplify) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { + return ExprHandle(4) * i - ExprHandle(2) * i; + }); + Tensor b = Compute( + "b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); }); + + LoopNest l({b}, {a, b}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4); + ASSERT_FALSE(l.computeInline(a.buf())); +} + +// Inline a Compute with two consumers. +TEST(LoopNest, ScheduleInlineThreeMixedOnce) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { + return a.load(k) * b.load(l); + }); + + LoopNest l({c}, {a, b, c}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(a.buf()); + l.prepareForCodegen(); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(4 * 3, 0); + SimpleIREvaluator eval(s, {c}); + eval(output); + + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 3; ++l) { + ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); + } + } +} + +// Inline Compute A into B, then inline B into C. +TEST(LoopNest, ScheduleInlineThreeMixedTwice) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { + return a.load(k) * b.load(l); + }); + + LoopNest l({c}, {a, b, c}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(a.buf()); + l.computeInline(b.buf()); + l.prepareForCodegen(); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(4 * 3, 0); + SimpleIREvaluator eval(s, {c}); + eval(output); + + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 3; ++l) { + ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); + } + } +} + +// Inline a Compute that is both a producer and consumer. +TEST(LoopNest, ScheduleInlineThreeMixedInner) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { + return a.load(k) * b.load(l); + }); + + LoopNest l({c}, {a, b, c}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(b.buf()); + l.prepareForCodegen(); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(4 * 3, 0); + SimpleIREvaluator eval(s, {c}); + eval(output); + + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 3; ++l) { + ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); + } + } +} + +// Split 3 Computes, then inline the first two into the last. +TEST(LoopNest, ScheduleInlineThreeMixedSplit) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { + return a.load(k) * b.load(l); + }); + + LoopNest l({c}, {a, b, c}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4); + loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); + LoopNest::splitWithMask(loops[0], 3); + loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::splitWithMask(loops[0], 2); + + ASSERT_FALSE(l.computeInline(a.buf())); +} + +// Check that inlining works for output tensors too +TEST(LoopNest, ScheduleInlineOutputTensors) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return m * n * k; + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + m; + }); + + LoopNest l1({x, y}); + l1.computeInline(x.buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: x[i, i_1, i_2] = (i * i_1) * i_2; +# CHECK: for (int i_3 = 0; i_3 < 4; i_3++) +# CHECK: for (int i_4 = 0; i_4 < 5; i_4++) +# CHECK: for (int i_5 = 0; i_5 < 6; i_5++) +# CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR"); +} + +TEST(LoopNest, ScheduleInlineWithCompoundIndices) { + // Input IR: + // for (int64_t i = 0; i < 100; i++) { + // A[i*2,i] = i * 500ll; + // } + // for (int64_t j = 0; j < 100; j++) { + // B[0ll,j] = A[0, j] + j * 100ll; + // } + BufHandle a_buf("A", {20, 100}, kLong); + BufHandle b_buf("B", {20, 100}, kLong); + VarHandle i("i", kLong); + VarHandle j("j", kLong); + auto forI = For::make( + i, + 0, + 100, + Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast(500)))); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + b_buf, + {static_cast(0), j}, + Add::make( + Load::make(a_buf, {static_cast(0), j}), + Mul::make(j, static_cast(100))))); + auto par = Block::make({forI, forJ}); + + LoopNest l(par, {b_buf.node()}); + // Inlining should fail since the producer has compound expr as index. + ASSERT_FALSE(l.computeInline(a_buf.node())); + + // The input statement must remain as is. + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int64_t i = 0; + # CHECK-NEXT: A[ + # CHECK: for (int64_t j = 0; + # CHECK-NEXT: B[)IR"); +} + +TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) { + // Input IR: + // for (int64_t i = 0; i < 100; i++) { + // A[0ll,i] = i * 500ll; + // } + // for (int64_t j = 0; j < 100; j++) { + // B[0ll,j] = A[(int64_t)0, j] + j * 100ll; + // } + BufHandle a_buf("A", {20, 100}, kLong); + BufHandle b_buf("B", {20, 100}, kLong); + VarHandle i("i", kLong); + VarHandle j("j", kLong); + auto forI = For::make( + i, + 0, + 100, + Store::make( + a_buf, + {static_cast(0), i}, + Mul::make(i, static_cast(500)))); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + b_buf, + {static_cast(0), j}, + Add::make( + Load::make(a_buf, {0, j}), + Mul::make(j, static_cast(100))))); + auto par = Block::make({forI, forJ}); + + LoopNest l(par, {b_buf.node()}); + ASSERT_TRUE(l.computeInline(a_buf.node())); + + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int64_t j = 0; j < 100; j++) { + # CHECK: B[0ll, j] = j * 500ll + j * 100ll; + # CHECK: })IR"); +} + +TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) { + // Input IR: + // for (int64_t i = 0; i < 100; i++) { + // A[(int64_t)0,i] = i * 500ll; + // } + // for (int64_t j = 0; j < 100; j++) { + // B[0ll,j] = A[0ll, j] + j * 100ll; + // } + BufHandle a_buf("A", {20, 100}, kLong); + BufHandle b_buf("B", {20, 100}, kLong); + VarHandle i("i", kLong); + VarHandle j("j", kLong); + auto forI = For::make( + i, + 0, + 100, + Store::make(a_buf, {0, i}, Mul::make(i, static_cast(500)))); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + b_buf, + {static_cast(0), j}, + Add::make( + Load::make(a_buf, {static_cast(0), j}), + Mul::make(j, static_cast(100))))); + auto par = Block::make({forI, forJ}); + + LoopNest l(par, {b_buf.node()}); + ASSERT_TRUE(l.computeInline(a_buf.node())); + + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int64_t j = 0; j < 100; j++) { + # CHECK: B[0ll, j] = j * 500ll + j * 100ll; + # CHECK: })IR"); +} + +TEST(LoopNest, ScheduleFuserStyle) { + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + + Tensor b = + Compute("f", {kTotalSize}, [&](const std::vector& axes) { + return a_buf.load(axes[0]) + 11.0f; + }); + + Tensor c = + Compute("g", {kTotalSize}, [&](const std::vector& axes) { + return b.load(axes[0]) + 1.0f; + }); + + LoopNest l({b, c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + std::vector a_data(kTotalSize, 7.0f); + std::vector b_data(kTotalSize, 0.0f); + std::vector c_data(kTotalSize, 0.0f); + SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data); + + for (int i = 0; i < kTotalSize; i++) { + ASSERT_EQ(b_data[i], 18.0f); + ASSERT_EQ(c_data[i], 19.0f); + } +} + +TEST(LoopNest, ScheduleFuserThreeArg) { + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat); + + Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) { + return a.load(i) + b.load(i); + }); + Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) { + return e.load(i) + c.load(i); + }); + Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) { + return f.load(i) + d.load(i); + }); + + LoopNest l({g}, {e, f, g}); + l.computeInline(l.getLoopBodyFor(e)); + l.computeInline(l.getLoopBodyFor(f)); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + std::vector a_data(kTotalSize, 1.0f); + std::vector b_data(kTotalSize, 2.0f); + std::vector c_data(kTotalSize, 3.0f); + std::vector d_data(kTotalSize, 4.0f); + std::vector g_data(kTotalSize, 0.0f); + SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data); + + for (int i = 0; i < kTotalSize; i++) { + ASSERT_EQ(g_data[i], 10.0f); + } +} + +TEST(LoopNest, ScheduleDynamicShape2D) { + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + BufHandle a("a", {m, n}, kFloat); + BufHandle b("b", {m, n}, kFloat); + Tensor c = + Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(i, j); + }); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + SimpleIREvaluator cg(s, {a, b, c, m, n}); + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + cg.call({aData, bData, cData, M, N}); + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + }; + testWithSize(1, 8); + testWithSize(16, 32); + testWithSize(37, 11); +} + +TEST(LoopNest, LoopNestComputeAt_1) { + // Verify that compute_at works on the following example: + // + // for (int i_a = 0; i_a < N; i_a++) { + // A[i_a] = i_a * i_a + // } + // for (int i_b = 0; i_b < N; i_b++) { + // B[i_b] = A[i_b] + // } + // + // After the transformation the i_b loop should have an allocation for a temp + // buffer and that buffer should be used in computation of B. No use of A + // should be in that loop after the transformation. Also, computation of A + // should not be inlined into B. Instead, it should be computed into the temp, + // and the temp should be used in B. + VarHandle N("N", kInt); + Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; }); + Tensor B = + Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); }); + LoopNest l({B}, {A, B}); + std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {B, N}); + StmtPtr s = cg.stmt(); + + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[1] +# CHECK: for (int i = 0; i < N; i++) +# CHECK: temp[ +# CHECK-NOT: A[ +# CHECK: B[i_1] = temp[0] +# CHECK: Free(temp))IR"); + + // Now check that the loop still produces the correct result. + std::vector b_data(100, 0); + cg.call({b_data, 100}); + + std::vector b_ref(100, 0); + for (int i = 0; i < 100; i++) { + b_ref[i] = i * i; + } + assertAllEqual(b_data, b_ref); +} + +TEST(LoopNest, LoopNestComputeAt_2) { + // Verify that compute_at works on the following example: + // + // for (int py = 0; py < H+1; py++) { + // for (int px = 0; px < W+1; px++) { + // p[py, px] = py*px + // } + // } + // for (int cy = 0; cy < H; cy++) { + // for (int cx = 0; cx < W; cx++) { + // c[py, px] = p[cy,cx] + p[cy+1,cx] + + // p[cy,cx+1] + p[cy+1,cx+1] + // } + // } + + const int kW = 16, kH = 16; + VarHandle W("W", kInt); + VarHandle H("H", kInt); + Tensor p = Compute( + "prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) { + return px * py; + }); + Tensor c = + Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) { + return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) + + p.load(y + 1, x + 1); + }); + + std::vector c_ref(kW * kH, 0); + for (int y = 0; y < kH; y++) { + for (int x = 0; x < kW; x++) { + c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); + } + } + LoopNest orig_loopnest({c}, {p, c}); + + { + // First let's try to compute P at axis cy (the outer loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); + StmtPtr s = cg.stmt(); + + // Check the IR we produced + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] +# CHECK: for (int i_2 = 0; i_2 < H; i_2++) +# CHECK: for +# CHECK: for +# CHECK: for (int i_3 = 0; i_3 < W; i_3++) +# CHECK-NOT: prod[ +# CHECK: cons[ +# CHECK: Free(temp))IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + + assertAllEqual(c_data, c_ref); + } + { + // Now let's try to compute P at axis cx (the inner loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); + StmtPtr s = cg.stmt(); + + // Check the IR we produced + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] +# CHECK: for (int i_2 = 0; i_2 < H; i_2++) +# CHECK: for (int i_3 = 0; i_3 < W; i_3++) +# CHECK: for +# CHECK: for +# CHECK-NOT: prod[ +# CHECK: cons[ +# CHECK: Free(temp))IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + + assertAllEqual(c_data, c_ref); + } +} + +TEST(LoopNest, LoopNestComputeAt_3) { + // Verify that compute_at works on the following example: + // + // A(x,y) = x*y + // B(x,y) = A(x, y) + // C(x,y) = B(x+1, y) + // D(x,y) = A(x, y+1) + C(x, y) + // + // i.e. when 'A' comes to 'D' directly and indirectly through 'C'. + + const int kW = 16, kH = 16; + VarHandle W("W", kInt); + VarHandle H("H", kInt); + Tensor A = Compute( + "A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) { + return ax * ay; + }); + Tensor B = Compute( + "B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) { + return A.load(by, bx); + }); + Tensor C = + Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) { + return B.load(cy, cx + 1); + }); + Tensor D = + Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) { + return A.load(dy + 1, dx) + C.load(dy, dx); + }); + + std::vector c_ref(kW * kH, 0); + for (int y = 0; y < kH; y++) { + for (int x = 0; x < kW; x++) { + c_ref[y * kW + x] = (y + 1) * x + y * (x + 1); + } + } + + LoopNest orig_loopnest({D}, {A, B, C, D}); + { + // First let's try to compute A at axis dy (the outer loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); + StmtPtr s = cg.stmt(); + + // Check the IR we produced + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[1, W] +# CHECK: for (int i = 0; i < H + 1; i++) +# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) +# CHECK: A[ +# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) +# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) +# CHECK: B[ +# CHECK: for (int i_4 = 0; i_4 < H; i_4++) +# CHECK: for (int i_5 = 0; i_5 < W; i_5++) +# CHECK: C[ +# CHECK: for (int i_6 = 0; i_6 < H; i_6++) +# CHECK: for (int i_7 = 0; i_7 < W; i_7++) +# CHECK-NOT: A[)IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + + assertAllEqual(c_data, c_ref); + } + { + // Now let's try to compute A at axis dx (the inner loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); + StmtPtr s = cg.stmt(); + + // Check the IR we produced + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[1, 1] +# CHECK: for (int i = 0; i < H + 1; i++) +# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) +# CHECK: A[ +# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) +# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) +# CHECK: B[ +# CHECK: for (int i_4 = 0; i_4 < H; i_4++) +# CHECK: for (int i_5 = 0; i_5 < W; i_5++) +# CHECK: C[ +# CHECK: for (int i_6 = 0; i_6 < H; i_6++) +# CHECK: for (int i_7 = 0; i_7 < W; i_7++) +# CHECK-NOT: A[)IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + + assertAllEqual(c_data, c_ref); + } +} + +using Axis = const VarHandle&; + +TEST(LoopNest, Reduce2dComputeAt) { + const int kW = 16, kH = 16; + VarHandle W("W", kInt); + VarHandle H("H", kInt); + + Tensor p = Compute( + "prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; }); + Tensor c = Reduce( + "cons", + {H, W}, + Sum(), + [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); }, + {2, 2}); + + std::vector c_ref(kW * kH, 0); + for (int y = 0; y < kH; y++) { + for (int x = 0; x < kW; x++) { + c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); + } + } + LoopNest orig_loopnest({c}, {p, c}); + checkIR(orig_loopnest.root_stmt(), R"IR( +# CHECK: for (int i = 0; i < H + 1; i++) { +# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) { +# CHECK: prod[i, i_1] = i_1 * i; +# CHECK: } +# CHECK: } +# CHECK: for (int i_2 = 0; i_2 < H; i_2++) { +# CHECK: for (int i_3 = 0; i_3 < W; i_3++) { +# CHECK: cons[i_2, i_3] = int(0); +# CHECK: for (int i_4 = 0; i_4 < 2; i_4++) { +# CHECK: for (int i_5 = 0; i_5 < 2; i_5++) { +# CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5}); +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: } +)IR"); + + { + // First let's try to compute P at axis cy (the outer loop) + LoopNest l(orig_loopnest); + auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); + // FIXME: Calling simplify here breaks the IR: + // MALFORMED INPUT: could not find base node in Load - temp[...] + // l.simplify(); + l.eliminateDeadStores(); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); + checkIR(cg.stmt(), R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] +# CHECK: for (int i = 0; i < H; i++) { +# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { +# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) { +# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0); +# CHECK: } +# CHECK: } +# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { +# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0); +# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { +# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { +# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]); +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: Free(temp); +)IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + assertAllEqual(c_data, c_ref); + } + { + // Now let's try to compute P at axis cx (the inner loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); + l.simplify(); + l.eliminateDeadStores(); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); + checkIR(cg.stmt(), R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] +# CHECK: for (int i = 0; i < H; i++) { +# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { +# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { +# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) { +# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1); +# CHECK: } +# CHECK: } +# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0; +# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { +# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { +# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]); +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: Free(temp); +)IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + assertAllEqual(c_data, c_ref); + } +} + +TEST(LoopNest, DISABLED_Conv1d_NH) { + // Lots of stuff is broken here. The computeAt swaps the axes for some odd + // reason. Even without that, the index flattener fails due to "dimensions + // mismatch in flatten index". + + int N = 4; + int H = 256; + int R = 3; + int Pad = 1; + BufHandle IP("input", {H}, kFloat); + + Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) { + auto cond = CompareSelect::make(h, Pad, 1, 0, kLT); + cond = CompareSelect::make(h, H + Pad, 1, cond, kGE); + return ifThenElse(cond, 0.f, IP.load(n, h - Pad)); + }); + Tensor B = Reduce( + "B", + {N, H}, + Sum(), + [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); }, + {R}); + LoopNest l({B}); + checkIR(l.root_stmt(), R"IR( +# CHECK: for (int np = 0; np < 4; np++) { +# CHECK: for (int hp = 0; hp < 258; hp++) { +# CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]); +# CHECK: } +# CHECK: } +# CHECK: for (int n = 0; n < 4; n++) { +# CHECK: for (int h = 0; h < 256; h++) { +# CHECK: B[n, h] = float(0); +# CHECK: for (int r = 0; r < 3; r++) { +# CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r}); +# CHECK: } +# CHECK: } +# CHECK: } +)IR"); + std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); + // FIXME: The current IR is totally broken. The body of the inlined loop is: + + // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0), + // 0.f, input[idx1 + 0, (idx0 + n) - 1]); + + // Which seems to mix up the axes. The CHECK below is my best guess at what + // the input "should" look like + + checkIR(l.root_stmt(), R"IR( +# CHECK: for (int n = 0; n < 4; n++) { +# CHECK: for (int idx0 = 0; idx0 < 1; idx0++) { +# CHECK: for (int idx1 = 0; idx1 < 258; idx1++) { + temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]); +# CHECK: } +# CHECK: } +# CHECK: for (int h = 0; h < 256; h++) { +# CHECK: B[n, h] = float(0); +# CHECK: for (int r = 0; r < 3; r++) { +# CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r}); +# CHECK: } +# CHECK: } +# CHECK: } +)IR"); + + l.simplify(); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + SimpleIREvaluator cg(s, {IP, B}); + // auto At = at::ones({N, H}, at::kFloat); + auto At = at::arange(N * H, at::kFloat).reshape({N, H}); + auto Rt = at::conv1d( + At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3); + auto Bt = at::empty_like(Rt); + cg.call({At.data_ptr(), Bt.data_ptr()}); + ASSERT_TRUE(at::allclose(Rt, Bt)); +} + +class LoopOrderHelper : public IRVisitor { + std::stringstream ordering; + + public: + std::string getOrder(StmtPtr s) { + ordering.str(""); + s->accept(this); + return ordering.str(); + } + + void visit(const ForPtr& v) final { + ordering << v->var()->name_hint() << ","; + IRVisitor::visit(v); + } +}; + +TEST(LoopNest, LoopNestReorderAxis1) { + Tensor tensor = + Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + std::vector stmt1_output(6, 0); + SimpleIREvaluator cg(stmt1, {tensor}); + cg.call({stmt1_output}); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[1]); + StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + ASSERT_NE(stmt1, stmt2); + LoopOrderHelper loopOrderHelper; + std::string order1 = loopOrderHelper.getOrder(stmt1); + std::string order2 = loopOrderHelper.getOrder(stmt2); + + ASSERT_EQ(order1, "j,i,"); + ASSERT_EQ(order2, "i,j,"); + + std::vector stmt2_output(6, 0); + SimpleIREvaluator cg2(stmt2, {tensor}); + cg.call({stmt2_output}); + + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(stmt1_output[i], stmt2_output[i]); + } + + // Reorder them back. + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[1]); + StmtPtr stmt3 = l.root_stmt(); + + std::string order3 = loopOrderHelper.getOrder(stmt3); + ASSERT_EQ(order3, order1); + + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt3; + + // Should be identical to the unreordered statement. + ASSERT_EQ(oss1.str(), oss2.str()); +} + +TEST(LoopNest, LoopNestReorderPartialAxes) { + Tensor tensor = Compute( + "f", + {2, 3, 4}, + [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + + cast(z) * z; + }); + LoopNest l({tensor}); + + LoopOrderHelper loopOrderHelper; + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,"); + + std::vector stmt1_output(24, 0); + SimpleIREvaluator cg(stmt1, {tensor}); + cg.call({stmt1_output}); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[1]); + ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,"); + + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); + + std::vector stmt2_output(24, 0); + SimpleIREvaluator cg2(stmt2, {tensor}); + cg2.call({stmt2_output}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(stmt1_output[i], stmt2_output[i]); + } + + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[1], loops[2]); + ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,"); + + StmtPtr stmt3 = Stmt::clone(l.root_stmt()); + + std::vector stmt3_output(24, 0); + SimpleIREvaluator cg3(stmt3, {tensor}); + cg3.call({stmt3_output}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(stmt1_output[i], stmt3_output[i]); + } +} + +TEST(LoopNest, LoopNestReorderInternalAxis) { + Tensor tensor = Compute( + "f", + {1, 2, 3, 4}, + [](const VarHandle& w, + const VarHandle& x, + const VarHandle& y, + const VarHandle& z) { + return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + + cast(z) * z; + }); + LoopNest l({tensor}); + + LoopOrderHelper loopOrderHelper; + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,"); + + std::vector stmt1_output(24, 0); + SimpleIREvaluator cg(stmt1, {tensor}); + cg.call({stmt1_output}); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[2], loops[1]); + ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,"); + + StmtPtr stmt2 = l.root_stmt(); + + std::vector stmt2_output(24, 0); + SimpleIREvaluator cg2(stmt2, {tensor}); + cg2.call({stmt2_output}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(stmt1_output[i], stmt2_output[i]); + } +} + +TEST(LoopNest, LoopNestReorderEnclosingAxis) { + Tensor tensor = Compute( + "f", + {1, 2, 3, 4}, + [](const VarHandle& w, + const VarHandle& x, + const VarHandle& y, + const VarHandle& z) { + return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + + cast(z) * z; + }); + LoopNest l({tensor}); + + LoopOrderHelper loopOrderHelper; + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + std::vector stmt1_output(24, 0); + SimpleIREvaluator cg(stmt1, {tensor}); + cg.call({stmt1_output}); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[3]); + ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,"); + + StmtPtr stmt2 = l.root_stmt(); + + std::vector stmt2_output(24, 0); + SimpleIREvaluator cg2(stmt2, {tensor}); + cg2.call({stmt2_output}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(stmt1_output[i], stmt2_output[i]); + } +} + +TEST(LoopNest, LoopNestReorderSameAxis) { + Tensor tensor = + Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[1], loops[1]); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); + + std::ostringstream oss, oss2; + oss << *stmt1; + oss2 << *stmt2; + ASSERT_EQ(oss.str(), oss2.str()); +} + +TEST(LoopNest, LoopNestReorderExtraStatements) { + /* We're going for a structure like this: + * for i in ... + * Stmt 1 + * for j in ... + * Stmt 2 + * for k in ... + * Stmt 3 + * Stmt 4 + */ + + Tensor tensor = Compute( + "f", + {2, 3, 4}, + [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + + cast(z) * z; + }); + LoopNest l({tensor}); + + BufHandle extra("res", {6, 3}, kFloat); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + VarHandle i = VarHandle(loops[0]->var()); + + StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f); + StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f); + // stmt 3 is the Function body. + StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f); + + loops[0]->body()->prepend_stmt(store_1); + loops[1]->body()->prepend_stmt(store_2); + loops[1]->body()->append_stmt(store_3); + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + std::vector extra1(6, 0); + std::vector res1(24, 0); + SimpleIREvaluator cg(stmt1, {tensor, extra}); + cg.call({res1, extra1}); + + /* Then we reorder loop y and z, we want it to look like: + * + * for i in ... + * Stmt 1 + * for j in ... + * Stmt 2 + * for j_1 in ... + * for k in ... + * Stmt 3 + * for j_2 in ... + * Stmt 4 + * + * We need extra loops because we don't have dependency info about stmt 3 + * and 4. + * + */ + + LoopNest::reorderAxis(loops[1], loops[2]); + StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + // Check the IR we produced + checkIR(stmt2, R"IR( +# CHECK: for +# CHECK: res[i, 0] = 1 +# CHECK: for +# CHECK: res[i, 1] = 2 +# CHECK: for +# CHECK: for +# CHECK: f[ +# CHECK: for +# CHECK: res[i, 2] = 4 +)IR"); + + std::vector extra2(6, 0); + std::vector res2(24, 0); + SimpleIREvaluator cg2(stmt2, {tensor, extra}); + cg2.call({res2, extra2}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(res1[i], res2[i]); + } + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(extra1[i], extra2[i]); + } + + /* Now reorder x and the y above stmt 3: + * + * + * for x in ... + * Stmt 1 + * for y in ... + * Stmt 2 + * + * for y in ... + * for z in ... + * for x in ... + * Stmt 3 + * + * for x in ... + * for y in ... + * Stmt 4 + * + * + */ + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[2]); + StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + // Check the IR we produced + checkIR(stmt3, R"IR( +# CHECK: for +# CHECK: res[i, 0] = 1 +# CHECK: for +# CHECK: res[i, 1] = 2 +# CHECK: for +# CHECK: for +# CHECK: for +# CHECK: f[ +# CHECK: for +# CHECK: for +# CHECK: res[i_2, 2] = 4 +)IR"); + + std::vector extra3(6, 0); + std::vector res3(24, 0); + SimpleIREvaluator cg3(stmt3, {tensor, extra}); + cg3.call({res3, extra3}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(res1[i], res3[i]); + } + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(extra1[i], extra3[i]); + } +} + +void LoopNestReorderTestHelper( + bool prepend, + bool append, + int index1, + int index2) { + Tensor c = Compute( + "5d", {2, 3, 2, 3, 2}, [](const std::vector&) { return -1; }); + LoopNest l({c}); + + BufHandle extra("extra", {5}, kInt); + + auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + int j = 0; + for (auto l : loops) { + // Add an increment at each layer of the loop which counts the number of + // times the loop executes. + LoadPtr load = + alloc(extra.node(), std::vector({alloc(j)})); + AddPtr add = alloc(load, alloc(1)); + StmtPtr store = alloc( + extra.node(), std::vector({alloc(j)}), add); + if (prepend) { + l->body()->prepend_stmt(store); + } + if (append) { + l->body()->append_stmt(Stmt::clone(store)); + } + + j++; + } + + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); + + std::vector extra1(5, 0); + std::vector res1(2 * 3 * 2 * 3 * 2, 0); + SimpleIREvaluator cg(stmt1, {c, extra}); + cg.call({res1, extra1}); + + std::vector loopExtents = {2, 3, 2, 3, 2}; + + int expected_loops = 0; + if (prepend) { + expected_loops++; + } + if (append) { + expected_loops++; + } + for (int i = 0; i < 5; ++i) { + expected_loops *= loopExtents[i]; + ASSERT_EQ(extra1[i], expected_loops); + } + + loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::reorderAxis(loops[index1], loops[index2]); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); + + std::ostringstream oss, oss2; + oss << *stmt1; + oss2 << *stmt2; + ASSERT_NE(oss.str(), oss2.str()); + + std::vector extra2(5, 0); + std::vector res2(2 * 3 * 2 * 3 * 2, 0); + SimpleIREvaluator cg2(stmt2, {c, extra}); + cg2.call({res2, extra2}); + + expected_loops = 0; + if (prepend) { + expected_loops++; + } + if (append) { + expected_loops++; + } + + for (int i = 0; i < 5; ++i) { + expected_loops *= loopExtents[i]; + ASSERT_EQ(extra2[i], expected_loops); + } + + for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) { + ASSERT_EQ(res2[i], res1[i]); + } +} + +TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) { + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 5; ++j) { + // skip noops, since we check the loop isn't the same after reordering. + if (i != j) { + LoopNestReorderTestHelper(true, false, i, j); + } + } + } +} + +TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) { + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 5; ++j) { + // skip noops, since we check the loop isn't the same after reordering. + if (i != j) { + LoopNestReorderTestHelper(false, true, i, j); + } + } + } +} + +TEST(LoopNest, LoopNestReorderLongStringFull) { + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 5; ++j) { + // skip noops, since we check the loop isn't the same after reordering. + if (i != j) { + LoopNestReorderTestHelper(true, true, i, j); + } + } + } +} + +TEST(LoopNest, LoopNestReorderInternalLoopNest) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + BufHandle c_buf("c", {M, N}, kFloat); + BufHandle d_buf("d", {M, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); + }); + Tensor z = Compute( + "z", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + y.load(m, n, k); + }); + + LoopNest l({z}, {x, y, z}); + ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2]; + ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0]; + LoopNest::reorderAxis(a, b); + + l.prepareForCodegen(); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + + // Check the IR we produced has the 3 nests in the right order, but k and m + // swapped in the middle. + checkIR(stmt, R"IR( +# CHECK: < 4 +# CHECK: < 5 +# CHECK: < 6 +# CHECK: < 6 +# CHECK: < 5 +# CHECK: < 4 +# CHECK: < 4 +# CHECK: < 5 +# CHECK: < 6)IR"); + + { + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N); + PaddedBuffer d_v(M, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + c_v(i, j) = i + j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + d_v(i, j) = i * j; + } + } + + PaddedBuffer z_v(M, N, K); + PaddedBuffer z_ref(M, N, K); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); + } + } + } + + SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); + eval(a_v, b_v, c_v, d_v, z_v); + ExpectAllNear(z_v, z_ref, 1e-5); + } +} + +TEST(LoopNest, OuterLoopVectorization) { + Tensor tensor = + Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + + ASSERT_TRUE( + LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0])); + + StmtPtr root_stmt = l.root_stmt(); + BlockPtr outer_block = to(root_stmt); + ASSERT_NE(outer_block, nullptr); + while (BlockPtr inner_block = to(outer_block->front())) { + outer_block = inner_block; + } + + // Verify that we have only a single loop level remaining after + // vectorization. + ASSERT_EQ(outer_block->nstmts(), 1); + ForPtr for_loop = to(outer_block->front()); + ASSERT_NE(for_loop, nullptr); + BlockPtr for_body = for_loop->body(); + ASSERT_EQ(for_body->nstmts(), 1); + ASSERT_EQ(to(for_body->front()), nullptr); +} + +TEST(LoopNest, VectorizeLoopNotNormalized) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 1; j < 5; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 1, 5, for_body); + auto outer_for = For::make(i, 0, 10, inner_for); + auto block = Block::make({outer_for}); + LoopNest l(block, {a_buf.node()}); + + ASSERT_TRUE(LoopNest::vectorize(inner_for)); + ASSERT_EQ(outer_for->body()->nstmts(), 1); + ASSERT_EQ(to(outer_for->body()->front()), nullptr); +} + +namespace { + +std::string constantUpperBoundLoopIR(int upper_bound_val) { + ExprHandle upper_bound(upper_bound_val); + Tensor A = + Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); + LoopNest l({A}); + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(loops[0], &unrolled); + std::ostringstream oss; + oss << *unrolled; + return oss.str(); +} + +} // namespace + +TEST(LoopNest, Unroll) { + const std::string actual = constantUpperBoundLoopIR(3); + const std::string& verification_pattern = + R"IR( +# CHECK: A[0] = 0; +# CHECK: A[1] = 2; +# CHECK: A[2] = 4)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, actual); +} + +TEST(LoopNest, UnrollOuter) { + ExprHandle outer_bound(3); + ExprHandle inner_bound(4); + Tensor A = Compute( + "A", + {outer_bound, inner_bound}, + [&](const VarHandle& x, const VarHandle& y) { return x + y; }); + LoopNest l({A}); + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(loops[0], &unrolled); + checkIR(unrolled, R"IR( +# CHECK: for (int i = 0; i < 4; i++) { +# CHECK: A[0, i] = i; +# CHECK: } +# CHECK: for (int i = 0; i < 4; i++) { +# CHECK: A[1, i] = i + 1; +# CHECK: } +# CHECK: for (int i = 0; i < 4; i++) { +# CHECK: A[2, i] = i + 2; +# CHECK: })IR"); +} + +TEST(LoopNest, UnrollInner) { + ExprHandle outer_bound(3); + ExprHandle inner_bound(4); + Tensor A = Compute( + "A", + {outer_bound, inner_bound}, + [&](const VarHandle& x, const VarHandle& y) { return x + y; }); + LoopNest l({A}); + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll( + static_to(loops[0]->body()->stmts().front()), &unrolled); + checkIR(loops[0], R"IR( +# CHECK: for (int i = 0; i < 3; i++) { +# CHECK: A[i, 0] = i; +# CHECK: A[i, 1] = i + 1; +# CHECK: A[i, 2] = i + 2; +# CHECK: A[i, 3] = i + 3; +# CHECK: })IR"); +} + +TEST(LoopNest, UnrollMultipleStatements) { + const int kTotalSize = 3; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle x("x", kInt); + auto f = For::make( + x, + 0, + kTotalSize, + Block::make( + {Store::make(a_buf, {x}, x * 2), + Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); + auto parent_block = Block::make({f}); + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(f, &unrolled); + checkIR(unrolled, R"IR( +# CHECK: A[0] = 0; +# CHECK: B[0] = A[0]; +# CHECK: A[1] = 2; +# CHECK: B[1] = A[1]; +# CHECK: A[2] = 4 +# CHECK: B[2] = A[2];)IR"); +} + +TEST(LoopNest, UnrollNonLiteralConstantBounds) { + // Input IR: + // for (int i = 2 - 1; i < 12 / 3; i++) { + // for (int j = 0; j < 4; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {3, 4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, 4, for_body); + auto outer_for = For::make( + i, + IntImm::make(2) - IntImm::make(1), + IntImm::make(12) / IntImm::make(3), + inner_for); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto b = Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(loops[0], &unrolled); + checkIR(unrolled, R"IR( +# CHECK: for (int j = 0; j < 4; j++) { +# CHECK: A[1, j] = j; +# CHECK: } +# CHECK: for (int j = 0; j < 4; j++) { +# CHECK: A[2, j] = 2 * j; +# CHECK: } +# CHECK: for (int j = 0; j < 4; j++) { +# CHECK: A[3, j] = 3 * j; +# CHECK: })IR"); +} + +TEST(LoopNest, UnrollNonConstantBounds) { + // Input IR: + // for (int i = 0; i < M; i++) { + // for (int j = 0; j < N; j++) { + // A[i, j] = i * j; + // } + // } + VarHandle M("M", kInt); + VarHandle N("N", kInt); + BufHandle a_buf("A", {M, N}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, N, for_body); + auto outer_for = For::make(i, 0, M, inner_for); + auto block = Block::make({outer_for}); + LoopNest l(block, {a_buf.node()}); + + LoopNest::unroll(inner_for, 8); + l.simplify(); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) { + # CHECK: A[i, 8 * j_outer] = + # CHECK: A[i, 8 * j_outer + 1] = + # CHECK: A[i, 2 * (4 * j_outer + 1)] = + # CHECK: A[i, 8 * j_outer + 3] = + # CHECK: A[i, 4 * (2 * j_outer + 1)] = + # CHECK: A[i, 8 * j_outer + 5] = + # CHECK: A[i, 8 * j_outer + 6] = + # CHECK: A[i, 8 * j_outer + 7] = + # CHECK: } + # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) { + # CHECK: A[i, 8 * (N / 8) + j_tail] = + # CHECK: } + # CHECK: } + )IR"); +} + +TEST(LoopNest, UnrollByFactorsLessThan2) { + // Input IR: + // for (int i = 0; i < M; i++) { + // for (int j = 0; j < N; j++) { + // A[i, j] = i * j; + // } + // } + VarHandle M("M", kInt); + VarHandle N("N", kInt); + BufHandle a_buf("A", {M, N}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, N, for_body); + auto outer_for = For::make(i, 0, M, inner_for); + auto block = Block::make({outer_for}); + LoopNest l(block, {a_buf.node()}); + + // Unrolling by factor = 1 should do nothing. + LoopNest::unroll(inner_for, 1); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); + + // Unrolling by factor = 0 should do nothing. + LoopNest::unroll(inner_for, 0); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); + + // Unrolling by negative factor should do nothing. + LoopNest::unroll(inner_for, -2); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); +} + +TEST(LoopNest, UnrollByFactorEqualToIters) { + // Input IR: + // for (int i = 0; i < 5; i++) { + // A[i] = i * i; + // } + BufHandle a_buf("A", {5}, kInt); + VarHandle i("i", kInt); + auto for_body = Block::make({Store::make(a_buf, {i}, i * i)}); + auto for_loop = For::make(i, 0, 5, for_body); + auto block = Block::make({for_loop}); + LoopNest l(block, {a_buf.node()}); + + LoopNest::unroll(for_loop, 5); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++) + # CHECK: A[5 * i_outer] + # CHECK: A[5 * i_outer + 1] + # CHECK: A[5 * i_outer + 2] + # CHECK: A[5 * i_outer + 3] + # CHECK: A[5 * i_outer + 4] + )IR"); +} + +TEST(LoopNest, UnrollEmpty) { + const std::string actual = constantUpperBoundLoopIR(0); + const std::string& verification_pattern = R"IR( +# CHECK-NOT: A[ + )IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, actual); +} + +TEST(LoopNest, NoUnroll) { + VarHandle upper_bound("N", kInt); + Tensor A = + Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); + LoopNest l({A}); + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; + StmtPtr unrolled = nullptr; + ASSERT_THROWS_WITH( + LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop"); +} + +TEST(LoopNest, UnrollWithLet) { + const int kTotalSize = 3; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle e("e", kInt); + VarHandle x("x", kInt); + auto f = For::make( + x, + 0, + kTotalSize, + Block::make( + {Let::make(e, 7), + Store::make(a_buf, {x}, e), + Store::make(b_buf, {x}, e + 1)})); + auto parent_block = Block::make({f}); + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(f, &unrolled); + std::ostringstream oss; + oss << *unrolled; + const std::string& verification_pattern = + R"IR( +# CHECK: int e = 7; +# CHECK: A[0] = e; +# CHECK: B[0] = e + 1; +# CHECK: A[1] = e; +# CHECK: B[1] = e + 1; +# CHECK: A[2] = e; +# CHECK: B[2] = e + 1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector a_v(kTotalSize, 0); + std::vector b_v(kTotalSize, 0); + SimpleIREvaluator eval(unrolled, {a_buf, b_buf}); + eval(a_v, b_v); + for (int i = 0; i < kTotalSize; ++i) { + ASSERT_EQ(a_v[i], 7); + ASSERT_EQ(b_v[i], 8); + } +} + +TEST(LoopNest, IsNormalized) { + // Input IR: + // for (int i = 50; i < 100; i++) { + // A[i] = B[i]; + // } + BufHandle a_buf("A", {ExprHandle(100)}, kInt); + BufHandle b_buf("B", {ExprHandle(100)}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto for_stmt = + For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i}))); + Block::make({for_stmt}); + ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); + + for_stmt->set_start(alloc(0)); + ASSERT_TRUE(LoopNest::isNormalized(for_stmt)); + + VarHandle N("N", kInt); + for_stmt->set_start(N.node()); + ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); +} + +TEST(LoopNest, NormalizeStartPositive) { + // Input IR: + // for (int x = 50; x < 100; x++) { + // A[x] = B[x]; + // B[x] = x * 2; + // } + const int kTotalSize = 50; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), + Store::make(b_buf, {x}, x * 2)}); + auto for_stmt = For::make(x, 50, 100, for_body); + Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 50; x++) { + # CHECK: A[x + 50] = B[x + 50]; + # CHECK: B[x + 50] = 2 * (x + 50); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeStartNegative) { + // Input IR: + // for (int x = -50; x < 100; x++) { + // A[x + 50] = B[x + 50]; + // B[x + 50] = x * 2; + // } + const int kTotalSize = 150; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + auto for_body = Block::make( + {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})), + Store::make(b_buf, {x + 50}, x * 2)}); + auto for_stmt = For::make(x, -50, 100, for_body); + Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 150; x++) { + # CHECK: A[x] = B[x]; + # CHECK: B[x] = 2 * (x - 50); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeStartZero) { + // Input IR: + // for (int x = 0; x < 100; x++) { + // A[x] = B[x]; + // B[x] = x * 2; + // } + // Should not be modified. + + const int kTotalSize = 100; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), + Store::make(b_buf, {x}, x * 2)}); + auto for_stmt = For::make(x, 0, 100, for_body); + Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 100; x++) { + # CHECK: A[x] = B[x]; + # CHECK: B[x] = 2 * x; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeStartVariable) { + // Input IR: + // for (int x = y; x < 100; x++) { + // A[x] = B[x]; + // B[x] = x * 2; + // } + + const int kTotalSize = 100; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), + Store::make(b_buf, {x}, x * 2)}); + auto for_stmt = For::make(x, y, 100, for_body); + auto parent_block = Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 100 - y; x++) { + # CHECK: A[x + y] = B[x + y]; + # CHECK: B[x + y] = 2 * (x + y); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeOnNestedOuterLoop) { + // Input IR: + // for (int x = 50; x < 100; x++) { + // for (int y = 10; y < 100; y++) { + // A[x] = A[x] + B[y] + y * 2; + // } + // } + + BufHandle a_buf("A", {ExprHandle(50)}, kInt); + BufHandle b_buf("B", {ExprHandle(100)}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto inner_for_body = Store::make( + a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); + auto inner_for = For::make(y, 10, 100, inner_for_body); + auto for_stmt = For::make(x, 50, 100, inner_for); + Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 50; x++) { + # CHECK: for (int y = 10; y < 100; y++) { + # CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeOnNestedInnerLoop) { + // Input IR: + // for (int x = 50; x < 100; x++) { + // for (int y = 10; y < 100; y++) { + // A[x] = A[x] + B[y] + y * 2; + // } + // } + + BufHandle a_buf("A", {ExprHandle(50)}, kInt); + BufHandle b_buf("B", {ExprHandle(100)}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto inner_for_body = Store::make( + a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); + auto inner_for = For::make(y, 10, 100, inner_for_body); + auto for_stmt = For::make(x, 50, 100, inner_for); + Block::make({for_stmt}); + + LoopNest::normalize(inner_for); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 50; x < 100; x++) { + # CHECK: for (int y = 0; y < 90; y++) { + # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeAndSplitWithTail) { + // Create a dummy tensor to construct LoopNest. + ExprHandle n(100); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + + // Input IR: + // for (int x = 5; x < 10; x++) { + // A[x] = x * 2; + // } + const int kTotalSize = 5; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2)); + auto parent_block = Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + ForPtr x_inner; + ForPtr x_tail; + LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); + + auto x_outer_result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss_outer; + oss_outer << *x_outer_result; + const std::string& expected_outer_ir = + R"IR( + # CHECK: { + # CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); + + auto x_tail_result = IRSimplifier::simplify(x_tail); + std::ostringstream oss_tail; + oss_tail << *x_tail_result; + const std::string& expected_tail_ir = + R"IR( + # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) { + # CHECK: A[x_tail + 5] = 2 * (x_tail + 5); + )IR"; + torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); +} + +TEST(LoopNest, NotNormalizeAndSplitWithTail) { + // Create a dummy tensor to construct LoopNest. + ExprHandle n(100); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + + // Input IR: + // for (int x = 5; x < 15; x++) { + // A[x] = x * 2; + // } + const int kTotalSize = 10; + BufHandle a_buf("A", {kTotalSize}, kInt); + VarHandle x("x", kInt); + auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2)); + auto parent_block = Block::make({for_stmt}); + + ForPtr x_inner; + ForPtr x_tail; + LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail); + + auto x_outer_result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss_outer; + oss_outer << *x_outer_result; + const std::string& expected_outer_ir = + R"IR( + # CHECK: { + # CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); + + auto x_tail_result = IRSimplifier::simplify(x_tail); + std::ostringstream oss_tail; + oss_tail << *x_tail_result; + const std::string& expected_tail_ir = + R"IR( + # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) { + # CHECK: A[x_tail + 13] = 2 * (x_tail + 13); + )IR"; + torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); +} + +TEST(LoopNest, FlattenSimpleLoopNest2D) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, 5, for_body); + auto outer_for = For::make(i, 0, 10, inner_for); + auto parent_block = Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, loops.front()); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { + # CHECK: A[i_flat / 5, i_flat % 5] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(10, 5); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(10, 5); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenSimpleLoopNest3D) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // for (int k = 0; k < 7; k++) { + // A[i,j,k] = i + j * k; + // } + // } + // } + BufHandle a_buf("A", {10, 5, 7}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)}); + auto for1 = For::make(k, 0, 7, for_body); + auto for2 = For::make(j, 0, 5, for1); + auto for3 = For::make(i, 0, 10, for2); + auto parent_block = Block::make({for3}); + + std::vector loops = {for3, for2, for1}; + ForPtr flattened = nullptr; + ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, loops.front()); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) { + # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(10, 5, 7); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(10, 5, 7); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenLoopNestAfterNormalize) { + // Input IR: + // for (int i = 2; i < 10; i++) { + // for (int j = 3; j < 15; j++) { + // A[i - 2,j - 3] = i * j; + // } + // } + BufHandle a_buf("A", {8, 12}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)}); + auto inner_for = For::make(j, 3, 15, for_body); + auto outer_for = For::make(i, 2, 10, inner_for); + auto parent_block = Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, loops.front()); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) { + # CHECK: A[i_flat / 12, i_flat % 12] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(8, 12); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(8, 12); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { + // Input IR: + // for (int i = 0; i < 15-5; i++) { + // for (int j = 0; j < 20/4; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = + For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body); + auto outer_for = + For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto b = Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, loops.front()); + + auto result = IRSimplifier::simplify(flattened); + checkIR(result, R"IR( + # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { + # CHECK: A[i_flat / 5, i_flat % 5] = + )IR"); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(10, 5); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(10, 5); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenImperfectLoopNest) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // A[i, i] = 0; + // for (int j = 0; j < 15; j++) { + // A[i,j] = i * j; + // } + // } + // Do not flatten. + + BufHandle a_buf("A", {10, 15}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, 15, for_body); + auto outer_for = For::make( + i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for})); + auto par = Block::make({outer_for}); + HashProvider hasher; + auto hash_before = hasher.hash(par); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, nullptr); + auto hash_after = hasher.hash(par); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, FlattenReductionLoopNest) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // S[i] = 0; + // for (int j = 0; j < 15; j++) { + // S[i] = S[i] + A[i,j]; + // } + // } + // Do not flatten. + + BufHandle a_buf("A", {10, 15}, kInt); + BufHandle s_buf("S", {10}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make( + s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))}); + auto inner_for = For::make(j, 0, 15, for_body); + auto outer_for = + For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for})); + auto par = Block::make({outer_for}); + HashProvider hasher; + auto hash_before = hasher.hash(par); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, nullptr); + auto hash_after = hasher.hash(par); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, FlattenReductionLoopNestFromTensor) { + const int M = 3; + const int N = 7; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + BufHandle b("b", {m, n}, kFloat); + Tensor c = Reduce("sum", {M}, Sum(), b, {N}); + LoopNest loop({c}); + HashProvider hasher; + auto hash_before = hasher.hash(loop.root_stmt()); + + auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1]; + ForPtr flattened = nullptr; + ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, nullptr); + auto hash_after = hasher.hash(loop.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, FlattenIncorrectLoopsAsInput) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // A[i,j] = i * j; + // } + // } + // for (int x = 0; x < 10; x++) { + // for (int y = 0; y < 5; y++) { + // A[x,y] = A[x,y] + x + y; + // } + // } + // Flatten({For_i, For_y}) => should not succeed + + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + auto par = Block::make({outer_for1, outer_for2}); + HashProvider hasher; + auto hash_before = hasher.hash(par); + + std::vector loops = {outer_for1, inner_for2}; + ForPtr flattened = nullptr; + ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, nullptr); + auto hash_after = hasher.hash(par); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, DetectInlineRankMismatch) { + const int kTotalSize = 8; + + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + Tensor a = Compute( + "a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); }); + Tensor reshape = Compute( + "reshape", + {kTotalSize / 2, 2}, + [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); }); + LoopNest l({reshape}, {a, reshape}); + ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a))); +} + +TEST(LoopNest, CacheReadsSimple) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 3); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); + + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + // just this once: verify the whole thing. + checkIR(result, R"IR( +#CHECK: Allocate(A); // dtype=int, dims=[64, 64] +#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10] +#CHECK: for (int i +#CHECK: for (int j +#CHECK: A[ +#CHECK: } +#CHECK: } +#CHECK: for (int i_1 +#CHECK: for (int j_1 +#CHECK: A_local[j_1] = A[ +#CHECK: } +#CHECK: for (int j_2 +#CHECK: B[j_2 + 10 * i_1] = A_local[j_2]; +#CHECK: } +#CHECK: } +#CHECK: for (int i_2 +#CHECK: for (int j_3 +#CHECK: C[ +#CHECK: } +#CHECK: } +#CHECK: Free(A_local); +#CHECK: Free(A); + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 3); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsOuter) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0]; + LoopNest::cacheAccesses(A.buf(), "A_local", i_loop); + + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + checkIR(result, R"IR( +#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] +#CHECK: A_local[j_1 + 11 * i_1] = +#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]); + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsInternal) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + checkIR(result, R"IR( +#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] +#CHECK: A_local[k + 11 * j_1] = +#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]); + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsInner) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + // note im changing the offset of the first arg of the first call to A. + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 34, j + 40) + A.load(i + 30, j + 41); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr body = l.getLoopBodyFor(B); + LoopNest::cacheAccesses(A.buf(), "A_local", body); + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + checkIR(result, R"IR( +#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] +#CHECK: A_local[l + 2 * k] = +#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]); + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheWritesSimple) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", a_loop); + + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + checkIR(result, R"IR( +#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64] +#CHECK: for (int j = 0; j < 64 +#CHECK: A_local[j] = i * j; +#CHECK: for (int j_1 = 0; j_1 < 64 +#CHECK: A[j_1 + 64 * i] = A_local[ +#CHECK: Free(A_local); +#CHECK-NOT: A_local + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, DeadStoreElimination) { + VarHandle y("y", kInt); + VarHandle x("x_tail", kInt); + BufHandle f("f", {26, 5}, kInt); + BufHandle g("g", {26, 5}, kInt); + ExprHandle x_outer_end = 5; + ExprHandle x_2 = x + x_outer_end * 4; + ForPtr stmt1 = For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + Block::make({ + Store::make(f, {x_2, y}, (x_2 + y)), + Store::make(g, {x_2, y}, (x_2 * y)), + }))); + StmtPtr stmt = Block::make({stmt1}); + + // Will eliminate if not used by an output. + LoopNest loop(Stmt::clone(stmt), {f.node()}); + loop.eliminateDeadStores(); + + checkIR(loop.root_stmt(), R"IR( +#CHECK: f[x_tail + 5 * 4, y] +#CHECK-NOT: g[x_tail + 5 * 4, y] + )IR"); + + // But won't eliminate if used by different outputs. + LoopNest loop2(stmt, {f.node(), g.node()}); + loop2.eliminateDeadStores(); + + checkIR(loop2.root_stmt(), R"IR( +#CHECK: f[x_tail + 5 * 4, y] +#CHECK: g[x_tail + 5 * 4, y] + )IR"); +} + +TEST(LoopNest, DeadStoreEliminationWithIntermediates) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + BufHandle f("f", {26 * 5}, kInt); + BufHandle g("g", {26 * 5}, kInt); + BufHandle h("h", {26, 5}, kInt); + ExprHandle x_outer_end = 5; + ExprHandle x_2 = x + x_outer_end * 4; + ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); + ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); + ForPtr stmt3 = For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + Block::make({ + Store::make(h, {x, y}, Load::make(f, {x * y})), + }))); + StmtPtr stmt = Block::make({stmt1, stmt2, stmt3}); + + // Will eliminate the write to g, but not f since it used by the producer of + // h. + LoopNest loop(Stmt::clone(stmt), {h.node()}); + loop.eliminateDeadStores(); + + checkIR(loop.root_stmt(), R"IR( + #CHECK: f[x] = x; + #CHECK-NOT: g[z] = + #CHECK: h[x, y] = f[x * y]; + )IR"); + + // Sanity check won't eliminate if g is an output. + LoopNest loop2(stmt, {h.node(), g.node()}); + loop2.eliminateDeadStores(); + + checkIR(loop2.root_stmt(), R"IR( + #CHECK: f[x] = x; + #CHECK: g[z] = z + 1; + #CHECK: h[x, y] = f[x * y]; + )IR"); +} + +TEST(LoopNest, CompoundTensorSimple) { + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + BlockPtr body = Block::make({outer_for1, outer_for2}); + + Tensor A = Tensor(a_buf.node(), body); + + LoopNest l({A}); + l.prepareForCodegen(); + + std::vector a_data(50, 0); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg(s, {A}); + + std::vector a_ref(50, 0); + + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 5; ++j) { + a_ref[i * 5 + j] = (i * j) + i + j; + } + } + cg.call({a_data}); + + assertAllEqual(a_data, a_ref); +} + +TEST(LoopNest, InlineConstantIndex) { + const int N = 10; + BufHandle x_buf("a", {1, N, 1}, kFloat); + Tensor y = Compute( + "f", + {1, N, 1}, + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { + return x_buf.load(m, n, o); + }); + Tensor z = Compute( + "f", + {1, N, 1}, + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { + return y.load(m, n, o); + }); + + LoopNest l({z}, {y, z}); + l.simplify(); + ASSERT_TRUE(l.computeInline(y.buf())); +} + +TEST(LoopNest, CompoundTensorUsed) { + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + BlockPtr body = Block::make({outer_for1, outer_for2}); + + Tensor A = Tensor(a_buf.node(), body); + Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i, j + 1) + A.load(i, j + 2); + }); + + LoopNest l({B}, {A, B}); + ASSERT_FALSE(l.computeInline(A.buf())); + l.prepareForCodegen(); + + std::vector a_data(50, 0); + std::vector b_data(50, 0); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg(s, {B}); + + std::vector b_ref(50, 0); + + auto AT = [](int i, int j) { return i * j + i + j; }; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 3; ++j) { + b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); + } + } + cg.call({b_data}); + + assertAllEqual(b_data, b_ref); +} + +TEST(LoopNest, InlineFromLoad) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto store_a = For::make(i, 0, N, Store::make(a, {i}, i)); + auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j}))); + LoopNest l(Block::make({store_a, store_b}), {b.node()}); + + l.computeInline(a.node()); + + // Check that A[j] is replaced with j after inlining + std::ostringstream oss; + oss << *l.root_stmt(); + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: for (int j +# CHECK-NOT: B[j] = A[j] +# CHECK-NEXT: B[j] = j +)IR", + oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsSimple) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) + // } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {15}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node()}); + nest.optimizeConditionals(); + + std::ostringstream oss; + oss << *nest.root_stmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i] = B[i] +# CHECK: for (int i = 0; i < 15 +# CHECK-NEXT: A[i + 5] = C[i] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsNestedConditions) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) + // } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 10, kLT), + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node()}); + nest.optimizeConditionals(); + + std::ostringstream oss; + oss << *nest.root_stmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i] = B[i] +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i + 5] = C[i] +# CHECK: for (int i = 0; i < 10 +# CHECK-NEXT: A[i + 10] = D[i] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsMultipleStores) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) + // } + // for (int j = 0; j < 100; j++) { + // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) + // } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {100}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto storeA = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, storeA); + auto storeB = Store::make( + b_buf, + {j}, + IfThenElse::make( + CompareSelect::make(j, 30, kLT), + Load::make(c_buf, {j}), + Load::make(d_buf, {j}))); + auto forJ = For::make(j, 0, 100, storeB); + auto par = Block::make({forI, forJ}); + + LoopNest nest(par, {a_buf.node()}); + nest.optimizeConditionals(); + + std::ostringstream oss; + oss << *nest.root_stmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i] = B[i] +# CHECK: for (int i = 0; i < 15 +# CHECK-NEXT: A[i + 5] = C[i] +# CHECK: for (int j = 0; j < 30 +# CHECK-NEXT: B[j] = C[j] +# CHECK: for (int j = 0; j < 70 +# CHECK-NEXT: B[j + 30] = D[j + 30] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { + // Input IR: + // for (int i = 0; i < 50; i++) { + // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) + // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) + // } + // Only the first conditional, in the write to A, will be optimized. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {100}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {100}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {100}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {100}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto storeA = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5}))); + auto storeB = Store::make( + b_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 30, kLT), + Load::make(c_buf, {i}), + Load::make(d_buf, {i}))); + auto forI = For::make(i, 0, 50, Block::make({storeA, storeB})); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node()}); + nest.optimizeConditionals(); + + std::ostringstream oss; + oss << *nest.root_stmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i] = B[i] +# CHECK-NEXT: B[i] = C[i] +# CHECK: for (int i = 0; i < 45 +# CHECK-NEXT: A[i + 5] = C[i] +# CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) + // } + // } + // Currently, this case where the condition variable `i` is not the + // inner-most loop variable, is not optimized. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 10, kLT), + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store)); + auto par = Block::make({forI}); + LoopNest nest(par, {a_buf.node()}); + + HashProvider hasher; + auto hash_before = hasher.hash(nest.root_stmt()); + nest.optimizeConditionals(); + auto hash_after = hasher.hash(nest.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10]) + // } + // No optimization should be done here because one of the conditions use '>'. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + IfThenElse::make( + CompareSelect::make(i, 10, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + LoopNest nest(par, {a_buf.node()}); + + HashProvider hasher; + auto hash_before = hasher.hash(nest.root_stmt()); + nest.optimizeConditionals(); + auto hash_after = hasher.hash(nest.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i'. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + VarHandle N("N", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, N, kLT), + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + LoopNest nest(par, {a_buf.node()}); + + HashProvider hasher; + auto hash_before = hasher.hash(nest.root_stmt()); + nest.optimizeConditionals(); + auto hash_after = hasher.hash(nest.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, OptimizeConditionalsInvalidCondition) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10]) + // } + // No optimization should be done here because one of the conditions use '>'. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 10, kLT), + IfThenElse::make( + CompareSelect::make(i, 5, kGT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + LoopNest nest(par, {a_buf.node()}); + + HashProvider hasher; + auto hash_before = hasher.hash(nest.root_stmt()); + nest.optimizeConditionals(); + auto hash_after = hasher.hash(nest.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, OptimizeConditionalsInvalidCondition2) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(10 colReduce(int M, int N) { + BufHandle a("a", {M, N}, kFloat); + Tensor t = Reduce( + "b", + {N}, + Sum(), + [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); }, + {M}); + return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))}; +} + +static StmtPtr splitTailReorder(Tensor b) { + constexpr int kVectorWidth = 8; + LoopNest nest({b}); + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; + nest.splitWithTail(loops[0], kVectorWidth); + // Now the loopnests will look like: + // + // for (int i_outer = 0; ... + // for (int i_inner = 0; ... + // b[i_outer * 8 + i_inner] = float(0); + // for (int j = 0; ... + // b[i_outer * 8 + i_inner] = ReduceOp(...); + // + // for (int i_tail = 0; ... + // b[i_tail + ((100 - 0) / 8) * 8] = float(0); + // for (int j = 0; ... + // b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...); + // + // Since there are 4 writes to b, we will get 4 loopnests from the + // call to `getAllLoopNestsWritingToBuf` below. + // + // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)" + // Loopnest #2: {i_outer, i_inner, j}; + // We will have to reorder i_inner and j. + auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf()); + LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]); + nest.prepareForCodegen(); + return nest.root_stmt(); +} + +static StmtPtr splitMaskReorder(Tensor b) { + constexpr int kVectorWidth = 8; + LoopNest nest({b}); + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; + nest.splitWithMask(loops[0], kVectorWidth); + loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; + LoopNest::reorderAxis(loops[1], loops[2]); + nest.prepareForCodegen(); + return nest.root_stmt(); +} + +static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) { + int M = immediateAs(p.dim(0)); + int N = immediateAs(p.dim(1)); + PaddedBuffer a(M, N); + PaddedBuffer b(N); + PaddedBuffer ref(N); + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a(i, j) = 1.0f; + } + } + for (int i = 0; i < N; i++) { + b(i) = 0.0f; + } + for (int i = 0; i < N; i++) { + ref(i) = 76.0f; + } + SimpleIREvaluator(s, {p, t}).call({a, b}); + ExpectAllNear(b, ref, 1e-5); +} + +TEST(LoopNest, ColReduceSplitTailEvenReorder) { + constexpr int M = 76, N = 128; + auto p = colReduce(M, N); + StmtPtr s = splitTailReorder(p.second); + + std::ostringstream oss; + oss << *s; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i_outer +# CHECK-NEXT: for (int i_inner +# CHECK-NEXT: b[ +# CHECK: for (int j +# CHECK-NEXT: for (int i_inner +# CHECK-NEXT: b[ +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + checkColReduce(s, p.first, p.second); +} + +TEST(LoopNest, ColReduceSplitTailUnevenReorder) { + constexpr int M = 76, N = 100; + auto p = colReduce(M, N); + StmtPtr s = splitTailReorder(p.second); + + std::ostringstream oss; + oss << *s; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i_outer +# CHECK-NEXT: for (int i_inner +# CHECK-NEXT: b[ +# CHECK: for (int j +# CHECK-NEXT: for (int i_inner +# CHECK-NEXT: b[ +# CHECK: for (int i_tail +# CHECK-NEXT: b[ +# CHECK-NEXT: for (int j +# CHECK-NEXT: b[ + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + checkColReduce(s, p.first, p.second); +} + +TEST(LoopNest, ColReduceSplitMaskEvenReorder) { + constexpr int M = 76, N = 128; + auto p = colReduce(M, N); + StmtPtr s = splitMaskReorder(p.second); + checkColReduce(s, p.first, p.second); +} + +TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { + constexpr int M = 76, N = 100; + auto p = colReduce(M, N); + StmtPtr s = splitMaskReorder(p.second); + checkColReduce(s, p.first, p.second); +} + +TEST(LoopNest, ReorderAxisWithMultipleConds) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // if i > 5 { + // if i < 10 { + // for (int j = 0; j < 100; j++) { + // A[i] = i * j; + // } + // } + // } + // } + BufHandle a_buf("A", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j))); + auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr); + auto outer_cond = + Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr); + auto forI = For::make(i, 0, 20, outer_cond); + StmtPtr par = Block::make({forI}); + LoopNest l(par, {a_buf.node()}); + LoopNest::reorderAxis(forI, forJ); + ASSERT_EQ(par, l.root_stmt()); + par = IRSimplifier::simplify(par); + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: for (int i +# CHECK-NEXT: if (i>5 +# CHECK-NEXT: if (i<10 +# CHECK-NEXT: A[i] = i * j +# CHECK-NOT: for ( + )IR"; + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, VectorizeUse) { + constexpr int N = 8; + BufHandle a("a", {N}, kFloat); + Tensor b = + Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); + Tensor c = + Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; }); + LoopNest nest({c}, {b, c}); + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; + ASSERT_TRUE(LoopNest::vectorize(loops[0])); + loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0]; + ASSERT_TRUE(LoopNest::vectorize(loops[0])); + nest.prepareForCodegen(); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + StmtPtr s = nest.root_stmt(); + std::ostringstream oss; + oss << *nest.root_stmt(); + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: c[Ramp +)IR", + oss.str()); +} + +const char* int64Loop = R"IR( +# CHECK: for (int64_t i = 0ll; i < 12ll; i++) { +# CHECK: b[i] = (a[i]) + 1ll; +# CHECK: } +)IR"; + +TEST(LoopNest, Int64Direct) { + constexpr int64_t N = 12; + BufHandle a("a", {N}, kLong); + BufHandle b("b", {N}, kLong); + VarHandle n("i", kLong); + StmtPtr s = For::make( + n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run(int64Loop, oss.str()); +} + +TEST(LoopNest, Int64Compute) { + constexpr int64_t N = 12; + BufHandle a("a", {N}, kLong); + Tensor b = Compute("b", {N}, [&](const VarHandle& n) { + return a.load(n) + LongImm::make(1l); + }); + LoopNest nest({b}); + nest.prepareForCodegen(); + nest.simplify(); + std::ostringstream oss; + oss << *nest.root_stmt(); + torch::jit::testing::FileCheck().run(int64Loop, oss.str()); +} + +TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0; + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + i * j; + // } + // B[i] = A[i]; + // for (int k = 0; k < 50; k++) { + // B[i] = B[i] + i * k; + // } + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + auto par = Block::make({forI}); + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0 +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = +# CHECK: for (int i +# CHECK-NEXT: B[i] = A[i] +# CHECK: for (int i +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB}); + + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(new_loops.front(), forI); +} + +TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0; + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + i * j; + // } + // B[i] = A[i]; + // for (int k = 0; k < 50; k++) { + // B[i] = B[i] + i * k; + // } + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto new_loops = LoopNest::distributeLoop(forI, {forJ}); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0 +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = +# CHECK: for (int i +# CHECK-NEXT: B[i] = A[i] +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(new_loops.front(), forI); +} + +TEST(LoopNest, DistributeLoopWithoutAnyPivot) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0; + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + i * j; + // } + // B[i] = A[i]; + // for (int k = 0; k < 50; k++) { + // B[i] = B[i] + i * k; + // } + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + auto par = Block::make({forI}); + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0 +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = +# CHECK: for (int i +# CHECK-NEXT: B[i] = A[i] +# CHECK: for (int i +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto new_loops = LoopNest::distributeLoop(forI); + + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(new_loops.front(), forI); +} + +TEST(LoopNest, DistributeLoopOverInnerLoops) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0; + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + i * j; + // } + // B[i] = A[i]; + // for (int k = 0; k < 50; k++) { + // B[i] = B[i] + i * k; + // } + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0 +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = +# CHECK: for (int i +# CHECK-NEXT: B[i] = A[i] +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(new_loops.front(), forI); +} + +TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { + // Input IR: + // for (int m = 0; m < 50; m++) { + // for (int i = 0; i < 20; i++) { + // A[m,i] = 0; + // for (int j = 0; j < 100; j++) { + // A[m,i] = A[m,i] + i * j; + // } + // B[m,i] = A[m,i]; + // for (int k = 0; k < 50; k++) { + // B[m,i] = B[m,i] + i * k; + // } + // } + // } + BufHandle a_buf("A", {100, 100}, kInt); + BufHandle b_buf("B", {100, 100}, kInt); + VarHandle m("m", kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {m, i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, + {m, i}, + Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, + {m, i}, + Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + + { + // Check the case of distributing loop and its parents over all the + // statements in the loop. + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: A[m, i] = 0 +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[m, i] = +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: B[m, i] = A[m, i] +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[m, i] = +# CHECK-NOT: for ( + )IR"; + + auto newForI = to(Stmt::clone(forI)); + auto forM = For::make(m, 0, 50, newForI); + auto par = Block::make({forM}); + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto newLoops = LoopNest::distributeLoopAndParents(newForI); + + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(newLoops.front(), forM); + } + + { + // Check the case of distributing loop and its parents over all the inner + // loops. + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: A[m, i] = 0 +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[m, i] = +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: B[m, i] = A[m, i] +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[m, i] = +# CHECK-NOT: for ( + )IR"; + + auto newForI = to(Stmt::clone(forI)); + auto forM = For::make(m, 0, 50, newForI); + auto par = Block::make({forM}); + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI); + + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(newLoops.front(), forM); + } +} + +TEST(LoopNest, fuseLoopsSimple) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: B[j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsMultiple) { + // Input IR: + // for (int i = 0; i < 100; i++) { + // A[i+100] = 20 + i; + // } + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {200}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forI = + For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i))); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); + auto par = Block::make({forI, forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i + 100] = +# CHECK-NEXT: A[i] = +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsNested) { + // Input IR: + // for (int m = 0; m < 20; m++) { + // A[m] = 0; + // for (int j = 0; j < 100; j++) { + // A[m] = A[m] + m * j; + // } + // } + // for (int n = 0; n < 20; n++) { + // B[n] = A[n]; + // for (int k = 0; k < 50; k++) { + // B[n] = B[n] + n * k; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 100}, kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {m}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); + auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); + auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); + auto forN = For::make(n, 0, 20, Block::make({initB, forK})); + auto par = Block::make({forM, forN}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m +# CHECK-NEXT: A[m] = 0 +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[m] = +# CHECK: B[m] = A[m] +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[m] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forM); +} + +TEST(LoopNest, fuseLoopsNested2D) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 50; n++) { + // B[m,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto forI = For::make( + i, + 0, + 20, + For::make( + j, + 0, + 100, + Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); + auto forM = For::make( + m, + 0, + 20, + For::make( + n, + 0, + 50, + Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); + auto par = Block::make({forI, forM}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK: for (int n +# CHECK-NEXT: B[i, n] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsNested2DInner) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // for (int n = 0; n < 100; n++) { + // B[i,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle n("n", kInt); + auto forJ = For::make( + j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); + auto forN = For::make( + n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); + auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); + + std::ostringstream oss; + oss << *forI; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK-NEXT: B[i, j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsDifferentStopBounds) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < 50; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsDifferentStartBounds) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 50; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsNotContiguous) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // B[0] = 0; + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto initB = Store::make(b_buf, {0}, 0); + auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, initB, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsWithDifferentParents) { + // Input IR: + // for (int i = 0; i < 50; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j; + // } + // } + // B[0] = 0; + // for (int k = 50; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {50, 100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j))); + auto forI = For::make(i, 0, 50, forJ); + auto initB = Store::make(b_buf, {0}, 0); + auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI, initB, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsWithVariableBounds) { + // Input IR: + // for (int j = 0; j < N; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < N; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle N("N", kInt); + auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j))); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) + auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: B[j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsWithExprBounds) { + // Input IR: + // for (int j = 0; j < M + N; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < M + N; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle M("M", kInt); + VarHandle N("N", kInt); + auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: B[j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { + // Input IR: + // for (int j = M; j < N * 2; j++) { + // A[j] = 10 * j; + // } + // for (int k = M; k < N + N; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle M("M", kInt); + VarHandle N("N", kInt); + auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j))); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) + auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: B[j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { + // Input IR: + // for (int j = 10; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 10; k < 100; k++) { + // A[k+100] = 30 * k + // } + BufHandle a_buf("A", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k))); + auto par = Block::make({forJ, forK}); + + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: A[j + 100] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 50; n++) { + // A[m+20,n+100] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); + auto forJ = For::make(j, 0, 100, storeA1); + auto forI = For::make(i, 0, 20, forJ); + auto storeA2 = + Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); + auto forN = For::make(n, 0, 50, storeA2); + auto forM = For::make(m, 0, 20, forN); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK: for (int n +# CHECK-NEXT: A[i + 20, n + 100] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsWithReductions) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0 + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + B[i,j]; + // } + // } + // for (int m = 0; m < 20; m++) { + // C[m] = A[m]; + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20, 100}, kInt); + BufHandle c_buf("C", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto sumA = Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j}))); + auto forJ = For::make(j, 0, 100, sumA); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ})); + auto forM = + For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); + auto par = Block::make({forI, forM}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = (A[i]) + +# CHECK-NOT: for ( +# CHECK: C[i] = A[i] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsWith2DReductions) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 50; j++) { + // A[i,j] = 0 + // for (int k = 0; k < 100; k++) { + // A[i,j] = A[i,j] + B[i,j,k]; + // } + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 40; n++) { + // C[m,n] = A[m,n]; + // } + // } + BufHandle a_buf("A", {20, 50}, kInt); + BufHandle b_buf("B", {20, 50, 100}, kInt); + BufHandle c_buf("C", {20, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto initA = Store::make(a_buf, {i, j}, 0); + auto sumA = Store::make( + a_buf, + {i, j}, + Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k}))); + auto forK = For::make(k, 0, 100, sumA); + auto forJ = For::make(j, 0, 50, Block::make({initA, forK})); + auto forI = For::make(i, 0, 20, forJ); + auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n})); + auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC)); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK-NEXT: for (int k +# CHECK-NEXT: A[i, j] = (A[i, j]) + +# CHECK: for (int n +# CHECK-NEXT: C[i, n] = A[i, n] +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsWithComplexIndices) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 20; j++) { + // A[i,j*20+j+2] = i + j; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 20; n++) { + // B[m,n] = A[m,n*20+n+2]; + // } + // } + BufHandle a_buf("A", {20, 400}, kInt); + BufHandle b_buf("B", {20, 400}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j); + auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); + auto storeB = + Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2})); + auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j +# CHECK: for (int n +# CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2] +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 20; j++) { + // A[i,i*20+j] = i + j; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 20; n++) { + // B[m,n] = A[m,m*20+n]; // Both indices of A use m + // } + // } + BufHandle a_buf("A", {20, 500}, kInt); + BufHandle b_buf("B", {20, 500}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j); + auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); + auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n})); + auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsWithTranspose) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 20; j++) { + // A[i,j] = i + j; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 20; n++) { + // B[m,n] = A[n,m]; // Transpose + // } + // } + BufHandle a_buf("A", {20, 20}, kInt); + BufHandle b_buf("B", {20, 20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto writeA = Store::make(a_buf, {i, j}, i + j); + auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); + auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m})); + auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies1) { + // Input IR: + // for (int j = 10; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 10; k < 100; k++) { + // A[k-1] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies2) { + // Input IR: + // for (int j = 10; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 10; k < 100; k++) { + // A[k+50] = 20 * k; + // } + BufHandle a_buf("A", {150}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies3) { + // Input IR: + // for (int m = 0; m < 20; m++) { + // A[m] = 0; + // for (int j = 0; j < 100; j++) { + // A[m] = A[m] + m * j; + // } + // } + // for (int n = 0; n < 20; n++) { + // B[n] = A[n+1]; + // for (int k = 0; k < 50; k++) { + // B[n] = B[n] + n * k; + // } + // } + BufHandle a_buf("A", {25, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {m}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); + auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); + auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); + auto forN = For::make(n, 0, 20, Block::make({initB, forK})); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forM, forN}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies4) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 50; n++) { + // A[m+1,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {30, 100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto forI = For::make( + i, + 0, + 20, + For::make( + j, + 0, + 100, + Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); + auto forM = For::make( + m, + 0, + 20, + For::make( + n, + 0, + 50, + Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI, forM}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies5) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // for (int n = 0; n < 100; n++) { + // A[i,n+1] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle n("n", kInt); + auto forJ = For::make( + j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); + auto forN = For::make( + n, + 0, + 100, + Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100)))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies6) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * A[99-k]; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make( + k, + 0, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies7) { + // Input IR: + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * A[99-k]; + // } + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forK = For::make( + k, + 0, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forK, forJ}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); +} + +TEST(LoopNest, areLoopsPerfectlyNested) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI}); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); + + // Specifying the loops in any other order fails. + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK})); + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ})); + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI})); + + // Adding a statement to forK body should be OK. + auto init = Store::make(a_buf, {i, j}, 0); + forK->body()->insert_stmt_before(init, store); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); + + // Adding a statement in forJ body should fail this test. + forK->body()->remove_stmt(init); + forJ->body()->insert_stmt_before(init, forK); + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); + + // Similarly, adding a statement in forI body should fail this test. + forJ->body()->remove_stmt(init); + forI->body()->insert_stmt_before(init, forJ); + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); +} + +TEST(LoopNest, reorderNestedLoops2D) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto store = Store::make(a_buf, {i, j}, Mul::make(i, j)); + auto forJ = For::make(j, 0, 30, store); + auto forI = For::make(i, 0, 20, forJ); + auto par = Block::make({forI}); + + auto reordered = LoopNest::reorder({forI, forJ}, {1, 0}); + + ASSERT_EQ(reordered[0], forJ); + ASSERT_EQ(reordered[1], forI); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI})); + ASSERT_EQ(forJ->get_parent(), par); + ASSERT_EQ(store->get_parent(), forI->body()); +} + +TEST(LoopNest, reorderNestedLoops3D) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + auto par = Block::make({forI}); + + auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1}); + + ASSERT_EQ(reordered[0], forK); + ASSERT_EQ(reordered[1], forI); + ASSERT_EQ(reordered[2], forJ); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ})); + ASSERT_EQ(forK->get_parent(), par); + ASSERT_EQ(store->get_parent(), forJ->body()); +} + +TEST(LoopNest, reorderNestedLoops4D) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // for (int l = 0; l < 50; l++) { + // A[i,j,k,l] = i * j * k * l * 500; + // } + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle l("l", kInt); + auto store = Store::make( + a_buf, + {i, j, k, l}, + Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500)); + auto forL = For::make(l, 0, 50, store); + auto forK = For::make(k, 0, 40, forL); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + auto par = Block::make({forI}); + + auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1}); + + ASSERT_EQ(reordered[0], forK); + ASSERT_EQ(reordered[1], forI); + ASSERT_EQ(reordered[2], forL); + ASSERT_EQ(reordered[3], forJ); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ})); + ASSERT_EQ(forK->get_parent(), par); + ASSERT_EQ(store->get_parent(), forJ->body()); +} + +TEST(LoopNest, reorderTrivialPermutation) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + auto par = Block::make({forI}); + + auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2}); + + ASSERT_EQ(reordered[0], forI); + ASSERT_EQ(reordered[1], forJ); + ASSERT_EQ(reordered[2], forK); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); + ASSERT_EQ(forI->get_parent(), par); + ASSERT_EQ(store->get_parent(), forK->body()); +} + +TEST(LoopNest, reorderInvalidPermutations) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI}); + + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}), + "invalid permutation size"); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {1, 2}), + "invalid permutation size"); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}), + "invalid permutation for reorder"); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}), + "invalid permutation for reorder"); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}), + "invalid permutation for reorder"); +} + +TEST(LoopNest, reorderInvalidLoopNest) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // A[i,j] = 0 + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI}); + + // Specifying the loops in incorrect order fails. + ASSERT_THROWS_WITH( + LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}), + "reorder is only allowed on perfectly nested loops"); + + // Adding a statement to forJ loop fails. + auto init = Store::make(a_buf, {i}, 0); + forJ->body()->insert_stmt_before(init, forK); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), + "reorder is only allowed on perfectly nested loops"); + + // Moving that statement to forI loop also fails. + forJ->body()->remove_stmt(init); + forI->body()->insert_stmt_before(init, forJ); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), + "reorder is only allowed on perfectly nested loops"); +} + +TEST(LoopNest, compressBufferSimple) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // } + // for (int j = 0; j < 199; ++j) { + // B[i,j] = A[i,j] + A[i, j+1] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + 199, + Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); + auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[0, j] = +# CHECK: for (int j +# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); +} + +TEST(LoopNest, compressBufferMultipleDims) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // B[i,j] = A[i,j] + A[i,j] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto store1 = Store::make(aBuf, {i, j}, sin(i * j)); + auto store2 = Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j}))); + auto forJ = For::make(j, 0, 200, Block::make({store1, store2})); + auto forI = For::make(i, 0, 100, forJ); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[0, 0] = +# CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); +} + +TEST(LoopNest, compressBufferMultipleDims2) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // for (int k = 0; k < 300; ++k) { + // A[i,j,k] = sin(i*j*k) + // } + // for (int k = 0; k < 299; ++j) { + // B[i,j,k] = A[i,j,k] + A[i,j,k+1] + // } + // } + // } + BufHandle aBuf("A", {100, 200, 300}, kInt); + BufHandle bBuf("B", {100, 200, 300}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k)); + auto forK1 = For::make(k, 0, 300, store1); + auto store2 = Store::make( + bBuf, + {i, j, k}, + Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1}))); + auto forK2 = For::make(k, 0, 299, store2); + auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2})); + auto forI = For::make(i, 0, 100, forJ); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: for (int k +# CHECK-NEXT: A[0, 0, k] = +# CHECK: for (int k +# CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 3); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300); +} + +TEST(LoopNest, compressBufferDifferentOrderIndices) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[j, i] = sin(i*j) + // } + // for (int j = 0; j < 99; ++j) { + // B[i, j] = A[j, i] + A[j+1, 0] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + 99, + Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i})))); + auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[j, 0] = +# CHECK: for (int j +# CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); +} + +TEST(LoopNest, compressBufferVariableBounds) { + // Input IR: + // for (int i = 0; i < M; ++i) { + // for (int j = 0; j < N; ++j) { + // A[i,j] = sin(i*j) + // } + // for (int j = 0; j < N-1; ++j) { + // B[i,j] = A[i,j] + A[i, j+1] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle M("M", kInt); + VarHandle N("N", kInt); + auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + N - 1, + Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2})); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[0, j] = +# CHECK: for (int j +# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); +} + +TEST(LoopNest, compressBufferNoCommonParentLoops) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // } + // } + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 199; ++j) { + // B[i,j] = A[i,j] + A[i, j+1] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + 199, + Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); + auto forI1 = For::make(i, 0, 100, forJ1); + auto forI2 = For::make(i, 0, 100, forJ2); + auto par = Block::make({forI1, forI2}); + LoopNest::compressBuffer(aBuf.node(), par); + + // There should be no change in the buffer or code. + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); +} + +TEST(LoopNest, compressBufferIndicesMixed) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i + j, j] = sin(i*j) + // } + // for (int j = 0; j < 199; ++j) { + // B[i,j] = A[i + j, j] + A[i + j, j+1] + // } + // } + BufHandle aBuf("A", {300, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + 199, + Store::make( + bBuf, + {i, j}, + Add::make( + Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1})))); + auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + // There should be no change in the buffer or code. + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i + j, j] = +# CHECK: for (int j +# CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); +} + +TEST(LoopNest, compressMultipleBuffers) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // } + // for (int k = 0; k < 199; ++k) { + // B[i,k] = A[i,k] + A[i, k+1] + // } + // for (int m = 0; m < 50; ++m) { + // C[i,m] = B[i,m] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + BufHandle cBuf("C", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle m("m", kInt); + auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); + auto forK = For::make( + k, + 0, + 199, + Store::make( + bBuf, + {i, k}, + Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1})))); + auto forM = + For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m}))); + auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM})); + auto par = Block::make({forI}); + + // This should compress all buffers A, B, and C as follows: + // A[100, 200] -> A[1, 200] + // B[100, 200] -> B[1, 200] + // C[100, 200] -> C[1, 1] + LoopNest::compressAllBuffers(par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[0, j] = +# CHECK: for (int k +# CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1]) +# CHECK: for (int m +# CHECK-NEXT: C[0, 0] = B[0, m] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); + ASSERT_EQ(bBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200); + ASSERT_EQ(cBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1); +} + +TEST(LoopNest, sanitizeNames) { + std::vector dim_args; + // Let's pick names that would overlap with default index names if not + // sanitized properly: + dim_args.emplace_back(ExprHandle(alloc("i", kInt))); + dim_args.emplace_back(ExprHandle(alloc("N:2", kInt))); + // Now let's create a many dimensions so that we had to use the same letter + // for different loops + for (int i = 0; i < 10; i++) { + dim_args.emplace_back(ExprHandle(alloc("N", kInt))); + } + + // Now create two Computes with conflicting after sanitization names: + Tensor X = Compute("$X:!", dim_args, [&](const std::vector& v) { + return v[0] + v[1] + v[9] + 1; + }); + Tensor Y = Reduce( + "%X\"+", + {}, + Sum(), + [&](const std::vector& v) { return X.load(v); }, + dim_args); + + // Finally, let's verify what we got after sanitization: + LoopNest l({X, Y}); + StmtPtr s = l.root_stmt(); + LoopNest::sanitizeNames(s); + + std::ostringstream oss; + oss << *s; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < i_1; i++) { +# CHECK-NEXT: for (int j = 0; j < N_2_1; j++) { +# CHECK-NEXT: for (int k = 0; k < N_9; k++) { +# CHECK-NEXT: for (int l = 0; l < N_8; l++) { +# CHECK-NEXT: for (int m = 0; m < N_7; m++) { +# CHECK-NEXT: for (int n = 0; n < N_6; n++) { +# CHECK-NEXT: for (int o = 0; o < N_5; o++) { +# CHECK-NEXT: for (int p = 0; p < N_4; p++) { +# CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) { +# CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) { +# CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) { +# CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) { +# CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1; +# CHECK: v_X___1 = int(0); +# CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) { +# CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) { +# CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) { +# CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) { +# CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) { +# CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) { +# CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) { +# CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) { +# CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) { +# CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) { +# CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) { +# CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) { +# CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1}); + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp new file mode 100644 index 000000000000..5db84eab1f50 --- /dev/null +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -0,0 +1,3252 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +// Test helper function used to determine if two regions of a buffer have an +// overlap. No Overlap & partial overlap is obvious. Contains means A is +// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal +// ranges are ContainedOrEqual. +TEST(MemDependency, BoundOverlap) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + + // Sanity check 3 overlap cases. + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0))); + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1))); + + // Partial overlap works in either order. + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14))); + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10))); + + // Total Overlap works when one bound encloses the other, and returns which. + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16))); + + // Total overlap works when the bounds are an identical range, returns + // ContainedOrEqual. + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15))); + + // Total overlap when only one end of the bound matches. + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10))); + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15))); + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15))); + + // No overlap when a < b. + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130))); + + // No overlap when a > b. + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120))); + + // No overlap when adjacent. + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1))); + + // Partial overlap when middle bounds match. + ASSERT_EQ( + OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120))); + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4))); + ASSERT_EQ( + OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100))); + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2))); + + // Total overlap when one bound is single length over one end of the other. + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15))); + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15))); +} + +TEST(MemDependency, BoundComparison) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ)); + + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE)); + + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT)); + + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE)); + + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT)); + + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE)); +} + +TEST(MemDependency, BoundOverlapSymbolic) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + VarHandle w("w", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + // Sanity check cases where the start and end is symbolic but the diff is + // constant. + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x))); + ASSERT_EQ( + OverlapKind::PartialOverlap, + boundOverlap(CB(x, x + 3), CB(x + 2, x + 5))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1))); + + // We can't infer the sign of y, so cannot tell whether adding y is larger or + // smaller than y/2. + ASSERT_EQ( + OverlapKind::PartialOverlap, + boundOverlap(CB(x, x + y), CB(x, x + y / 2))); + + // No information about this bound, have to take the most conservative option: + // there may be an overlap. + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w))); + + // Math on opaque terms works. + ASSERT_EQ( + OverlapKind::ContainedOrEqual, + boundOverlap(CB(x + w, y - z), CB(x + w, y - z))); + // Even requiring simplification. + ASSERT_EQ( + OverlapKind::ContainedOrEqual, + boundOverlap(CB(x - w - w, y), CB(x - w * 2, y))); +} + +// Tests the helper function for overlap of multi dimensional indices bounds. +// This uses boundOverlap on each dimension and return the "lowest" kind of +// overlap. +TEST(MemDependency, BoundOverlapMultiDim) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + + // Sanity check one dimensional cases. + ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)})); + ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)})); + ASSERT_EQ( + OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)})); + + // Total overlap in 3 dims. + ASSERT_EQ( + OverlapKind::ContainedOrEqual, + overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)})); + ASSERT_EQ( + OverlapKind::ContainedOrEqual, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)})); + + // Total overlap in 2 dims, no overlap in another. + ASSERT_EQ( + OverlapKind::NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); + + // Total overlap in 2 dims, partial overlap in another. + ASSERT_EQ( + OverlapKind::PartialOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); + // This case is most important, so verify the overlap in any dim. (dim 2) + ASSERT_EQ( + OverlapKind::PartialOverlap, + overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)})); + // Dim 1. + ASSERT_EQ( + OverlapKind::PartialOverlap, + overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)})); + // Total overlap in 1 dim, partial in 2. + ASSERT_EQ( + OverlapKind::PartialOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)})); + // Total overlap, partial overlap, no overlap. + ASSERT_EQ( + OverlapKind::NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)})); + + // Total overlap (B) in 2 dims, total overlap (A) in another. + ASSERT_EQ( + OverlapKind::Contains, + overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)})); + + // Total overlap (A) in 2 dims, total overlap (B) in another. + ASSERT_EQ( + OverlapKind::Contains, + overlaps( + {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)})); + + // Total (B), No Overlap, Total (A). + ASSERT_EQ( + OverlapKind::NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)})); +} + +// Test the helper we use to subtract bounds: returns the regions(s) of A which +// remain after removing the region of B. +TEST(MemDependency, BoundSubtract) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + // One element subtract. + ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0); + ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0); + + // No Overlap. + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)})); + + // one side overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)})); + ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)})); + ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)})); + + // both sides overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {})); + + // internal overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)})); +} + +TEST(MemDependency, BoundSubtractSymbolic) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + VarHandle w("w", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + // One element subtract. + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {})); + + // Subtract constant range low. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)})); + // Subtract constant range high. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)})); + // Subtract constant range total overlap. + ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {})); + // Subtract constant range internal. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)), + {CB(x, x + 2), CB(x + 8, x + 10)})); + + // Size is inferable but not constant, only works with a single var. + ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)})); + + // Size is not inferable. + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)})); +} + +// Tests the helper function that does subtraction, but for multi dimensional +// indices bounds. +TEST(MemDependency, BoundSubtractMultiDim) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + auto EQ = [](std::vector x, std::vector y) { + if (x.size() != y.size()) { + return false; + } + for (auto i = 0U; i < x.size(); ++i) { + if (!indexBoundsEquals(x[i], y[i])) { + return false; + } + } + return true; + }; + + // sanity check one dimension. + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {})); + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {})); + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}})); + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}})); + + // Multi dim total overlap. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); + + // Multi dim one way partial in dim 1. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), + {{CB(4, 9), CB(0, 2)}})); + + // Multi dim one way partial in dim 2. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), + {{CB(0, 9), CB(11, 20)}})); + + // Partial overlap in 2 dims. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}), + {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}})); + + // Partial overlap in 3 dims. + ASSERT_TRUE( + EQ(subtractIndicesBounds( + {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}), + {{CB(0, 1), CB(0, 5), CB(0, 5)}, + {CB(2, 5), CB(0, 1), CB(0, 5)}, + {CB(2, 5), CB(2, 5), CB(0, 1)}})); +} + +// Tests the multi dimensional subtraction code for bounds that cannot be fully +// materialized. +TEST(MemDependency, BoundSubtractMultiDimSymbolic) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](std::vector x, std::vector y) { + if (x.size() != y.size()) { + return false; + } + for (auto i = 0U; i < x.size(); ++i) { + if (!indexBoundsEquals(x[i], y[i])) { + return false; + } + } + return true; + }; + + // Cannot determine overlaps. + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}})); + + // Various total Overlaps. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {})); + + // one-way overlap in first dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}), + {{CB(x - 4, x), CB(0, y)}})); + // second dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}), + {{CB(0, x), CB(0, 4)}})); + + // Internal overlap in first dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}), + {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}})); + // second dim. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}), + {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}})); + + // Overlap in both dimensions. + ASSERT_TRUE( + EQ(subtractIndicesBounds( + {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}), + { + {CB(0, 4), CB(0, y)}, + {CB(x - 4, x), CB(0, y)}, + {CB(0, x), CB(0, 9)}, + {CB(0, x), CB(y - 9, y)}, + })); +} + +// Simple check that the analyzer does anything at all... +TEST(MemDependency, MemDependencyCheckerSimple) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * B[0] = A[0] + 1; + */ + + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); + + StmtPtr stmt = Block::make({aStore, bStore}); + + stmt->accept(&analyzer); + + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); + // sanity check, but anything that depends directly must depend indirectly. + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore)); +} + +// Check that there is a difference between direct and indirect dependence. +TEST(MemDependency, MemDependencyCheckerMultiStmt) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * B[0] = A[0]; + * C[0] = B[0] + 1; + */ + + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); + StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)); + + StmtPtr stmt = Block::make({aStore, bStore, cStore}); + + stmt->accept(&analyzer); + + // C depends on A indirectly. + ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore)); + ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore)); + + // C depends on B directly, which depends on A directly. + ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore)); + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + + // Dependency goes top to bottom only. + ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore)); +} + +// Verify that we do filter writes that are totally overlapped by later writes. +TEST(MemDependency, MemDependencyCheckerOverlap) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * A[0] = 6; + * B[0] = A[0] + 1; + */ + + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr a2Store = Store::make(a, {0}, 6); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); + + StmtPtr stmt = Block::make({aStore, a2Store, bStore}); + + stmt->accept(&analyzer); + + // B store depends on second A store but not first since it is completely + // overlapped. + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store)); + ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore)); + + // No dependency between either A store. + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store)); + ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore)); +} + +// Verify that bounds match loop iterations, and that dependencies progress +// across loop scopes. +TEST(MemDependency, MemDependencyCheckerLoop) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * for (int x = 0; x < 10; ++x) { + * A[x] = x; + * } + * B[0] = A[0] + 1; + */ + + StorePtr aStore = Store::make(a, {x}, x); + StmtPtr loop = For::make(x, 0, 10, aStore); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1)); + + StmtPtr stmt = Block::make({loop, bStore}); + + stmt->accept(&analyzer); + + // Same A->B dependency. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop but does not depend on any loop iteration. + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop)); + + auto aStoreAccess = analyzer.accessFor(aStore); + ASSERT_NE(aStoreAccess, nullptr); + + // It should have bounds covering the range of x: 0 <= x < 10. + ASSERT_TRUE(indexBoundsEquals( + aStoreAccess->bounds(), {Bound(alloc(0), alloc(9))})); +} + +// Reductions should promote dependencies as well. +TEST(MemDependency, MemDependencyCheckerLoopReduce) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * A[0] = 0; + * for (int x = 0; x < 10; ++x) { + * A[0] = A[x] + 1; + * } + * B[0] = A[0]; + */ + + StorePtr aInit = Store::make(a, {0}, 0); + ExprHandle reduce = Sum()(a, 1, {x}, {x}); + StorePtr aReduce = Store::make(a, {0}, reduce); + StmtPtr loop = For::make(x, 0, 10, aReduce); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); + + StmtPtr stmt = Block::make({aInit, loop, bStore}); + + stmt->accept(&analyzer); + + // B -> A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); + + // B depends indirectly on the initializer of A, since the reduction depends + // on it. + ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); + + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop and depends on other iterations. + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); + + // The loop contents depend on the initializer too. + ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); + + // Find loads within the reduction: + auto reduceLoads = NodeFinder::find(reduce.node()); + // Pull out the access for the load inside the loop. + for (auto load : reduceLoads) { + auto loopLoad = analyzer.accessFor(load); + // It should have 10 element long bounds. + ASSERT_TRUE(indexBoundsEquals( + loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); + } +} + +// Lowering a reduction doesn't affect dependency analysis. +TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * A[0] = 0; + * for (int x = 0; x < 10; ++x) { + * A[0] = A[x] + 1; + * } + * B[0] = A[0]; + */ + + StorePtr aInit = Store::make(a, {0}, 0); + ExprHandle aLoad = Load::make(a, {x}); + StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); + StmtPtr loop = For::make(x, 0, 10, aReduce); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); + + StmtPtr stmt = Block::make({aInit, loop, bStore}); + + stmt->accept(&analyzer); + + // B -> A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); + + // B depends indirectly on the initializer of A, since the reduction depends + // on it. + ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); + + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop and depends on other iterations. + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); + + // The loop contents depend on the initializer too. + ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); + + // Pull out the access for the store inside the loop. + auto loopLoad = analyzer.accessFor(aLoad.node()); + // It should have 10 element long bounds. + ASSERT_TRUE(indexBoundsEquals( + loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); +} + +// Can determine dependencies of outputs, through to inputs. +TEST(MemDependency, MemDependencyCheckerInputsOutputs) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + // initialize analyzer with inputs and outputs. + analysis::MemDependencyChecker analyzer({a}, {b}); + + // Here's a Relu. + /* + * for (int x = 0; x < 10; ++x) { + * B[x] = Max(A[x], 0); + * } + */ + + ExprHandle aLoad = Load::make(a, {x}); + StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true)); + StmtPtr loop = For::make(x, 0, 10, bStore); + + StmtPtr stmt = Block::make({loop}); + + stmt->accept(&analyzer); + + // Output depends indirectly on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + // aLoad depends directly on the input A. + ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node())); + // bStore therefore depends directly on the input A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node())); + // The output depends directly on the store. + ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); + + // Check AccessInfo based overloads. + auto input = analyzer.input(a.node()); + auto output = analyzer.output(b.node()); + + // Output depends indirectly on input. + ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); + // Not directly. + ASSERT_FALSE(analyzer.dependsDirectly(output, input)); + // Not in reverse order. + ASSERT_FALSE(analyzer.dependsIndirectly(input, output)); + + // output -> bStore -> bLoad -> input. + auto storeAccess = analyzer.accessFor(bStore); + auto loadAccess = analyzer.accessFor(aLoad.node()); + + ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess)); + ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input)); +} + +// Can tell if an output does not depend on an input. +TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + // initialize analyzer with inputs and outputs. + analysis::MemDependencyChecker analyzer({a}, {b}); + + // Here's a dumb Relu. + /* + * for (int x = 0; x < 10; ++x) { + * B[x] = Max(x, 0); + * } + */ + + StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true)); + StmtPtr loop = For::make(x, 0, 10, bStore); + + StmtPtr stmt = Block::make({loop}); + + stmt->accept(&analyzer); + + // Output does not depend indirectly on input. + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node())); + + // The output still depends directly on the store. + ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); + + // Check AccessInfo based overloads. + auto input = analyzer.input(a.node()); + auto output = analyzer.output(b.node()); + + // Output does not depend indirectly on input. + ASSERT_FALSE(analyzer.dependsIndirectly(output, input)); +} + +// Verify different loop extents produce accesses with different bounds, and +// that later accesses find dependencies that overlap their entire bound range. +TEST(MemDependency, MemDependencyCheckerLoopBounds) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + using namespace analysis; + + MemDependencyChecker analyzer({a}, {c}); + + // This enables using the execution order of the loops to determine if some + // loops are self dependent or not. + analyzer.allowLoopExecutionOrderAnalysis(); + + /* + * for (int x = 1; x < 10; ++x) { + * B[x] = A[x]; + * } + * for (int x = 1; x < 9; ++x) { + * B[x] = B[x] * 2; + * } + * for (int x = 3; x < 4; ++x) { + * C[x] = A[x]; + * } + * for (int x = 0; x < 10; ++x) { + * C[x] = B[x]; + * } + */ + + std::vector stmts( + {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))), + For::make( + x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))), + For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))}); + + StmtPtr stmt = Block::make(stmts); + + stmt->accept(&analyzer); + + auto input = analyzer.input(a.node()); + auto output = analyzer.output(c.node()); + + // sanity check Output -> Input. + ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); + + // Check the For loop dependencies: + + // Last write to C depends on both writes to B since they contain the last + // write to at least one element. + ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1])); + ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0])); + + // The last write to C does not depend on the other write to C. + ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2])); + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + /* 0. Input: A[(0, 9)] - dependents: 1 5 + * 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2 + * 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7 + * 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4 + * 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7 + * 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6 + * 6. Store: C[(3, 3)] - depends on: 5 + * 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8 + * 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9 + * 9. Output: C[(0, 9)] - depends on: 8 + */ + + // Now let's look at the bounds of each access. + // There are 9 accesses in this Stmt, so this is exhaustive, we won't do this + // much. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 10); + VarPtr aVar = a.node()->base_handle(); + VarPtr bVar = b.node()->base_handle(); + VarPtr cVar = c.node()->base_handle(); + + // The first access is the input A. + ASSERT_EQ(history[0]->type(), AccessType::Input); + ASSERT_EQ(history[0]->var(), aVar); + // It has the bounds of the producing Input. + ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); + // sanity check the input we retrieved earlier matches. + ASSERT_EQ(history[0], input); + + // The second access is the load of A in the first loop. + ASSERT_EQ(history[1]->type(), AccessType::Load); + ASSERT_EQ(history[1]->var(), aVar); + // It has the bounds of the loop, i.e. start == 1. + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)})); + // It reads from A, so it should have a dependency on the last write to this + // range - with is the input. + ASSERT_EQ(history[1]->dependencies().size(), 1); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + // The third access is the store into B in the first loop. + ASSERT_EQ(history[2]->type(), AccessType::Store); + ASSERT_EQ(history[2]->var(), bVar); + // It also has the bounds of the loop, i.e. start == 1. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); + // The previous load is in its RHS, so it depends on it. + ASSERT_EQ(history[2]->dependencies().size(), 1); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The third access is the load from B in the second loop. + ASSERT_EQ(history[3]->type(), AccessType::Load); + ASSERT_EQ(history[3]->var(), bVar); + // It has the bounds of the second loop, i.e. >= 1 < 9. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)})); + // It reads from B in a smaller range, so should depend on the previous + // store. + ASSERT_EQ(history[3]->dependencies().size(), 1); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The fourth: the store to B in the second loop. + ASSERT_EQ(history[4]->type(), AccessType::Store); + ASSERT_EQ(history[4]->var(), bVar); + // It also has the bounds of the second loop. + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)})); + // The previous load is in its RHS, so it depends on it as before. + ASSERT_EQ(history[4]->dependencies().size(), 1); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // The fifth access is the load is from the 3rd loop, and skips previous B + // accesses. + ASSERT_EQ(history[5]->type(), AccessType::Load); + ASSERT_EQ(history[5]->var(), aVar); + // It has the bounds of the third loop: >= 3 < 4. + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)})); + // It depends on the last thing to write to A, which is the A input. + ASSERT_EQ(history[5]->dependencies().size(), 1); + ASSERT_TRUE(history[5]->hasDependency(history[0])); + + // Sixth: the store into the output C. + ASSERT_EQ(history[6]->type(), AccessType::Store); + ASSERT_EQ(history[6]->var(), cVar); + // It also has the bounds of the third loop. + ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)})); + // The previous load is in its RHS, so it depends on it as always. + ASSERT_EQ(history[6]->dependencies().size(), 1); + ASSERT_TRUE(history[6]->hasDependency(history[5])); + + // The seventh access is the load of B in the fourth loop. + ASSERT_EQ(history[7]->type(), AccessType::Load); + ASSERT_EQ(history[7]->var(), bVar); + // It has the bounds of the final loop, >= 0 < 10 + ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); + // The bounds of this read are larger than the bounds of the previous write, + // so it depends on both previous Stores to B. + ASSERT_EQ(history[7]->dependencies().size(), 2); + ASSERT_TRUE(history[7]->hasDependency(history[2])); + ASSERT_TRUE(history[7]->hasDependency(history[4])); + + // Eight: the final store into the output C. + ASSERT_EQ(history[8]->type(), AccessType::Store); + ASSERT_EQ(history[8]->var(), cVar); + // It also has the bounds of the final loop. + ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); + // The previous load is in its RHS, so it depends on it as always. + ASSERT_EQ(history[8]->dependencies().size(), 1); + ASSERT_TRUE(history[8]->hasDependency(history[7])); + + // The last access represents the output Buf. + ASSERT_EQ(history[9]->type(), AccessType::Output); + ASSERT_EQ(history[9]->var(), cVar); + // It has the bounds of the output Buf. + ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)})); + // sanity check the input we retrieved earlier matches. + ASSERT_EQ(history[9], output); + // It depends on the last write to C only. + ASSERT_EQ(history[9]->dependencies().size(), 1); + ASSERT_TRUE(history[9]->hasDependency(history[8])); +} + +// Verify that we can still infer bounds when the loop var is offset. +TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer({a}, {b}); + + // This enables using the execution order of the loops to determine if some + // loops are self dependent or not. + analyzer.allowLoopExecutionOrderAnalysis(); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * for (int x = 0; x < 9; x++) { + * A[x] = A[x + 1]; + * } + * for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + * for (int x = 0; x < 10; x++) { + * A[x] = A[9 - x]; + * } + * for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + */ + + StmtPtr stmt = Block::make( + {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), + For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))), + For::make( + x, + 0, + 9, + Store::make( + a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))), + For::make( + x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))), + For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))}); + + stmt->accept(&analyzer); + + // Sanity check output depends on Input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + /* 0. Input: A[(0, 9)] - dependents: 1 + * 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2 + * 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3 + * 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4 + * 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7 + * 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6 + * 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7 + * 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8 + * 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9 + * 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10 + * 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11 + * 11. Output: B[(0, 9)] - depends on: 10 + */ + + // Now let's look at the bounds of each access. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 12); + VarPtr aVar = a.node()->base_handle(); + VarPtr bVar = b.node()->base_handle(); + + // The first access is the input A. + ASSERT_EQ(history[0]->type(), AccessType::Input); + ASSERT_EQ(history[0]->var(), aVar); + // It has the bounds of the producing Input. + ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); + + // The second access is the load A[x-1]. + ASSERT_EQ(history[1]->type(), AccessType::Load); + ASSERT_EQ(history[1]->var(), aVar); + // It has the bounds of the loop modified by the offset of each index, in + // this case -1. + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); + // It depends on the input, but also the store in the same loop, since + // different iterations of the loop depend on each other. + ASSERT_EQ(history[1]->dependencies().size(), 2); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + ASSERT_TRUE(history[1]->hasDependency(history[2])); + + // The third access is the Store to A[x] in the first loop. + ASSERT_EQ(history[2]->type(), AccessType::Store); + ASSERT_EQ(history[2]->var(), aVar); + // It has no offset on x, so should have the same bounds as the loop. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); + + // The fourth access is the load A[x+1] in the second loop. + ASSERT_EQ(history[3]->type(), AccessType::Load); + ASSERT_EQ(history[3]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9) modified by the offset of each + // index, in this case 1. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)})); + // This load totally overlaps the previous write to A, so it depends only on + // it and not the input. + ASSERT_EQ(history[3]->dependencies().size(), 1); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The fifth access is the store to A[x] in the second loop. + ASSERT_EQ(history[4]->type(), AccessType::Store); + ASSERT_EQ(history[4]->var(), aVar); + // It has no offset on x, so should have the same bounds as the loop. + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)})); + + // The sixth access is the load to A[8 - x] in the third loop. + ASSERT_EQ(history[5]->type(), AccessType::Load); + ASSERT_EQ(history[5]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9) modified by the offset of each + // index, in this case 8 - x. + // This access has a negative stride, which will be normalized. + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)})); + // This load totally overlaps the most recent write to A, so it depends only + // on it and not the input or the first write to A. + ASSERT_EQ(history[5]->dependencies().size(), 1); + ASSERT_TRUE(history[5]->hasDependency(history[4])); + + // The seventh access is the store to A[9 - x] in the third loop. + ASSERT_EQ(history[6]->type(), AccessType::Store); + ASSERT_EQ(history[6]->var(), aVar); + // This store has a negative stride on it's indices, but is normalized + // internally. + ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)})); + + // The eighth access is the load A[9-x] in the second loop. + ASSERT_EQ(history[7]->type(), AccessType::Load); + ASSERT_EQ(history[7]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x, + // which essentially traverses the loop backwards. + ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); + // This Load has three write dependencies: + ASSERT_EQ(history[7]->dependencies().size(), 3); + // * The previous store (#6) for elements 1-9 + ASSERT_TRUE(history[7]->hasDependency(history[6])); + // * An earlier store (#4) covering element 0 + ASSERT_TRUE(history[7]->hasDependency(history[4])); + // * A future store inside this loop, since this loop modifies the buffer + // in a non distinct way (due to the load and store having different access + // strides). + ASSERT_TRUE(history[7]->hasDependency(history[8])); + + // The ninth access is the store to A[x] in the fourth loop. + ASSERT_EQ(history[8]->type(), AccessType::Store); + ASSERT_EQ(history[8]->var(), aVar); + // This store has a negative stride on it's indices, but is normalized + // internally. + ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); + + // The tenth and 11th accesses are the copy from A[x] to B[x]. + ASSERT_EQ(history[9]->type(), AccessType::Load); + ASSERT_EQ(history[9]->var(), aVar); + ASSERT_EQ(history[10]->type(), AccessType::Store); + ASSERT_EQ(history[10]->var(), bVar); + + // The last access represents the output Buf. + ASSERT_EQ(history[11]->type(), AccessType::Output); + ASSERT_EQ(history[11]->var(), bVar); + // It has the bounds of the output Buf. + ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)})); + // It depends on the last write to B only. + ASSERT_EQ(history[11]->dependencies().size(), 1); + ASSERT_TRUE(history[11]->hasDependency(history[10])); + + // ok that's enough of that. +} + +// Check many different cases of loop self dependency - when a load within a +// loop is dependent on a Store later in the same loop but in different +// iteration. This is affected by whether or not we can trust the execution +// order of the loop. +TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + using namespace analysis; + + // This check assumes that the Stmt has a single Store with a single Load on + // the RHS. + auto isSelfDependent = + [](const std::vector>& history) -> bool { + return history.front()->hasDependency(history.back()); + }; + + { + /* for (int y = 0; y < 10; y++) { + * A[y] = (A[y]) + 1; + * } */ + + // Not self dependent since all loop iterations use a different y. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + y, + 0, + 10, + Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))})); + + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int y = 0; y < 10; y++) { + * A[y + 1] = (A[y + 1]) + 1; + * } + */ + + // Not self dependent due to different y (with offset). + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + y, + 0, + 10, + Block::make( + {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))})); + + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + // Is self dependent since all loops use a common constant element of A. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[0] = (B[0]) + x; + * } + */ + + // Is not self dependent because there is no store to the buffer that is + // read. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))})); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[y] = (A[y]) + x; + * } + */ + + // Is self dependent since all loops use a common symbolic element of A. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 1]; + * } + */ + + // In this case it depends if we are considering execution order. + + MemDependencyChecker analyzer; + + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); + stmt->accept(&analyzer); + + // With analysis of order disabled, this is self dependent since the read + // from X+1 and the write to X+1 could be in reverse order. + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 1]; + * } + */ + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); + stmt->accept(&analyzer); + + // If order analysis is enabled, this is not dependent since the read for + // each element occurs before the write to that element. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + MemDependencyChecker analyzer; + + StmtPtr stmt = + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); + stmt->accept(&analyzer); + + // In this case, even with order analysis the Load is dependent on the + // Store, since the write to X occurs before the read from X. + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + */ + + // Still works if the execution order is reversed, so long as the read + // comes before the write. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); + stmt->accept(&analyzer); + + // However here was can determine the A store is earlier in the order than + // the load. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[8 - x] = A[9 - x]; + * } + */ + + // But not if it doesn't. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + */ + + // And not if we're not relying on execution order. + + MemDependencyChecker analyzer; + + StmtPtr stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 3; x < 10; x++) { + * A[x - 2] = A[x - 1]; + * } + */ + + // Forward order but negative indices. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = + For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1}))); + stmt->accept(&analyzer); + + // However here was can determine the A store is earlier in the order than + // the load. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2]; + * } + */ + + // With an access stride. + + MemDependencyChecker analyzer; + // Execution order doesn't matter since the read and the write are totally + // distinct. + + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 1]; + * } + */ + + // Here we can use the common stride of the accesses to determine they are + // distinct. + // Note, this is the only place (loop self dependency) we use this stride + // to avoid unnecessary dependence. + + MemDependencyChecker analyzer; + // Execution order doesn't matter since the read and the write are totally + // distinct. + + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 - 1]; + * } + */ + + // same if the read is behind the write so long as they are distinct. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 2]; + * } + */ + + // But not if the offset is in the stride. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 - 2]; + * } + */ + + // Works with negative offsets too. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 7]; + * } + */ + + // Detects accesses are distinct when offset is large but not a multiple + // of stride. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 4]; + * } + */ + + // Works with offsets which are multiples of the stride. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 6] = A[x * 6 + 5]; + * } + */ + + // detects accesses are distinct with large strides when the offset is + // within. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6]; + * } + */ + + // detects accesses are overlapping when stride is different but a + // multiple. + + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 4] = A[x * 2]; + * } + */ + + // still works when the read axis is the smaller stride. + + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6 + 1]; + * } + */ + + // detects accesses are distinct when stride is different but a multiple + // and there is an offset. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6 + 4]; + * } + */ + + // The smaller stride determines whether there is overlap. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2 + 3] = A[x * 6]; + * } + */ + + // The smaller stride determines whether there is overlap, not the larger. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 3 + 1]; + * } + */ + + // If they have strides with no common multiple > 1, they overlap. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 10]; + * } + */ + + // If the offset is greater than the size of the loop, they can't overlap. + + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[9 - x]; + * } + */ + + // If they have different execution orders they may overlap. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[19 - x * 2]; + * } + */ + + // Or they may not, depending on their start offset and strides. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x / 2] = A[x / 2]; + * } + */ + + // If the stride is not monotonic, they overlap. + + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x / 2] = A[x / 2] + 1; + * } + */ + + // If the stride is not monotonic, they overlap - even with an offset. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x % 2] = A[x % 2]; + * } + */ + + // Mod too... + + analysis::MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = y; x < z; x++) { + * A[x] = A[x + 1]; + * } + */ + + // Still works with symbolic loop extents. + + { + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + StmtPtr stmt = + For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + } +} + +// Verify that a strided access still works. +// TODO: actually this only works because of the size of the ranges, revisit +// this test after strided overlap is implemented. +TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { + BufHandle a("A", {20}, kInt); + BufHandle b("B", {20}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + MemDependencyChecker analyzer({a.node()}, {b.node()}); + StmtPtr stmt = Block::make( + {For::make( + x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), + For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))) + + }); + stmt->accept(&analyzer); + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 2 dependencies... the store in each loop. + auto outputAccess = analyzer.output(b.node()); + ASSERT_EQ(outputAccess->dependencies().size(), 2); +} + +/* TODO(nickg) - this test will fail due to the lack of stride math in Bound +TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { + BufHandle a("A", {20}, kInt); + BufHandle b("B", {20}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + analysis::MemDependencyChecker analyzer({a.node()}, {c.node()}); + StmtPtr stmt = Block::make( + {For::make( + x, + 0, + 10, + Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), + For::make( + x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))) + + }); + stmt->accept(&analyzer); + + std::cout << *stmt << "\n"; + for (auto& wi : analyzer.getHistory()) { + wi->print(); + } + } +}*/ + +// analysis on Stmts using Cond. +TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * C[0] = (B[0]) + 1; + * } else { + * C[0] = (B[1]) + 1; + * } + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)), + Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 3); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * C[x] = B[x]; + * } + * } else { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))), + For::make( + x, + 0, + 10, + Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 3); + + // TODO(nickg): actually since the true and false branch cover the total + // range of the first store this should have 2 dependencies, but we don't + // do that yet. + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Only has true branch. + + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + For::make( + x, + 0, + 10, + Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))), + nullptr)}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * } else { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Only has false branch. + + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + nullptr, + For::make( + x, + 0, + 10, + Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (C[0]<5 ? 1 : 0) { + * C[0] = 5; + * } + */ + + // Cond's Condition depends on a previous access. + + MemDependencyChecker analyzer({a}, {c}); + StorePtr initStore = Store::make(c, {x}, Load::make(a, {x})); + ExprHandle conditionalLoad = Load::make(c, {0}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, initStore), + Cond::make( + CompareSelect::make( + conditionalLoad, 5, CompareSelectOperation::kLT), + Store::make(c, {0}, 5), + nullptr)}); + + stmt->accept(&analyzer); + + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + + ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore)); + ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node())); + } +} + +// Stmts using IfThenElse. +TEST(MemDependency, MemDependencyCheckerIfThenElse) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1; + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + StorePtr ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Add::make(Load::make(b, {0}), 1), + Add::make(Load::make(b, {1}), 1))); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + ifStore}); + + stmt->accept(&analyzer); + + // Output C should have 2 dependencies, each of the two stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // Now we need to check the Store containing the IfThenElse. + auto ifStoreAccess = analyzer.accessFor(ifStore); + + // It should have 2 dependencies. + ASSERT_EQ(ifStoreAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * C[0] = (y < 5 ? (B[0]) + 1 : 42; + */ + + // If the load appears in only one side of an IfThenElse the output may be + // dependent on it. + + MemDependencyChecker analyzer({a, b}, {c}); + StorePtr ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Add::make(Load::make(b, {0}), 1), + 42)); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + ifStore}); + + stmt->accept(&analyzer); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = (x < 5 ? B[x] : A[x]; + * } + */ + + // In this case C is dependent on both A and B. + + // TODO: in cases like this it would be possible to split the range of B + // into two bounds, one dependent on A and one dependent on B. We'd need to + // examine conditions relative to previously encountered loop variables. I'm + // uncertain if this would be helpful. + + MemDependencyChecker analyzer({a, b}, {c}); + StorePtr ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Load::make(b, {x}), + Load::make(a, {x}))); + StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)}); + + stmt->accept(&analyzer); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } +} + +// Cutting a loop with single elem writes +TEST(MemDependency, MemDependencyCheckerCutLoop) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + * B[5] = 100; + */ + + // Cutting a loop with single element writes. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))), + Store::make(b, {5}, 100)}); + + stmt->accept(&analyzer); + + // Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 2 dependencies. + auto outputAccess = analyzer.output(b.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + } + + { + /* for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + * for (int x = 4; x < 7; x++) { + * B[x] = B[x] + 3; + * } + * B[5] = 100; + * B[6] = 101; + * B[7] = 102; + */ + + // Cutting a loop with a smaller loop but then totally overlap that second + // loop with one element writes. + + MemDependencyChecker analyzer({a}, {b}); + ForPtr firstLoop = + For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))); + StorePtr secondStore = + Store::make(b, {x}, Add::make(Load::make(b, {x}), 1)); + ForPtr secondLoop = For::make(x, 4, 7, secondStore); + + StmtPtr stmt = Block::make( + {firstLoop, + secondLoop, + Store::make(b, {4}, 100), + Store::make(b, {5}, 101), + Store::make(b, {6}, 102)}); + + stmt->accept(&analyzer); + + // Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 4 dependencies. + auto outputAccess = analyzer.output(b.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 4); + + // Second loop depends on first loop. + ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop)); + + // Output does not depend on second loop or store. + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop)); + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore)); + } +} + +// Dynamic shapes (load in indices). +TEST(MemDependency, MemDependencyCheckerDynamicShapes) { + BufHandle a("A", {100}, kInt); + BufHandle b("B", {100}, kInt); + BufHandle c("C", {100}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + { + /* for (int x = 0; x < B[0]; x++) { + * C[x] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 + * 1. Input: A[(0, 99)] - dependents: 3 + * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4 + * 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4 + * 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Output dependent on A input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + // Also dependent on B input to determine the size of the region written. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The accesses in the loop depend on the load in the stop condition. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // Make a load from B to compare against. + ExprHandle loadFromB = Load::make(b, {0}); + + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)})); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)})); + } + + { + /* for (int x = B[0]; x < B[1]; x++) { + * C[x] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, + Load::make(b, {0}), + Load::make(b, {1}), + Store::make(c, {x}, Load::make(a, {x})))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 3 + * 1. Input: A[(0, 99)] - dependents: 4 + * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5 + * 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5 + * 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5 + * 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6 + * 6. Output: C[(0, 99)] - depends on: 5 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 7); + + // The accesses in the loop depend on the load in the start condition. + ASSERT_TRUE(history[5]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[2])); + + // also the stop condition. + ASSERT_TRUE(history[5]->hasDependency(history[3])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // Make loads from B to compare against. + ExprHandle loadFromB0 = Load::make(b, {0}); + ExprHandle loadFromB1 = Load::make(b, {1}); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[B[x]]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 + * 1. Input: A[(0, 99)] - dependents: 3 + * 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4 + * 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4 + * 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads, the load of A depends on the load of B. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The loads in the indices depend on the relevant input buffer. + ASSERT_TRUE(history[3]->hasDependency(history[1])); + ASSERT_TRUE(history[2]->hasDependency(history[0])); + + // The load from B has the loop bounds. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + + // The load from A has bounds B[0] to B[9]. + ExprHandle loadFromB0 = Load::make(b, {0}); + ExprHandle loadFromB9 = Load::make(b, {9}); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[B[x]] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 3 + * 1. Input: A[(0, 99)] - dependents: 2 + * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4 + * 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4 + * 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads, neither load is dependent. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + ASSERT_FALSE(history[3]->hasDependency(history[2])); + ASSERT_FALSE(history[2]->hasDependency(history[3])); + + // The loads each depend on their relevant input. (but accesses are in a + // different order than the last case). + ASSERT_TRUE(history[3]->hasDependency(history[0])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The load from B has the loop bounds. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)})); + + // And so does the load from A. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[B[A[x]]] = x; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 3 + * 1. Input: A[(0, 99)] - dependents: 2 + * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4 + * 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4 + * 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // The outer load depends on the inner. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The loads each depend on their relevant input. (but accesses are in a + // different order than the last case). + ASSERT_TRUE(history[3]->hasDependency(history[0])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The load from A has the loop bounds. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + // The load from B as bounds A[0] to A[9]. + ExprHandle loadFromA0 = Load::make(a, {0}); + ExprHandle loadFromA9 = Load::make(a, {9}); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)})); + + // The store has bounds of B[A[0]] to B[A[9]]. + ExprHandle loadFromBA0 = Load::make(b, {loadFromA0}); + ExprHandle loadFromBA9 = Load::make(b, {loadFromA9}); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)})); + } +} + +// Verify multi dimensional bounds work. +TEST(MemDependency, MemDependencyCheckerMultiDim) { + int M = 10, N = 9, K = 12; + BufHandle a("A", {M, N, K}, kInt); + BufHandle b("B", {M, N, K}, kInt); + BufHandle c("C", {M, K}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 9; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, y, z] = A[x, y, z]; + * } + * } + * } + */ + // Full range. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + N, + For::make( + z, + 0, + K, + Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 5; x++) { + * for (int y = 0; y < 5; y++) { + * for (int z = 0; z < 5; z++) { + * B[x, y, z] = A[x, y, z]; + * } + * } + * } + */ + // Partial range. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + For::make( + z, + 0, + 5, + Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 12; y++) { + * B[x, 0, y] = A[x, 0, y]; + * } + * } + */ + + // Partial loops. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + N, + For::make( + y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 100; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, 0, z] = (A[x, 0, z]) + (C[x, z]); + * } + * } + * } + */ + + // Loops that don't correspond to an index, bufs with different + // dimensionality. + + MemDependencyChecker analyzer({a, c}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + 100, + For::make( + z, + 0, + K, + Store::make( + b, + {x, 0, z}, + Add::make( + Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on both inputs. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node())); + + // 6 accesses: 2 inputs, 2 loads, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // Simple chain from input to output over the A buf. + // history[0] is the C input, history[3] is the load from C. + ASSERT_TRUE(history[5]->hasDependency(history[4])); + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + // The store also depends on the load from the C input. + ASSERT_TRUE(history[4]->hasDependency(history[3])); + ASSERT_TRUE(history[3]->hasDependency(history[0])); + + // A Buf accesses. + ASSERT_TRUE( + EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); + + // C buf access. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 9; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]); + * } + * } + * } + */ + // Multi-dim reductions. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + N, + For::make( + z, + 0, + K, + Store::make( + b, + {x, 0, 0}, + Add::make( + Load::make(b, {x, y, z}), + Load::make(a, {x, y, z}))))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, 2 loads, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 5); + + // Simple chain from input to output. + ASSERT_TRUE(history[4]->hasDependency(history[3])); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[3]->hasDependency(history[1])); + ASSERT_TRUE(history[2]->hasDependency(history[0])); + + // The load from B depends on the store to B. + ASSERT_TRUE(history[1]->hasDependency(history[3])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)})); + } +} + +// Various tests using the external Compute/Reduce API. +TEST(MemDependency, MemDependencyCheckerComputeAPI) { + using namespace analysis; + + /* for (int m = 0; m < 4; m++) { + * for (int n = 0; n < 5; n++) { + * for (int k = 0; k < 6; k++) { + * broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]); + * } + * } + * } + * for (int m_1 = 0; m_1 < 4; m_1++) { + * for (int n_1 = 0; n_1 < 5; n_1++) { + * for (int k_1 = 0; k_1 < 6; k_1++) { + * d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1); + * } + * } + * } + */ + + // Can determine if 2 loops created by Compute are dependent. + BufHandle a_buf("a", {4, 5}, kFloat); + BufHandle b_buf("b", {5, 6}, kFloat); + Tensor c = Compute( + "broadcast_add", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + Tensor d = Compute( + "d", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c.load(m, n, k) + 1; + }); + + LoopNest l({d}, {c, d}); + + MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); + + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); + + // Second loop depends on first loop. + auto c_loop = l.getLoopStmtsFor(c)[0]; + auto d_loop = l.getLoopStmtsFor(d)[0]; + ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); +} + +TEST(MemDependency, MemDependencyCheckerComputeInline) { + using namespace analysis; + + /* for (int m = 0; m < 4; m++) { + * for (int n = 0; n < 5; n++) { + * for (int k = 0; k < 6; k++) { + * d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1); + * } + * } + * } + */ + + // Check inlining affects the number of accesses returned. + + BufHandle a_buf("a", {4, 5}, kFloat); + BufHandle b_buf("b", {5, 6}, kFloat); + Tensor c = Compute( + "broadcast_add", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + Tensor d = Compute( + "d", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c.load(m, n, k) + 1; + }); + + LoopNest l({d}, {c, d}); + l.computeInline(c.buf()); + + MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); + + // broadcast_add tensor should not appear in trace at all. + for (auto& wi : analyzer.getHistory()) { + ASSERT_NE(wi->var(), c.buf()->base_handle()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeSplit) { + using namespace analysis; + // Split an axis, so the number of loops != the number of dimensions. + + BufHandle a_buf("a", {4, 5}, kFloat); + BufHandle b_buf("b", {5, 6}, kFloat); + Tensor c = Compute( + "broadcast_add", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + + LoopNest l({c}); + + MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); + l.root_stmt()->accept(&analyzer_before); + + l.splitWithTail(l.getLoopStmtsFor(c)[0], 2); + + MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + stmt->accept(&analyzer_after); + + // Splitting should not change accesses at all. + auto history_before = analyzer_before.getHistory(); + auto history_after = analyzer_after.getHistory(); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + ASSERT_EQ( + history_before[i]->bounds().size(), history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + ASSERT_EQ( + history_before[i]->dependencies().size(), + history_after[i]->dependencies().size()); + ASSERT_EQ( + history_before[i]->dependents().size(), + history_after[i]->dependents().size()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeReorder) { + using namespace analysis; + // Reorder an axis, so the loop order doesn't match the indexing order. + + BufHandle a_buf("a", {4, 5}, kFloat); + BufHandle b_buf("b", {5, 6}, kFloat); + Tensor c = Compute( + "broadcast_add", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + + LoopNest l({c}); + + MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); + l.root_stmt()->accept(&analyzer_before); + + auto loops = l.getLoopStmtsFor(c); + l.reorderAxis(loops[0], loops[1]); + + MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + stmt->accept(&analyzer_after); + + // Reordering should not change accesses at all. + auto history_before = analyzer_before.getHistory(); + auto history_after = analyzer_after.getHistory(); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + ASSERT_EQ( + history_before[i]->bounds().size(), history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + ASSERT_EQ( + history_before[i]->dependencies().size(), + history_after[i]->dependencies().size()); + ASSERT_EQ( + history_before[i]->dependents().size(), + history_after[i]->dependents().size()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeReduce) { + using namespace analysis; + /* for (int l2 = 0; l2 < 2; l2++) { + * for (int n1 = 0; n1 < 3; n1++) { + * for (int m1 = 0; m1 < 6; m1++) { + * scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]); + * } + * } + * } + * for (int l1 = 0; l1 < 2; l1++) { + * sum[l1] = float(0); + * for (int n1_1 = 0; n1_1 < 3; n1_1++) { + * for (int m1_1 = 0; m1_1 < 6; m1_1++) { + * sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)), + * out_args={l1}, reduce_args={n1, m1}); + * } + * } + * } + */ + + // Can determine dependencies of a Reduction. + + BufHandle a("a", {2, 3, 6}, kFloat); + BufHandle b("b", {2, 3, 6}, kFloat); + + Tensor c = Compute( + "scale", + {2, 3, 6}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6}); + LoopNest l({d}, {c, d}); + + MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()}); + + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node())); + + // Second loop depends on first loop. + auto c_loop = l.getLoopStmtsFor(c)[0]; + auto d_loop = l.getLoopStmtsFor(d)[0]; + ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); + + // Reduction depends on both inputs. + auto reduces = NodeFinder::find(l.root_stmt()); + ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node())); +} + +TEST(MemDependency, MemDependencyCheckerComputeGEMM) { + int M = 1024; + int N = 1024; + int K = 2048; + using namespace analysis; + + BufHandle AP("A", {M, K}, kFloat); + BufHandle BP("B", {K, N}, kFloat); + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr m = loops[0]; + loop.splitWithMask(m, 4); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr n = loops[2]; + loop.splitWithMask(n, 16); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr mi = loops[1]; + ForPtr no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr ni = loops[3]; + ForPtr k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr mi = loops[2]; + ForPtr k = loops[3]; + loop.reorderAxis(mi, k); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + loop.cacheAccesses(CT.buf(), "C_regs", loops[2]); + } + + MemDependencyChecker analyzer_unlowered( + loop.getInputBufs(), loop.getOutputBufs()); + + MemDependencyChecker analyzer_lowered( + loop.getInputBufs(), loop.getOutputBufs()); + + // Test both unlowered and lowered form. + { + StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt()); + stmt->accept(&analyzer_unlowered); + + // Outputs depend on inputs. + ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node())); + ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node())); + + // The last write to gemm should cover the total bound of the output. + std::shared_ptr outputAccess = + analyzer_unlowered.output(CT.buf()); + // A single dependency. + ASSERT_EQ(outputAccess->dependencies().size(), 1); + + // dependencies is a set with 1 element, so can just deref begin(). + std::shared_ptr gemmStore = + outputAccess->dependencies().begin()->second; + // Check its a store. + ASSERT_EQ(gemmStore->type(), AccessType::Store); + + ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds())); + + // Likewise the first read from each input cover the entire range of the + // input. + auto aInput = analyzer_unlowered.input(AP.node()); + auto bInput = analyzer_unlowered.input(BP.node()); + + // A single dependent each. + ASSERT_EQ(aInput->dependents().size(), 1); + ASSERT_EQ(bInput->dependents().size(), 1); + + // They're both loads. + std::shared_ptr aLoad = aInput->dependents().begin()->second; + std::shared_ptr bLoad = bInput->dependents().begin()->second; + ASSERT_EQ(aLoad->type(), AccessType::Load); + ASSERT_EQ(bLoad->type(), AccessType::Load); + + ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds())); + ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds())); + } + + loop.prepareForCodegen(); + SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT}); + + // now check lowered dependency graph. + { + StmtPtr stmt = IRSimplifier::simplify(cg.stmt()); + stmt->accept(&analyzer_lowered); + + // Lowering will change the dimensionality of all bounds due to index + // flattening and will insert Allocates and Frees. + + auto history_before = analyzer_unlowered.getHistory(); + auto history_after = analyzer_lowered.getHistory(); + + ASSERT_EQ(history_before.size() + 2, history_after.size()); + + // Filter out the alloc/free; + auto isAllocFree = [](const auto& info) { + return info->type() == AccessType::Alloc || + info->type() == AccessType::Free; + }; + history_after.erase( + std::remove_if(history_after.begin(), history_after.end(), isAllocFree), + history_after.end()); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + + if (history_before[i]->dependencies().size() != + history_after[i]->dependencies().size()) { + // Must depend on an Alloc. + ASSERT_TRUE(std::any_of( + history_after[i]->dependencies().begin(), + history_after[i]->dependencies().end(), + [](const auto& pair) { + return pair.second->type() == AccessType::Alloc; + })); + + ASSERT_EQ( + history_before[i]->dependencies().size() + 1, + history_after[i]->dependencies().size()); + } + + if (history_before[i]->dependents().size() != + history_after[i]->dependents().size()) { + // Must depend on an Free. + ASSERT_TRUE(std::any_of( + history_after[i]->dependents().begin(), + history_after[i]->dependents().end(), + [](const auto& pair) { + return pair.second->type() == AccessType::Free; + })); + + ASSERT_EQ( + history_before[i]->dependents().size() + 1, + history_after[i]->dependents().size()); + } + + // Inputs and outputs are not flattened, only accesses. + if (history_before[i]->type() == AccessType::Input || + history_before[i]->type() == AccessType::Output) { + ASSERT_EQ( + history_before[i]->bounds().size(), + history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + } else { + ASSERT_EQ(history_after[i]->bounds().size(), 1); + ExprPtr flat_bounds = alloc(1); + + for (auto& b : history_before[i]->bounds()) { + flat_bounds = + alloc(flat_bounds, alloc(b.end, alloc(1))); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start)); + } + + flat_bounds = IRSimplifier::simplify(flat_bounds); + ExprPtr after_bounds = IRSimplifier::simplify( + alloc(history_after[i]->bounds()[0].end, alloc(1))); + ASSERT_TRUE(exprEquals(flat_bounds, after_bounds)); + } + } + } +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_memplanning.cpp b/test/cpp/tensorexpr/test_memplanning.cpp new file mode 100644 index 000000000000..f5ee8747650f --- /dev/null +++ b/test/cpp/tensorexpr/test_memplanning.cpp @@ -0,0 +1,708 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +extern void checkIR(StmtPtr s, const std::string& pattern); + +TEST(BufLiveRange, SingleRangeLine) { + VarHandle i("i", kInt), j("j", kInt); + BufHandle a("a", {32}, kFloat); + BufHandle b("b", {32, 32}, kFloat); + + // Construct Stmt: + // { + // for (int i = 0; i < 32; i++) { + // a[i] = 0; + // for (int j = 0; j < 32; j++) { + // a[i] = (a[i]) + (b[i, j]); + // } + // } + // } + + StorePtr aInit = Store::make(a, {i}, 0); + ExprHandle reduce = a.load({i}) + b.load({i, j}); + StorePtr aReduce = Store::make(a, {i}, reduce); + StmtPtr loop = + For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)})); + + StmtPtr stmt = Block::make({loop}); + + auto range = BufLiveRange::liveRange(stmt, a.node()); + ASSERT_TRUE(std::get<0>(range) == 0); + ASSERT_TRUE(std::get<1>(range) == 0); +} + +TEST(BufLiveRange, MulRangeLine) { + VarHandle i("i", kInt); + BufHandle a("a", {32}, kFloat); + BufHandle b("b", {32}, kFloat); + + // Construct Stmt: + // { + // for (int i = 0; i < 32; i++) { + // if (i<10 ? 1 : 0) { + // a[i] = i + i; + // b[i] = i * i; + // } + // } + // for (int i = 0; i < 32; i++) { + // if (i>10 ? 1 : 0) { + // a[i] = i * i; + // b[i] = i + i; + // } + // } + // } + + StorePtr aStore_1 = Store::make(a, {i}, i + i); + StorePtr bStore_1 = Store::make(b, {i}, i * i); + StmtPtr loop_1 = For::make( + i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL)); + + StorePtr aStore_2 = Store::make(a, {i}, i * i); + StorePtr bStore_2 = Store::make(b, {i}, i + i); + StmtPtr loop_2 = For::make( + i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL)); + + StmtPtr stmt = Block::make({loop_1, loop_2}); + + auto range_a = BufLiveRange::liveRange(stmt, a.node()); + ASSERT_TRUE(std::get<0>(range_a) == 0); + ASSERT_TRUE(std::get<1>(range_a) == 1); + + auto range_b = BufLiveRange::liveRange(stmt, b.node()); + ASSERT_TRUE(std::get<0>(range_b) == 0); + ASSERT_TRUE(std::get<1>(range_b) == 1); +} + +TEST(MemPlanning, MemReuseWithTypeCast) { + int M = 4; + int N = 4; + int K = 4; + + BufHandle AP("A", {M, K}, kFloat); + BufHandle BP("B", {K, N}, kFloat); + + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + Tensor DT = + Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return CompareSelect::make( + CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n)); + }); + Tensor FT = + Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n); + }); + StmtPtr stmt = + tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are + // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E' + // with typecasting. + //{ + // for (int i = 0; i < 4; i++) { + // for (int i_1 = 0; i_1 < 4; i_1++) { + // gemm[i, i_1] = float(0); + // for (int i_2 = 0; i_2 < 4; i_2++) { + // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, + // i_1]), reduce_args={i_2}); + // } + // } + // } + // for (int i_3 = 0; i_3 < 4; i_3++) { + // for (int i_4 = 0; i_4 < 4; i_4++) { + // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]); + // } + // } + // for (int i_5 = 0; i_5 < 4; i_5++) { + // for (int i_6 = 0; i_6 < 4; i_6++) { + // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6])); + // } + // } + // for (int i_7 = 0; i_7 < 4; i_7++) { + // for (int i_8 = 0; i_8 < 4; i_8++) { + // F[i_7, i_8] = E[i_7, i_8]; + // } + // } + //} + + LoopNest l(stmt, {FT.buf()}); + l.prepareForCodegen(); + SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT}); + + checkIR(cg.stmt(), R"IR( +# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] +# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] +# CHECK: Alias(E,gemm); +# CHECK: Free(relu); +# CHECK: Free(gemm))IR"); + + PaddedBuffer a_v(M, K, "a"); + PaddedBuffer b_v(K, N, "b"); + PaddedBuffer o1(M, N, "e_before"); + PaddedBuffer o2(M, N, "e_after"); + + for (const auto m : c10::irange(M)) { + for (const auto k : c10::irange(K)) { + a_v(m, k) = at::randn({1}).item().to(); + } + } + + for (const auto k : c10::irange(K)) { + for (const auto n : c10::irange(N)) { + b_v(k, n) = at::randn({1}).item().to(); + } + } + + cg.call({a_v, b_v, o1}); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); + + checkIR(cg_llvm.stmt(), R"IR( +# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] +# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] +# CHECK: Alias(E,gemm); +# CHECK: Free(relu); +# CHECK: Free(gemm))IR"); + + cg_llvm.call({a_v, b_v, o2}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(o1, o2, 1e-5); +#endif +} + +TEST(MemPlanning, NoMemReuseForLargerType) { + int M = 4; + int N = 4; + int K = 4; + + BufHandle AP("A", {M, K}, kShort); + BufHandle BP("B", {K, N}, kShort); + + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + auto zero = Cast::make(CT.buf()->dtype(), 0); + Tensor DT = + Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n)); + }); + Tensor FT = + Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n); + }); + StmtPtr stmt = + tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are + // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for + // 'E'. + //{ + // for (int i = 0; i < 4; i++) { + // for (int i_1 = 0; i_1 < 4; i_1++) { + // gemm[i, i_1] = int16_t(0); + // for (int i_2 = 0; i_2 < 4; i_2++) { + // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, + // i_1]), reduce_args={i_2}); + // } + // } + // } + // for (int i_3 = 0; i_3 < 4; i_3++) { + // for (int i_4 = 0; i_4 < 4; i_4++) { + // relu[i_3, i_4] = (gemm[i_3, i_4]) a_v(M, K, "a"); + PaddedBuffer b_v(K, N, "b"); + PaddedBuffer o1(M, N, "e_before"); + PaddedBuffer o2(M, N, "e_after"); + + for (const auto m : c10::irange(M)) { + for (const auto k : c10::irange(K)) { + a_v(m, k) = at::randn({1}).item().to(); + } + } + + for (const auto k : c10::irange(K)) { + for (const auto n : c10::irange(N)) { + b_v(k, n) = at::randn({1}).item().to(); + } + } + + cg.call({a_v, b_v, o1}); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); + + checkIR(cg_llvm.stmt(), R"IR( +# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] +# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] +# CHECK: Allocate(E); // dtype=float, dims=[4, 4] +# CHECK: Free(E); +# CHECK: Free(relu); +# CHECK: Free(gemm))IR"); + + cg_llvm.call({a_v, b_v, o2}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(o1, o2, 1e-5); +#endif +} + +TEST(MemPlanning, SameBufSizeMemReuse) { + int M = 1024; + int N = 1024; + int K = 2048; + + BufHandle AP("A", {M, K}, kFloat); + BufHandle BP("B", {K, N}, kFloat); + + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + Tensor DT = + Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + auto zero = Cast::make(CT.buf()->dtype(), 0); + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return DT.load(m, n) + DT.load(m, n); + }); + Tensor FT = + Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n) * ET.load(m, n); + }); + auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm' + // for 'add'. + //{ + // for (int M = 0; M < 1024; M++) { + // for (int N = 0; N < 1024; N++) { + // gemm[M, N] = float(0); + // for (int K = 0; K < 2048; K++) { + // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), + // reduce_args={K}); + // } + // } + // } + // for (int M_1 = 0; M_1 < 1024; M_1++) { + // for (int N_1 = 0; N_1 < 1024; N_1++) { + // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return DT.load(m, n) + DT.load(m, n); + }); + Tensor FT = + Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n) * ET.load(m, n); + }); + Tensor GT = + Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return FT.load(m, n) - ET.load(m, n); + }); + + auto stmt = + Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same + // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul' + //{ + // for (int M = 0; M < 1024; M++) { + // for (int N = 0; N < 1024; N++) { + // gemm[M, N] = float(0); + // for (int K = 0; K < 2048; K++) { + // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), + // reduce_args={K}); + // } + // } + // } + // for (int M_1 = 0; M_1 < 1024; M_1++) { + // for (int N_1 = 0; N_1 < 1024; N_1++) { + // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return DT.load(m, n) + DT.load(m, n); + }); + Tensor FT = + Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n) * ET.load(m, n); + }); + Tensor GT = + Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return FT.load(m, n) - 1; + }); + Tensor HT = + Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return GT.load(m, n) / 2; + }); + + auto stmt = Block::make( + {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and + // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for + // 'mul', and reuse 'gemm' for 'sub'. + //{ + // for (int M = 0; M < 1024; M++) { + // for (int N = 0; N < 1024; N++) { + // gemm[M, N] = float(0); + // for (int K = 0; K < 2048; K++) { + // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), + // reduce_args={K}); + // } + // } + // } + // for (int M_1 = 0; M_1 < 1024; M_1++) { + // for (int N_1 = 0; N_1 < 1024; N_1++) { + // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = Compute( + "add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) { + return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2); + }); + Tensor FT = Compute( + "mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) { + return ET.load(fm, fn) * ET.load(fm, fn); + }); + auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of + // buffer 'gemm' is smaller. + //{ + // for (int M = 0; M < 1024; M++) { + // for (int N = 0; N < 1024; N++) { + // gemm[M, N] = float(0); + // for (int K = 0; K < 2048; K++) { + // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), + // reduce_args={K}); + // } + // } + // } + // for (int M_1 = 0; M_1 < 1024; M_1++) { + // for (int N_1 = 0; N_1 < 1024; N_1++) { + // relu[M_1, N_1] = (gemm[M_1, N_1]) +#include +#include +#include +#include +#include + +using namespace torch::jit::tensorexpr; + +using Tensors = std::vector; +using Args = std::vector; +std::unique_ptr compile( + const Args& inputs, + const Tensors& outputs) { + LoopNest nest({outputs}); + nest.prepareForCodegen(); + nest.simplify(); + auto join = inputs; + join.insert(join.end(), outputs.begin(), outputs.end()); + return std::make_unique(nest.root_stmt(), join); +} + +TEST(Ops, Sum) { + constexpr int M = 8; + constexpr int N = 16; + std::vector testDims = {{0}, {1}, {0, 1}}; + std::vector> outputShapes = {{N}, {M}, {}}; + for (unsigned idx = 0; idx < testDims.size(); idx++) { + const auto& dims = testDims[idx]; + const auto& outShape = outputShapes[idx]; + + BufHandle a("a", {M, N}, kFloat); + std::vector outStrides = + c10::fmap(make_contiguous_strides(outShape)); + Tensor b = computeSum( + {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); + auto cg = compile({a}, {b}); + + auto at = at::arange(M * N, at::kFloat).view({M, N}); + auto ref = at::sum(at, dims); + auto bt = at::empty_like(ref); + + cg->call({at.data_ptr(), bt.data_ptr()}); + + ASSERT_TRUE(at::allclose(bt, ref)); + } +} + +TEST(Ops, ChannelsLastSum) { + constexpr int A = 2; + constexpr int B = 3; + constexpr int C = 4; + constexpr int D = 5; + constexpr int E = 6; + std::vector testDims = {{0}, {1}, {0, 1}}; + + std::vector> outputShapes = { + {B, C, D, E}, {A, C, D, E}, {C, D, E}}; + for (unsigned idx = 0; idx < testDims.size(); idx++) { + const auto& dims = testDims[idx]; + const auto& outShape = outputShapes[idx]; + + BufHandle a("a", {A, B, C, D, E}, kFloat); + std::vector outStrides = + c10::fmap(make_channels_last_strides(outShape)); + Tensor b = computeSum( + {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); + auto cg = compile({a}, {b}); + + auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E}); + auto ref = at::sum(at, dims); + auto bt = at::empty_like(ref); + + cg->call({at.data_ptr(), bt.data_ptr()}); + + ASSERT_TRUE(at::allclose(bt, ref)); + } +} diff --git a/test/cpp/tensorexpr/test_quantization.cpp b/test/cpp/tensorexpr/test_quantization.cpp new file mode 100644 index 000000000000..af6b539ff33e --- /dev/null +++ b/test/cpp/tensorexpr/test_quantization.cpp @@ -0,0 +1,452 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/ir.h" + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; +using SimpleIRExprEval = ExprEval; +using namespace torch::indexing; +using namespace torch::jit::tensorexpr; + +class Quantization : public ::testing::Test { + public: + void SetUp() override { + getTEMustUseLLVMOnCPU() = false; + } +}; + +TEST_F(Quantization, QuantDequantInt8) { + const auto graph_string = R"IR( + graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=12]() + %3 : int = prim::Constant[value=13]() + %4 : float = prim::Constant[value=0.1]() + %q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) + %6 : Float(2, 2) = aten::dequantize(%q.1) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8); + auto y_expected = at::dequantize(q); + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantDequantUInt8) { + const auto graph_string = R"IR( + graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %3 : int = prim::Constant[value=122]() + %4 : float = prim::Constant[value=0.1]() + %q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) + %6 : Float(2, 2) = aten::dequantize(%q.1) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); + auto y_expected = at::dequantize(q); + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantDequantUInt8_NLC) { + const auto graph_string = R"IR( + graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)): + %2 : int = prim::Constant[value=13]() + %3 : int = prim::Constant[value=122]() + %4 : float = prim::Constant[value=0.1]() + %q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) + %6 : Float(1, 2, 2) = aten::dequantize(%q.1) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + x.unsafeGetTensorImpl()->set_sizes_and_strides( + std::initializer_list{1, 2, 2}, {4, 1, 2}); + auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); + auto y_expected = at::dequantize(q); + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +at::Tensor quantized_add( + at::Tensor x1, + at::Tensor x2, + double scale, + int64_t zero) { + const auto qadd_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("quantized::add", "") + .typed(); + return qadd_op.call(x1, x2, scale, zero); +} + +TEST_F(Quantization, QuantAddDequantInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=12]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8); + auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8); + auto qa = quantized_add(q1, q2, 0.1f, 13); + auto y_expected = at::dequantize(qa); + TensorExprKernel k(graph); + std::vector inputs = {x1, x2}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "x2:\n" << x2 << std::endl; + std::cout << "q2:\n" << q2 << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantAddDequantUInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); + auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); + auto qa = quantized_add(q1, q2, 0.1f, 13); + auto y_expected = at::dequantize(qa); + + TensorExprKernel k(graph); + std::vector inputs = {x1, x2}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "x2:\n" << x2 << std::endl; + std::cout << "q2:\n" << q2 << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantSigmoidDequantUInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %qa : QUInt8(2, 2) = aten::sigmoid(%q1) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); + auto qs = at::sigmoid(q1); + auto y_expected = at::dequantize(qs); + + TensorExprKernel k(graph); + std::vector inputs = {x1}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "qs:\n" << qs << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +at::Tensor quantized_mul( + at::Tensor x1, + at::Tensor x2, + double scale, + int64_t zero) { + const auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("quantized::mul", "") + .typed(); + return op.call(x1, x2, scale, zero); +} + +TEST_F(Quantization, QuantMulDequantUInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); + auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); + auto qa = quantized_mul(q1, q2, 0.1f, 13); + auto y_expected = at::dequantize(qa); + + TensorExprKernel k(graph); + std::vector inputs = {x1, x2}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "x2:\n" << x2 << std::endl; + std::cout << "q2:\n" << q2 << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) { + const auto graph_string = R"IR( + graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %4 : NoneType = prim::Constant() + %3 : int[] = prim::Constant[value=[6, 6]]() + %qz : int = prim::Constant[value=13]() + %qs : float = prim::Constant[value=0.1]() + %q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2) + %qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4) + %6 : Float(1, 1, 6, 6) = aten::dequantize(%qu) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); + auto qu = at::upsample_nearest2d(q, {6, 6}); + auto y_expected = at::dequantize(qu); + + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "q:\n" << q << std::endl; + std::cout << "qu:\n" << qu << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, UpsampleNearst2d) { + const auto graph_string = R"IR( + graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): + %4 : NoneType = prim::Constant() + %3 : int[] = prim::Constant[value=[4, 4]]() + %u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4) + return (%u))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto y_expected = at::upsample_nearest2d(x, {4, 4}); + + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +at::Tensor quantized_cat( + c10::List const& xs, + int64_t dim, + double scale, + int64_t zero) { + const auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("quantized::cat", "") + .typed const&, + int64_t, + std::optional, + std::optional)>(); + return op.redispatch( + DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero); +} + +TEST_F(Quantization, QuantCatDequantUInt8) { + const auto graph_string = R"IR( + graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): + %qdt : int = prim::Constant[value=13]() + %qxz : int = prim::Constant[value=13]() + %qxs : float = prim::Constant[value=0.1]() + %qyz : int = prim::Constant[value=16]() + %qys : float = prim::Constant[value=0.15]() + %qzz : int = prim::Constant[value=19]() + %qzs : float = prim::Constant[value=0.2]() + %qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt) + %qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt) + %qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt) + %catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz) + %catd : int = prim::Constant[value=0]() + %qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz) + %cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat) + return (%cat))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); + auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8); + auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8); + auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13); + auto expected = at::dequantize(qcat); + + TensorExprKernel k(graph); + std::vector inputs = {x, y, z}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto result = stack[0].toTensor(); + bool check = at::allclose(expected, result); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y:\n" << y << std::endl; + std::cout << "z:\n" << z << std::endl; + std::cout << "qx:\n" << qx << std::endl; + std::cout << "qy:\n" << qy << std::endl; + std::cout << "qz:\n" << qz << std::endl; + std::cout << "qcat:\n" << qcat << std::endl; + std::cout << "expected:\n" << expected << std::endl; + std::cout << "result:\n" << result << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp new file mode 100644 index 000000000000..fb83ab85b71e --- /dev/null +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -0,0 +1,1928 @@ +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +TEST(Reductions, ReduceSum0D_1) { + const int M = 10; + + BufHandle b("b", {M}, kFloat); + std::vector in(M); + for (const auto j : c10::irange(M)) { + in[j] = j; + } + + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + for (const auto i : c10::irange(M)) { + ASSERT_EQ(out[i], in[i]); + } +} + +TEST(Reductions, ReduceSum0D_2) { + BufHandle b("b", {}, kFloat); + std::vector in(1); + in[0] = 77.7; + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], in[0]); +} + +// Sum an array to a single value. +TEST(Reductions, ReduceSum1D) { + BufHandle b("b", {10}, kFloat); + std::vector in(10); + for (const auto j : c10::irange(10)) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {10}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 45); +} +// Sum a 2D tensor to a 1D tensor with dynamic shapes. +TEST(Reductions, ReduceSum2D) { + const int M = 3; + const int N = 7; + + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + BufHandle b("b", {m, n}, kFloat); + std::vector in(M * N); + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + in[i * N + j] = j; + } + } + + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, n, m}); + + cg.call({in, out, 5, 7}); + + float expected = 0; + for (const auto i : c10::irange(N)) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected += i; + } + + for (const auto i : c10::irange(M)) { + ASSERT_EQ(out[i], expected); + } +} + +// Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to +// check our work. +TEST(Reductions, ReduceSum3D) { + const int M = 10; + VarHandle m("m", kInt); + + BufHandle b("b", {2, 3, m}, kFloat); + + Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, m}); + + std::vector bData(2 * 3 * M, 0); + std::vector cData(2 * 3, 6.0f); + std::vector dData(2, 1.0f); + std::vector eData(2, 1.0f); + + for (int i = 0; i < 2 * 3; ++i) { + for (const auto j : c10::irange(M)) { + bData[i * M + j] = j; + } + } + + cg.call({bData, cData, M}); + float expected = 0; + for (const auto i : c10::irange(M)) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected += i; + } + + for (int i = 0; i < 2 * 3; ++i) { + ASSERT_EQ(cData[i], expected); + } + + Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m}); + LoopNest loop2({d}); + loop2.prepareForCodegen(); + StmtPtr s2 = loop2.root_stmt(); + s2 = IRSimplifier::simplify(s2); + + SimpleIREvaluator cg2(s2, {b, d, m}); + cg2.call({bData, dData, M}); + + // We're combining an additional dimension of 3, so the sum is 3x. + expected = expected * 3; + + for (const auto i : c10::irange(2)) { + ASSERT_EQ(dData[i], expected); + } + + // This is the same as just reducing the original result across that axis. + BufHandle c_buf(c.buf()); + Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3}); + LoopNest loop3({e}); + loop3.prepareForCodegen(); + StmtPtr s3 = loop3.root_stmt(); + s3 = IRSimplifier::simplify(s3); + + SimpleIREvaluator cg3(s3, {c, e}); + cg3.call({cData, eData}); + + for (const auto i : c10::irange(2)) { + ASSERT_EQ(eData[i], expected); + } +} + +// Sum a large (10 D) Tensor 5 dimensions in. +TEST(Reductions, ReduceSum10D) { + BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat); + const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3; + BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat); + const int OutputSize = 2 * 3 * 2 * 3 * 2; + + std::vector in(InputSize, 1.f); + std::vector out(OutputSize, -1.f); + + Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in_, c}); + + cg.call({in, out}); + + // NOLINTNEXTLINE(bugprone-integer-division) + float expected = InputSize / OutputSize; + for (const auto i : c10::irange(OutputSize)) { + ASSERT_EQ(out[i], expected); + } +} + +// Reduce via Mul rather than Add using a custom Reducer. +TEST(Reductions, ReduceProduct) { + const int M = 4; + const int N = 4; + + BufHandle b("b", {M, N}, kFloat); + std::vector in(M * N); + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + in[i * N + j] = 2 + j; + } + } + + std::vector out(M, -1.f); + + Reducer product( + ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; }); + + Tensor c = Reduce("product", {M}, product, b, {N}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + + float expected = 1; + for (const auto i : c10::irange(N)) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected *= 2 + i; + } + + for (const auto i : c10::irange(M)) { + ASSERT_EQ(out[i], expected); + } +} + +// Maximum reductions. +TEST(Reductions, ReduceMax) { + BufHandle in_("b", {10}, kFloat); + + std::vector in(10); + std::vector out(1, -1.f); + for (const auto j : c10::irange(10)) { + in[j] = j; + } + + Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10}); + + LoopNest loop({dm1}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + SimpleIREvaluator cg(s, {in_, dm1}); + + cg.call({in, out}); + + ASSERT_EQ(out[0], 9); + + BufHandle in2_("b", {2, 5}, kFloat); + std::vector out2(2, -1.f); + + Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5}); + + LoopNest loop2({m2d}); + loop2.prepareForCodegen(); + s = loop2.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg2(s, {in2_, m2d}); + cg2.call({in, out2}); + + ASSERT_EQ(out2[0], 4); + ASSERT_EQ(out2[1], 9); +} + +// Minimum reduction, with custom initialization. +TEST(Reductions, ReduceMinCustomInitializer) { + VarHandle minInit("minInit", kFloat); + BufHandle in_("b", {10}, kFloat); + + std::vector in(10); + std::vector out(1, -1.f); + for (const auto j : c10::irange(10)) { + in[j] = 10 + j; + } + + Tensor min = Reduce( + "min", + {}, + Minimum(ExprHandle(minInit)), + [&](ParameterList& v) { return in_.load(v); }, + {10}); + + LoopNest loop({min}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in_, min, minInit}); + + // Works normally (note that out data starts lower than the correct + // minimum). + cg.call({in, out, std::numeric_limits::max()}); + ASSERT_EQ(out[0], 10); + + // With an initializer lower than the min, that's the min. + cg.call({in, out, 5.f}); + ASSERT_EQ(out[0], 5); +} + +// Example implementation of Any/All. +// TODO: this is very awkward without logical And/Or operators. +TEST(Reductions, ReduceAnyAll) { + VarHandle searchValue("searchValue", kInt); + BufHandle b("b", {4, 10}, kInt); + + Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) { + return CompareSelect::make(a, 1, 1, b, kEQ); + }); + + Tensor any = Reduce( + "anyEqual", + {4}, + anyEqSV, + [&](const auto& i, const auto& j) { + return CompareSelect::make(b.load(i, j), searchValue, kEQ); + }, + {10}); + + LoopNest loop({any}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, any, searchValue}); + + std::vector in(40, 0); + std::vector out(4, 0); + + // input has 0-39 in 4 rows. + for (const auto i : c10::irange(40)) { + in[i] = i; + } + cg.call({in, out, 1}); + + // only the first row has 1 + ASSERT_EQ(out[0], 1); + ASSERT_EQ(out[1], 0); + ASSERT_EQ(out[2], 0); + ASSERT_EQ(out[3], 0); + + cg.call({in, out, 15}); + + // 15 in the 3rd row + ASSERT_EQ(out[0], 0); + ASSERT_EQ(out[1], 1); + ASSERT_EQ(out[2], 0); + ASSERT_EQ(out[3], 0); + + Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) { + return CompareSelect::make(a, 0, 0, b, kEQ); + }); + + Tensor allGreaterThan = Reduce( + "allGreaterThan", + {4}, + allGTSV, + [&](const auto& i, const auto& j) { + return CompareSelect::make(b.load(i, j), searchValue, kGT); + }, + {10}); + + LoopNest loop2({allGreaterThan}); + loop2.prepareForCodegen(); + s = loop2.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue}); + + cg2.call({in, out, 11}); + + // 11 is in row 2. + ASSERT_EQ(out[0], 0); + ASSERT_EQ(out[1], 0); + ASSERT_EQ(out[2], 1); + ASSERT_EQ(out[3], 1); + + cg2.call({in, out, -3}); + + // All are positive. + ASSERT_EQ(out[0], 1); + ASSERT_EQ(out[1], 1); + ASSERT_EQ(out[2], 1); + ASSERT_EQ(out[3], 1); +} + +TEST(Reductions, ReduceMatmul2D) { + BufHandle tA("tA", {3, 2}, kFloat); + BufHandle tB("tB", {2, 3}, kFloat); + + std::vector tA_(6); + std::vector tB_(6); + + std::vector out(9, -1.f); + for (const auto i : c10::irange(3)) { + for (const auto j : c10::irange(2)) { + tA_[i * 2 + j] = i * 2 + j; + tB_[j * 3 + i] = i * 2 + j; + } + } + + Tensor mm = Reduce( + "mm", + {3, 3}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return tA.load(m, k) * tB.load(k, n); + }, + {2}); + + LoopNest loop({mm}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {tA, tB, mm}); + cg.call({tA_, tB_, out}); + + std::vector expected( + {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f}); + + for (const auto i : c10::irange(9)) { + ASSERT_EQ(out[i], expected[i]); + } +} + +TEST(Reductions, ReduceRfactorLike) { + BufHandle in("in", {10, 10}, kFloat); + std::vector in_(100); + for (const auto i : c10::irange(100)) { + in_[i] = i; + } + std::vector in_rf_(10, -2.f); + std::vector out(1, -1.f); + + Tensor l1 = Reduce("l1", {10}, Sum(), in, {10}); + BufHandle in_rf(l1.buf()); + + Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10}); + + LoopNest loop({l1, l2}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in, l1, l2}); + cg.call({in_, in_rf_, out}); + + ASSERT_EQ(out[0], 99 * 50); +} + +TEST(Reductions, ReduceAsProducer) { + const int M = 10; + VarHandle m("m", kInt); + + BufHandle a("a", {2, 3}, kFloat); + BufHandle b("b", {2, 3, m}, kFloat); + + Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); + Tensor d = + Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) { + return c.load(l, n) * a.load(l, n); + }); + LoopNest loop({d}, {c, d}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {a, b, d, m}); + + std::vector aData(2 * 3, 0); + std::vector bData(2 * 3 * M, 0); + std::vector dData(2 * 3, 6.0f); + + for (int i = 0; i < 2 * 3; ++i) { + aData[i] = 6 - i; + for (const auto j : c10::irange(M)) { + bData[i * M + j] = j; + } + } + + cg.call({aData, bData, dData, M}); + float expected = 0; + for (const auto i : c10::irange(M)) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected += i; + } + for (int i = 0; i < 2 * 3; ++i) { + ASSERT_EQ(dData[i], expected * (6 - i)); + } +} + +TEST(Reductions, ReduceAsConsumer) { + const int M = 10; + VarHandle m("m", kInt); + + BufHandle a("a", {2, 3, m}, kFloat); + BufHandle b("b", {2, 3, m}, kFloat); + + Tensor c = Compute( + "scale", + {2, 3, m}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {2}, Sum(), c, {3, m}); + LoopNest loop({d}, {c, d}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {a, b, d, m}); + + std::vector aData(2 * 3 * M, 0); + std::vector bData(2 * 3 * M, 0); + std::vector dData(2, 6.0f); + + for (int i = 0; i < 2 * 3; ++i) { + for (const auto j : c10::irange(M)) { + bData[i * M + j] = j + 1; + aData[i * M + j] = 6 - i; + } + } + + cg.call({aData, bData, dData, M}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + float expected[2] = {0, 0}; + for (const auto i : c10::irange(2)) { + for (const auto j : c10::irange(3)) { + for (const auto k : c10::irange(M)) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected[i] += (k + 1) * (6 - (i * 3 + j)); + } + } + } + + for (const auto i : c10::irange(2)) { + ASSERT_EQ(dData[i], expected[i]); + } +} + +TEST(Reductions, SplitReduceAxis) { + BufHandle in("in", {16, 8}, kFloat); + + std::vector in_(16 * 8); + for (const auto i : c10::irange(16)) { + for (const auto j : c10::irange(8)) { + in_[i * 8 + j] = i; + } + } + std::vector out(16, -1.f); + + Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); + LoopNest l({tensor}); + std::vector loops = l.getLoopStmtsFor(tensor); + LoopNest::splitWithTail(loops[1], 2); + + l.prepareForCodegen(); + + StmtPtr s = l.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in, tensor}); + cg.call({in_, out}); + + for (const auto i : c10::irange(16)) { + ASSERT_EQ(out[i], i * 8); + } +} + +TEST(Reductions, SplitNonReduceAxis) { + BufHandle in("in", {16, 8}, kFloat); + + std::vector in_(16 * 8); + for (const auto i : c10::irange(16)) { + for (const auto j : c10::irange(8)) { + in_[i * 8 + j] = i; + } + } + std::vector out(16, -1.f); + Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); + LoopNest l({tensor}); + std::vector loops = l.getLoopStmtsFor(tensor); + LoopNest::splitWithTail(loops[0], 2); + LoopNest::splitWithTail(loops[0], 2); + + l.prepareForCodegen(); + + StmtPtr s = l.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in, tensor}); + cg.call({in_, out}); + + for (const auto i : c10::irange(16)) { + ASSERT_EQ(out[i], i * 8); + } +} + +TEST(Reductions, ReorderedReductionInitializer) { + /* From the quip: + for k in 0..1: // blockIdx + for m in 0..128: + for n in 0..64: // threadIdx + SumOp(c(k, n), 0, a(k, m, n), {m}) + */ + + BufHandle in("in", {1, 12, 6}, kFloat); + std::vector in_(12 * 6, 1.f); + + Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6}); + LoopNest l_({tensor_}); + + l_.prepareForCodegen(); + StmtPtr s_ = Stmt::clone(l_.root_stmt()); + s_ = IRSimplifier::simplify(s_); + + Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6}); + LoopNest l({tensor}); + + auto loops = l.getLoopStmtsFor(tensor); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + + LoopNest::reorderAxis(loops[1], loops[2]); + + StmtPtr s = l.root_stmt(); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + s = IRSimplifier::simplify(s); + + l.prepareForCodegen(); + + s = l.root_stmt(); + s = IRSimplifier::simplify(s); + + std::vector out1(16, -1.f); + SimpleIREvaluator cg(s_, {in, tensor_}); + cg.call({in_, out1}); + + std::vector out2(16, -1.f); + SimpleIREvaluator cg2(s, {in, tensor}); + cg2.call({in_, out2}); + + for (const auto i : c10::irange(16)) { + ASSERT_EQ(out1[i], out2[i]); + } +} + +TEST(Reductions, ReduceRfactor) { + const int M = 10; + const int N = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + BufHandle b("b", {m, n}, kFloat); + std::vector in(M * N); + for (int j = 0; j < M * N; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); + auto rc = NodeFinder::find(loop.root_stmt()); + ASSERT_EQ(rc.size(), 2); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, m, n}); + + cg.call({in, out, M, N}); + ASSERT_EQ(out[0], 4950); +} + +TEST(Reductions, Reduce3DRfactorInner) { + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + BufHandle b("b", {m, n, k}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + ASSERT_FALSE(loop.rfactor(c_body, loops.at(2))); + auto rc = NodeFinder::find(loop.root_stmt()); + ASSERT_EQ(rc.size(), 1); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, Reduce3DRfactorOuter) { + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + BufHandle b("b", {m, n, k}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); + auto rc = NodeFinder::find(loop.root_stmt()); + ASSERT_EQ(rc.size(), 2); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, m, n, k}); + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, ReduceRepeatedInternalRfactor) { + BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat); + const int InputSize = 2 * 3 * 4 * 5 * 6; + + std::vector in(InputSize, 1.f); + std::vector out(1, -1.f); + std::vector ref(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6}); + LoopNest orig_loop({c}); + + // Try rfactoring N outer loops + for (const auto rfac_number : c10::irange(1, 5)) { + LoopNest refloop(orig_loop); + LoopNest loop(orig_loop); + refloop.prepareForCodegen(); + SimpleIREvaluator ref_cg( + IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); + ref_cg.call({in, ref}); + + BufPtr tmp_buf = c.buf(); + + for (const auto idx : c10::irange(rfac_number)) { + auto reduce = loop.getAllWritesToBuf(tmp_buf)[1]; + ASSERT_TRUE(loop.rfactor( + reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf)); + } + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in_, c}); + cg.call({in, out}); + + ASSERT_EQ(ref[0], out[0]); + } +} + +// Split a reduction axis with a tail loop. +TEST(Reductions, ReduceSplitTail) { + const int M = 10; + const int N = 10; + const int K = 10; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithTail(loops[i], 8); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis cleanly so there is no tail loop. +TEST(Reductions, ReduceSplitNoTail) { + const int M = 10; + const int N = 10; + const int K = 10; + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithTail(loops[i], 5); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis with only a tail loop (the split loop will be size 0 +// and eliminated out). +TEST(Reductions, ReduceOverSplitTail) { + const int M = 10; + const int N = 10; + const int K = 10; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithTail(loops[i], 16); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis with a mask. +TEST(Reductions, ReduceSplitMask) { + const int M = 10; + const int N = 10; + const int K = 10; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithMask(loops[i], 8); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis cleanly not requiring a mask. +TEST(Reductions, ReduceSplitNoMask) { + const int M = 10; + const int N = 10; + const int K = 10; + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithMask(loops[i], 5); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis with all logic in the mask. +TEST(Reductions, ReduceOverSplitMask) { + const int M = 10; + const int N = 10; + const int K = 10; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithMask(loops[i], 16); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Test an rfactor when there are two ReduceOps in the graph due to a +// splitWithTail. +TEST(Reductions, ReduceSplitRfactor) { + const int M = 2; + const int N = 10; + const int K = 10; + const int SPLIT_FACTOR = 4; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (const auto m : c10::irange(M)) { + for (int j = 0; j < N * K; ++j) { + in[m * N * K + j] = j; + } + } + + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithTail(loops[2], SPLIT_FACTOR); + + auto c_body = loop.getAllWritesToBuf(c.buf())[2]; + auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); + ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); + LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]); + all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); + ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); + ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1])); + loop.prepareForCodegen(); + loop.simplify(); + StmtPtr s = loop.root_stmt(); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + for ([[maybe_unused]] const auto i : c10::irange(M)) { + ASSERT_EQ(out[0], 4950); + } +} + +// Test an rfactor which ends up being eliminated since the total loop size is +// smaller than the split factor. +TEST(Reductions, ReduceOverSplitRfactor) { + const int N = 10; + const int K = 10; + const int SPLIT_FACTOR = 16; + + BufHandle b("b", {N, K}, kFloat); + std::vector in(N * K); + for (int j = 0; j < N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + ForPtr i, t; + LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); + LoopNest::reorderAxis(loops[0], i); + + auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); + ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0])); + LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]); + + loop.prepareForCodegen(); + loop.simplify(); + StmtPtr s = loop.root_stmt(); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + + std::ostringstream oss; + oss << *cg.stmt(); + + // Check the IR to verify the rfactored reduce is eliminated. + // TODO: The alloc free should be eliminated here since it is size 0. + /* + const std::string& verification_pattern = + R"IR( +# CHECK: Allocate(tmp_buf); // dtype=float, dims=[0] +# CHECK: sum[0] = 0.f; +# CHECK: for (int n = 0; n < 10; n++) { +# CHECK: for (int k_tail = 0; k_tail < 10; k_tail++) { +# CHECK: sum[0] = (sum[0]) + (b[k_tail + 10 * n]); +# CHECK: } +# CHECK: } +# CHECK: Free(tmp_buf);)IR"; + */ + // TODO: rfactor output is not consistent yet, will fix (@nickg). + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Reductions, ReduceInlineReduction) { + const int M = 4; + const int N = 5; + const int K = 6; + + BufHandle a_buf("a", {M}, kFloat); + BufHandle b_buf("b", {M, N, K}, kFloat); + + Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K}); + Tensor y = Compute( + "y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); }); + + PaddedBuffer a_v(M); + PaddedBuffer b_v(M, N, K); + + for (const auto i : c10::irange(M)) { + a_v(i) = i * i; + } + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + for (const auto k : c10::irange(K)) { + b_v(i, j, k) = j * j * k; + } + } + } + + LoopNest l1({y}, {x, y}); + // Cannot inline a reduction computation + ASSERT_FALSE(l1.computeInline(x.buf())); +} + +TEST(Reductions, ReduceInlineConsumer) { + const int M = 4; + const int N = 5; + const int K = 6; + + BufHandle a_buf("a", {M, N, K}, kFloat); + BufHandle b_buf("b", {M, N, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n, k) + b_buf.load(m, n, k); + }); + Tensor y = Reduce("y", {M}, Sum(), x, {N, K}); + + PaddedBuffer a_v(M, N, K); + PaddedBuffer b_v(M, N, K); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + for (const auto k : c10::irange(K)) { + a_v(i, j, k) = i * i + k; + b_v(i, j, k) = j * j + k; + } + } + } + + LoopNest l1({y}, {x, y}); + LoopNest l2(l1); + l2.computeInline(x.buf()); + + l1.prepareForCodegen(); + l2.prepareForCodegen(); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); + + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); + + PaddedBuffer y_1(M); + PaddedBuffer y_2(M); + + eval1(a_v, b_v, y_1); + eval2(a_v, b_v, y_2); + ExpectAllNear(y_1, y_2, 1e-5); + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt2; + ASSERT_GT(oss1.str().size(), oss2.str().size()); +} + +TEST(Reductions, ReduceInlineReducerInternal) { + const int M = 4; + const int N = 5; + const int K = 6; + + BufHandle a_buf("a", {M, N, K}, kFloat); + BufHandle b_buf("b", {M, N, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n, k) + b_buf.load(m, n, k); + }); + + Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) { + return Add::make(ExprHandle(1.f), Min::make(a, b, false)); + }); + Tensor y = Reduce("y", {M}, minimum, x, {N, K}); + + PaddedBuffer a_v(M, N, K); + PaddedBuffer b_v(M, N, K); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + for (const auto k : c10::irange(K)) { + a_v(i, j, k) = i * i + k; + b_v(i, j, k) = j * j + k; + } + } + } + + LoopNest l1({y}, {x, y}); + LoopNest l2(l1); + l2.computeInline(x.buf()); + + l1.prepareForCodegen(); + l2.prepareForCodegen(); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); + + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); + + PaddedBuffer y_1(M); + PaddedBuffer y_2(M); + + eval1(a_v, b_v, y_1); + eval2(a_v, b_v, y_2); + ExpectAllNear(y_1, y_2, 1e-5); + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt2; + ASSERT_GT(oss1.str().size(), oss2.str().size()); +} + +TEST(Reductions, ReductionCacheAccessesOperatorAxis) { + int L = 4; + int N = 3; + int M = 2; + + BufHandle a("a", {L, N, M}, kFloat); + BufHandle b("b", {L, N, M}, kFloat); + + Tensor c = Compute( + "scale", + {L, N, M}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); + + Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before( + LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e}); + + StmtPtr d_loop = l.getLoopStmtsFor(d)[0]; + l.cacheAccesses(d.buf(), "d_local", d_loop); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg_after(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg_after.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local); // dtype=float, dims=[4] +#CHECK: for (int i_2 +#CHECK: d_local[i_2] = 0.f +#CHECK: for (int +#CHECK: for (int +#CHECK: d_local[i_2] = (d_local[i_2]) + (scale[ +#CHECK: } +#CHECK: } +#CHECK: } +#CHECK: for (int i_3 +#CHECK: sum[i_3] = d_local[i_3] +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (const auto l : c10::irange(L)) { + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); +} + +TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { + int L = 4; + int N = 3; + int M = 2; + + BufHandle a("a", {L, N, M}, kFloat); + BufHandle b("b", {L, N, M}, kFloat); + + Tensor c = Compute( + "scale", + {L, N, M}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); + + Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); + + StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; + l.cacheAccesses(d.buf(), "d_local", d_loop); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg_after(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg_after.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local); // dtype=float, dims=[1] +#CHECK: sum[i_1] = 0 +#CHECK: d_local[0] = sum[i_1] +#CHECK: for (int j_1 +#CHECK: for (int k_1 +#CHECK: d_local[0] = (d_local[0]) + (scale[ +#CHECK: } +#CHECK: } +#CHECK: sum[i_1] = d_local[0] +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (const auto l : c10::irange(L)) { + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); +} + +TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { + int L = 4; + int N = 3; + int M = 2; + + BufHandle a("a", {L, N, M}, kFloat); + BufHandle b("b", {L, N, M}, kFloat); + + Tensor c = Compute( + "scale", + {L, N, M}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); + + Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); + + StmtPtr d_loop = l.getLoopStmtsFor(d)[2]; + l.cacheAccesses(d.buf(), "d_local", d_loop); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg_after(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg_after.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local); // dtype=float, dims=[1] +#CHECK: sum[i_1] = 0 +#CHECK: for (int +#CHECK: d_local[0] = 0 +#CHECK: for (int +#CHECK: d_local[0] = (d_local[0]) + (scale[ +#CHECK: } +#CHECK: sum[i_1] = (sum[i_1]) + (d_local[0]) +#CHECK: } +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (const auto l : c10::irange(L)) { + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); +} + +TEST(Reductions, ReductionCacheBodyAccess) { + BufHandle a("a", {24, 32, 12}, kFloat); + BufHandle b("b", {24, 32, 12}, kFloat); + + Tensor c = Compute( + "scale", + {24, 32, 12}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); + + Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + + StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; + l.cacheAccesses(c.buf(), "scale_local", d_loop); + + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12] +#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) { +#CHECK: for (int k_1 = 0; k_1 < 12; k_1++) { +#CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1]; +#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]); +#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]); +#CHECK: Free(scale_local); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionCacheConsumerAccess) { + BufHandle a("a", {24, 32, 12}, kFloat); + BufHandle b("b", {24, 32, 12}, kFloat); + + Tensor c = Compute( + "scale", + {24, 32, 12}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); + + Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + + LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); + + StmtPtr e_loop = l.getLoopStmtsFor(e)[1]; + l.cacheAccesses(d.buf(), "sum_local", e_loop); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Alias(sum_local,scale); +#CHECK: sum[i_1] = (sum[i_1]) + (scale[ +#CHECK: for (int j_2 = 0; j_2 < 4 +#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; +#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionSplitCacheConsumerAccess) { + BufHandle a("a", {24, 32, 12}, kFloat); + BufHandle b("b", {24, 32, 12}, kFloat); + + Tensor c = Compute( + "scale", + {24, 32, 12}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); + + Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + + ForPtr inner; + + // Split outer reduction axis. + LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner); + + // Split reduction consumer. + LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); + + l.cacheAccesses(d.buf(), "sum_local", inner); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {a, b, e}); + + // reduction changes but cache does not. + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Alias(sum_local,scale); +#CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]); +#CHECK: for (int i_2 = 0; i_2 < 6 +#CHECK: for (int j_2 = 0; j_2 < 4 +#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; +#CHECK: for (int j_3 = 0; j_3 < 4 +#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionReorderCacheConsumerAccess) { + BufHandle a("a", {24, 32, 12}, kFloat); + BufHandle b("b", {24, 32, 12}, kFloat); + + Tensor c = Compute( + "scale", + {24, 32, 12}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); + + Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + + ForPtr inner; + + // reorder outer reduction axes. + auto loops = l.getLoopStmtsFor(d); + LoopNest::reorderAxis(loops[0], loops[1]); + + // Split reduction consumer. + LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); + + l.cacheAccesses(d.buf(), "sum_local", inner); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {a, b, e}); + + // neither reduction body not cache changes. + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]); +#CHECK: for (int i_3 = 0; i_3 < 6; +#CHECK: for (int j_2 = 0; j_2 < 4; +#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3]; +#CHECK: for (int j_3 = 0; j_3 < 4; +#CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionRfactorCacheTempOuter) { + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + BufHandle b("B", {m, n, k}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); + LoopNest loop({c}); + + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::reorderAxis(loops.at(0), loops.at(1)); + loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + BufPtr rfac_buf; + ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); + loop.distributeLoop(loops.at(0)); + + auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); + ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); + LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); + + all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); + LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]); + loop.simplify(); + loop.prepareForCodegen(); + StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] +#CHECK: Allocate(tmp); // dtype=float, dims=[n] +#CHECK: for (int i_1 = 0; i_1 < m +#CHECK: for (int j = 0; j < n +#CHECK: tmp[j] = 0 +#CHECK: } +#CHECK: for (int j_1 = 0; j_1 < n +#CHECK: for (int k +#CHECK: tmp[j_1] = (tmp[j_1]) + (B[ +#CHECK: } +#CHECK: } +#CHECK: for (int j_2 = 0; j_2 < n +#CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]); +#CHECK: } +#CHECK: Free(tmp); +#CHECK-NOT: tmp + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, ReductionRfactorCacheTempInner) { + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + BufHandle b("B", {m, n, k}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + + LoopNest::reorderAxis(loops.at(0), loops.at(1)); + loops = loop.getLoopStmtsFor(c); + BufPtr rfac_buf; + ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); + loop.distributeLoop(loops.at(0)); + auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); + ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); + LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); + + all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); + ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); + LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]); + loop.prepareForCodegen(); + loop.simplify(); + StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] +#CHECK: Allocate(tmp); // dtype=float, dims=[1] +#CHECK: for (int i_1 = 0; i_1 < m +#CHECK: for (int j = 0; j < n +#CHECK: tmp[0] = 0 +#CHECK: for (int k +#CHECK: tmp[0] = (tmp[0]) + (B[ +#CHECK: } +#CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]); +#CHECK: Free(tmp); +#CHECK-NOT: tmp + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, ReductionVectorize) { + std::vector in_(8 * 8); + for (const auto i : c10::irange(8)) { + for (const auto j : c10::irange(8)) { + in_[i * 8 + j] = i; + } + } + std::vector out_before(8, -1.f); + std::vector out_after(8, -1.f); + + BufHandle in("in", {8, 8}, kFloat); + + Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); + LoopNest l_before({tensor}); + LoopNest l(l_before); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); + cg_before.call({in_, out_before}); + + ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0])); + + StmtPtr s = l.root_stmt(); + s = LoopNest::sanitizeNames(IRSimplifier::simplify(s)); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8); +#CHECK: for (int i = 0; i < 8; i++) { +#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i}); +#CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // Vectorizing should not change result. + l.prepareForCodegen(); + s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(s, {in, tensor}); + cg_after.call({in_, out_after}); + for (const auto i : c10::irange(8)) { + ASSERT_EQ(out_before[i], out_after[i]); + } +} + +TEST(Reductions, ReductionVectorizeInner) { + BufHandle in("in", {8, 8}, kFloat); + + Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); + LoopNest l({tensor}); + + ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); +} + +TEST(Reductions, ReductionVectorizeRfactor) { + std::vector in_(8 * 8); + for (const auto i : c10::irange(8)) { + for (const auto j : c10::irange(8)) { + in_[i * 8 + j] = i; + } + } + std::vector out_before(1, -1.f); + std::vector out_after(1, -1.f); + + BufHandle in("in", {8, 8}, kFloat); + + Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8}); + + LoopNest l_before({tensor}); + LoopNest l(l_before); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); + cg_before.call({in_, out_before}); + + ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); + + // But if we rfactor this so it's not a reduce axis we can vectorize that + // loop. + std::vector loops = l.getLoopStmtsFor(tensor); + LoopNest::reorderAxis(loops[0], loops[1]); + loops = l.getLoopStmtsFor(tensor); + auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1]; + BufPtr rfac_buf = nullptr; + ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf)); + + LoopNest::distributeLoop(loops.at(0)); + auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf); + + ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0])); + l.simplify(); + + StmtPtr s = LoopNest::sanitizeNames(l.root_stmt()); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: sum = 0.f; +#CHECK: for (int i = 0; i < 8; i++) { +#CHECK: sum_rfac[i] = 0.f; +#CHECK: } +#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) { +#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1}); +#CHECK: } +#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) { +#CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2}); +#CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // Vectorizing should not change result. + l.prepareForCodegen(); + s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(s, {in, tensor}); + cg_after.call({in_, out_after}); + + ASSERT_EQ(out_before[0], out_after[0]); +} + +TEST(Reductions, InitFunction) { + constexpr int M = 32; + constexpr int N = 16; + BufHandle A("A", {M, N}, kFloat); + BufHandle B("B", {N}, kFloat); + Tensor C = Reduce( + "C", + {N}, + Sum(), + [&](const std::vector& v) { return B.load(v[0]); }, + [&](const std::vector& v) { return A.load(v[1], v[0]); }, + {M}); + LoopNest nest({C}); + nest.prepareForCodegen(); + StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt())); + std::ostringstream oss; + oss << *s << "\n"; + const std::string& expected_ir = + R"IR( +#CHECK: for (int i = 0; i < 16; i++) { +#CHECK: C[i] = B[i]; +#CHECK: for (int j = 0; j < 32; j++) { +#CHECK: C[i] = (C[i]) + (A[i + 16 * j]); +#CHECK: } +#CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp new file mode 100644 index 000000000000..6cbd04264c32 --- /dev/null +++ b/test/cpp/tensorexpr/test_registerizer.cpp @@ -0,0 +1,3702 @@ +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include "test/cpp/tensorexpr/test_utils.h" +#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" +#include "torch/csrc/jit/tensorexpr/registerizer.h" + +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +// Can replace a simple scalar access with a local variable. +TEST(Registerizer, RegisterizerSimple) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Won't do replacement of a loop access. +TEST(Registerizer, RegisterizerLoop) { + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * } + */ + + // No change. + stmt = registerize(stmt); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK: A[0] = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A_ +# CHECK: A[x] = +# CHECK-NOT: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Won't replace even if the load is a fixed scalar, since the store could +// invalidate it. +TEST(Registerizer, RegisterizerLoopFixedLoad) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[x] = (A[0]) + x; + * } + */ + + // No change. + stmt = registerize(stmt); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[x] = (A[0]) + x; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK: A[0] = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A_ +# CHECK: A[x] = +# CHECK-NOT: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// We can registerize accesses that occur entirely within inner scopes, even if +// they depend on the loop var. +TEST(Registerizer, RegisterizerLoopInternal) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)), + Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * A[x] = (A[x]) + x; + * } + */ + + stmt = registerize(stmt); + + // TODO: the order of terms in addition changes and in general depends on + // some hash value. This results in unpredictable swaps of the operands from + // random changes, which is not great. Ideally, we should ensure some + // specific order (ideally, the original one). + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * A_1 = x + A_1; + * A_1 = x + A_1; + * A[x] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: int A_1 = A[x]; +# CHECK: A_1 = A_1 + x; +# CHECK: A_1 = A_1 + x; +# CHECK: A[x] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An access can be overlapped by another read in the same Expr. In this case +// B[z] and B[y] overlap and prevent registerization of both accesses. +TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))}); + stmt = IRSimplifier::simplify(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (B[y]) + (B[z]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoopInternalRepeated) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), + Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), + Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})) + + }); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = x + (A[1]); + * A[0] = x + (A[1]); + * } + * for (int x = 0; x < 10; x++) { + * A[0] = x + (A[1]); + * A[0] = x + (A[1]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[1]; + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_1 + x; + * A_2 = A_1 + x; + * } + * for (int x = 0; x < 10; x++) { + * A_2 = A_1 + x; + * A_2 = A_1 + x; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[1]; +# CHECK: int A_2 = A[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: A_2 = A_1 + x; +# CHECK: A_2 = A_1 + x; +# CHECK: } +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: A_2 = A_1 + x; +# CHECK: A_2 = A_1 + x; +# CHECK: } +# CHECK-NOT: A[1] +# CHECK: A[0] = A_2; +# CHECK-NOT: A[1] +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), + Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), + Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})) + + }); + stmt = IRSimplifier::simplify(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = IRSimplifier::simplify(Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), + Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), + Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})) + + })); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Will registerize multiple accesses of different items of the same buffer. +TEST(Registerizer, RegisterizerMultiVar) { + BufHandle a("A", {2}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({ + Store::make(a, {0}, 0), + Store::make(a, {1}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), + Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), + }); + + /* + * A[0] = 0; + * A[1] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * A[1] = (A[1]) - x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * int A_2 = 0; + * for (int x = 0; x < 10; x++) { + * A_2 = x + A_2; + * A_1 = A_1 - x; + * } + * A[1] = A_2; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: int A_2 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A_2 = +# CHECK: A[1] = A_2 +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Will registerize the valid accesses while skipping invalid replacements. +TEST(Registerizer, RegisterizerVariableLoad) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle x2("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make(x, 0, 10, Store::make(b, {x}, x)), + For::make( + x2, + 0, + 10, + Block::make({Store::make( + a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = x; + * } + * for (int x_1 = 0; x_1 < 10; x_1++) { + * A[0] = (A[0]) + (B[x_1]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = x; + * } + * for (int x_1 = 0; x_1 < 10; x_1++) { + * A_1 = A_1 + (B[x_1]); + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: B[x] = x +# CHECK: for (int x_1 = 0; x_1 < 10; x_1++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize variable accesses so long as the variable does not change. +TEST(Registerizer, RegisterizerSymbolicIndices) { + VarHandle i("i", kInt); + VarHandle N("N", kInt); + BufHandle a("A", {N}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {i}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))}); + + /* + * A[i] = 0; + * for (int x = 0; x < 10; x++) { + * A[i] = (A[i]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[i] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[i] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize accesses dependent on multiple loop vars. +TEST(Registerizer, RegisterizerMultiLoop) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + Block::make({Store::make( + a, + {0}, + Mul::make(Add::make(Load::make(a, {0}), x), y))})))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * A[0] = x * y + (A[0]) * y; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * A_1 = x * y + y * A_1; + * } + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: for (int y = 0; y < 10; y++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize correctly if scalars already exist in the program. +TEST(Registerizer, RegisterizerRepeated) { + BufHandle a("A", {2}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({ + Store::make(a, {0}, 0), + Store::make(a, {1}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), + Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), + }); + + // Registerize manually to make sure we only replace a single target. + { + registerizer::RegisterizerAnalysis analysis; + stmt->accept(&analysis); + auto candidates = analysis.getCandidates(); + ASSERT_EQ(candidates.size(), 2); + + candidates.pop_back(); + registerizer::RegisterizerReplacer replacer(candidates); + stmt = stmt->accept_mutator(&replacer); + } + + // Re-analyze and replace the second target. + { + registerizer::RegisterizerAnalysis analysis; + stmt->accept(&analysis); + auto candidates = analysis.getCandidates(); + ASSERT_EQ(candidates.size(), 1); + + registerizer::RegisterizerReplacer replacer(candidates); + stmt = stmt->accept_mutator(&replacer); + } + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: int A_1_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A_1_1 = +# CHECK: A[1] = A_1_1; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize the load of A. +TEST(Registerizer, RegisterizerNoLoads) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = x + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + 1; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize the load of A but not the store of B. +TEST(Registerizer, RegisterizerNoRepeatedStores) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + // TODO: its unnecessary to reorder the initializer of A[0], but it's not + // actually worse so lets not worry for now. + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = x + A_1; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A_ +# CHECK: B[x] = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Won't registerize if there are multiple accesses which may overlap. +TEST(Registerizer, RegisterizerMultiVarOverlap) { + BufHandle a("A", {2}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({ + Store::make(a, {0}, 0), + Store::make(a, {1}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)), + Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})), + }); + stmt = IRSimplifier::simplify(stmt); + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerAllocs) { + BufHandle a("A", {2}, kInt); + BufHandle c("C", {1}, kInt); + VarHandle x("x", kInt); + + BufHandle b("B", {Load::make(c, {0})}, kInt); + + StmtPtr stmt = Block::make( + {Allocate::make(b), + Store::make(a, {0}, Load::make(c, {0})), + Store::make(b, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)), + Store::make(a, {0}, Load::make(c, {0}))})), + Free::make(b)}); + + /* + * Allocate(B, int, {C[0]}); + * A[0] = C[0]; + * B[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[0] = (B[0]) + x; + * A[0] = C[0]; + * } + * Free(B); + */ + + stmt = registerize(stmt); + + /* + * int C_1 = C[0]; + * Allocate(B, int, {C_}); + * int A_1 = C_1; + * int B_1 = 0; + * for (int x = 0; x < 10; x++) { + * B_1 = B_1 + x; + * A_1 = C_1; + * } + * B[0] = B_1; + * A[0] = A_1; + * Free(B); + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int C_1 = C[0]; +# CHECK: Allocate(B +# CHECK: int A_1 = C_1; +# CHECK: int B_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: B_1 = +# CHECK: A_1 = C_ +# CHECK: B[0] = B_1; +# CHECK: A[0] = A_1; +# CHECK: Free(B)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerNoInitializer) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerNoInitializerLoopVar) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); + stmt = IRSimplifier::simplify(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoadThenStore) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)), + Store::make(a, {0}, Load::make(b, {0}))}))}); + + /* + * for (int x = 0; x < 10; x++) { + * B[0] = (A[0]) + x; + * A[0] = B[0]; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * int B_1 = B[0]; + * for (int x = 0; x < 10; x++) { + * B_1 = x + A_1; + * A_1 = B_1; + * } + * B[0] = B_1; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: int B_1 = B[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: B[ +# CHECK: B_1 = +# CHECK-NOT: A[ +# CHECK: A_1 = B_ +# CHECK: B[0] = B_ +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerParallelized) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + LoopOptions loopOpts; + loopOpts.set_gpu_block_index(0); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}), + loopOpts)}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + ASSERT_THROWS_WITH( + registerize(stmt), + "Registerization must occur after parallelism flattening"); +} + +// Should be able to registerize this since the scalar would exist before the +// branch. +TEST(Registerizer, RegisterizerConditionAfter) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * C[x] = A_1; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Should be able to registerize this since the scalar exists in the same form +// after the branch and there is no overlap. +TEST(Registerizer, RegisterizerConditionBefore) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr), + Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x}))}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * A[x] = B[x]; + * C[x] = A[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_ 1 = A[x]; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A_1 = B[x]; + * C[x] = A_1; + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Should be able to registerize this as the combination of the two above rules. +TEST(Registerizer, RegisterizerConditionInside) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr), + Store::make(b, {x}, Load::make(a, {x})), + Store::make(a, {x}, Load::make(c, {x}))}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * A[x] = C[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * C[x] = A_1; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * B[x] = A_1; + * A_1 = C[x]; + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: B[x] = A_1; +# CHECK: A_1 = C[x]; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An example where an access is cut by an overlapping access inside a +// condition, and both sides are large enough to be registerized but cannot be +// because there is no safe place to put the initializer or finalizer. +TEST(Registerizer, RegisterizerConditionInsideOverlap1) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make( + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({ + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + Store::make(a, {0}, 3), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + }), + nullptr), + Store::make(b, {x}, Load::make(a, {x})), + Store::make(a, {x}, Load::make(c, {x}))}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * A[x] = C[x]; + */ + + // The A[0] store overlaps, A[x] cutting the region that can be registerized + // into two groups. + // Each group has 2 loads and 2 stores however, so we could registerize it, + // but the first group would need to be finalized inside the condition block, + // the second would need to be initialized inside the condition block. There's + // no safe place to put these that's visible to the other uses in the group + // and so neither registerization is possible. + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Same as the above, but the access group before the condition (and after the +// condition) are large enough to be registerized without needing the access +// from the loop. Registerization occurs but does not include any accesses in +// the condition, and the first group must be finalized before the Cond, the +// second initialized after it. +TEST(Registerizer, RegisterizerConditionInsideOverlap2) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make( + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(a, {x}, Load::make(b, {x + 1})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({ + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + Store::make(a, {0}, 3), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + }), + nullptr), + Store::make(b, {x}, Load::make(a, {x})), + Store::make(b, {x + 1}, Load::make(a, {x})), + Store::make(a, {x}, Load::make(c, {x}))}); + + /* + * A[x] = B[x]; + * A[x] = B[x + 1]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * B[x + 1] = A[x]; + * A[x] = C[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; // A_1 initializer + * A_1 = B[x + 1]; // + * C[x] = A_1; // + * A[x] = A_1; // A_1 finalizer + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * int A_2 = A[x]; // A_2 initializer + * B[x] = A_2; // + * B[x + 1] = A_2; // + * A_2 = C[x]; // + * A[x] = A_2; // A_2 finalizer + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: A_1 = B[x + 1]; +# CHECK: C[x] = A_1; +# CHECK: A[x] = A_1; +# CHECK: if ( +# CHECK-NOT: A_1 = A_1 + 1; +# CHECK: A[x] = (A[x] +# CHECK: A[0] = +# CHECK: A[x] = (A[x] +# CHECK: } +# CHECK: int A_2 = A[x]; +# CHECK: B[x] = A_2; +# CHECK: B[x + 1] = A_2; +# CHECK: A_2 = C[x]; +# CHECK: A[x] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// When accesses are within conditional blocks they are not visible to the wider +// program, because we don't know if the branch would be taken and if it isn't +// the accesses in it don't need to be valid (think size checks on the index). +// In this case the accesses cannot be registerized. +TEST(Registerizer, RegisterizerConditionHidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * if (x>5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// But... if the same access is found in a non conditional scope, that means +// that that access is valid in the higher scope (or at least if its not it's +// the user's fault). It "unhides" the conditional accesses, allowing +// registerization to occur. +TEST(Registerizer, RegisterizerConditionUnhidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * A[x] = (A[x]) + 1; <-- this is doing the unhiding. + * if (x>5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A_1 = A_1 + 1; + * if (x>5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if (x<5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x>5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a load that occurs in the condition of a Cond. +TEST(Registerizer, RegisterizerCondCondition) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make( + Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), + nullptr)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if ((A[x])<5 ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * int C_1 = A_1; + * if (A_1<5 ? 1 : 0) { + * C_1 = C_1 + 1; + * } + * C[x] = C_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: int C_1 = A_1; +# CHECK: if (A_1<5 +# CHECK: C_1 = C_1 + 1; +# CHECK: C[x] = C_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Appearing in the condition of a Cond makes it visible to the enclosing scope, +// and so we can registerize internal usages. +TEST(Registerizer, RegisterizerCondConditionUnhidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))}); + + /* + * if ((A[x])<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } else { + * A[x] = (A[x]) + 10; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * if (A_1<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } else { + * A_1 = A_1 + 10; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if (A_1<5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } else { +# CHECK: A_1 = A_1 + 10; +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Conditional hiding also works for IfThenElse exprs. +TEST(Registerizer, RegisterizerIfThenElseHidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make( + {Store::make( + b, + {y}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x + 1}), 2))), + Store::make( + b, + {y + 1}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x + 1}), 2)))}); + + /* + * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Conditional unhiding also works for IfThenElse exprs. +TEST(Registerizer, RegisterizerIfThenElseUnhidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make({ + Store::make(a, {x}, 0), + Store::make( + b, + {y}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x + 1}), 2))), + Store::make( + b, + {y + 1}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x + 1}), 2))), + }); + + /* + * A[x] = 0; + * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); +# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Nested IfThenElse exprs can't promote to higher level scopes. +TEST(Registerizer, RegisterizerIfThenElseNested) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + BufHandle d("D", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + IfThenElse::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Load::make(d, {x}), + Load::make(b, {x})), + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kEQ), + Load::make(c, {x}), + Load::make(d, {x}))))}); + + /* + * A[x] = IfThenElse(x<3 ? 1 : 0, + * IfThenElse(x==2 ? 1 : 0, D[x], B[x]), + * IfThenElse(x==5 ? 1 : 0, C[x], D[x])); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Cannot registerize an access completely contained within an IfThenElse +// branch, since it is not a Stmt and cannot hold variable definitions. We need +// to check that we don't promote the initializer/finalizer to the enclosing +// Block. +TEST(Registerizer, RegisterizerIfThenElseInternal) { + // Making these floats so they don't get simplified to a single access. + BufHandle a("A", {5}, kFloat); + BufHandle b("B", {5}, kFloat); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Add::make(Load::make(b, {x}), Load::make(b, {x})), + Load::make(b, {x})))}); + + /* + * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + // If this was a Cond instead of an IfThenElse then we could registerize the + // two accesses to B[x] in the True branch. + + // Actually lets verify that. + + stmt = Block::make({Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))), + Store::make(a, {x}, Load::make(b, {x})))}); + + /* + * if (x<3 ? 1 : 0) { + * A[x] = (B[x]) + (B[x]); + * } else { + * A[x] = B[x]; + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<3 ? 1 : 0) { + * float B_1 = B[x]; + * A[x] = B_1 + B_1; + * } else { + * A[x] = B[x]; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK-NOT: float +# CHECK: if (x<3 +# CHECK: float B_1 = +# CHECK: A[x] = B_1 + B_1 +# CHECK: } else { +# CHECK: A[x] = B[x] +# CHECK: } +# CHECK-NOT: A[x] +# CHECK-NOT: B[x])IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a load that occurs in the condition of an IfThenElse; +TEST(Registerizer, RegisterizerIfThenElseCondition) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Store::make(a, {x}, Load::make(a, {x})), + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Load::make(b, {0}), + Load::make(c, {0})))}); + + /* + * A[x] = A[x]; <---- just here so there are enough accesses to combine. + * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * A_1 = A_1; + * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Appearing in the condition of a Cond makes it visible to the enclosing scope, +// and so we can registerize internal usages. +TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Store::make( + b, + {x}, + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x}), 10)))}); + + /* + * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10); + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Cannot promote accesses internal to IfThenElse branches even if the enclosing +// scope if conditional. +TEST(Registerizer, RegisterizerConditionBranchOnly) { + BufHandle a("A", {5}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make({ + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), x), + Add::make(Load::make(a, {x - 5}), x))), + Store::make( + a, + {x - 5}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), x), + Add::make(Load::make(a, {x - 5}), x)))), + }))}); + stmt = IRSimplifier::simplify(stmt); + + std::ostringstream before; + before << *stmt; + + /* for (int x = 0; x < 10; x++) { + * if (x<5 ? 1 : 0) { + * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); + * } else { + * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); + * } + * } + */ + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// We can registerize an IfThenElse that appears in the condition branch of a +// Cond. This is a weird but valid thing to do. +TEST(Registerizer, RegisterizerCondIfThenElse) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make( + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Load::make(a, {x}), + Load::make(b, {x})), + x, + CompareSelectOperation::kEQ), + Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), + nullptr)}); + + /* + * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + // access to A can be registerized, but not B or C + + /* + * int A_1 = A[x]; + * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x] +# CHECK: C[x] = (C[x]) + 1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a conditional access in the RHS of a store unhidden by it's +// LHS, and hoist it out of a loop. +TEST(Registerizer, RegisterizerIfThenElseLoop) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = For::make( + y, + 0, + 10, + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Load::make(a, {x}), + Load::make(b, {y})))); + + /* + * for (int y = 0; y < 10; y++) { + * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * for (int y = 0; y < 10; y++) { + * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: for ( +# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Cannot registerize if the RHS overlaps the access creating visibility. +TEST(Registerizer, RegisterizerIfThenElseLoopCut) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make({For::make( + y, + 0, + 10, + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Load::make(a, {x}), + Load::make(a, {y}))))}); + + /* + * for (int y = 0; y < 10; y++) { + * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Simple case where an access is cut by an overlapping access later in the +// program, we can registerize up until the overlap. +TEST(Registerizer, RegisterizerPartialAfter) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})), + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x] = A[x - 1]; +# CHECK: } +# CHECK-NOT: A)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// We can registerize an access which overlaps a previous access, the +// initializer must be inserted after the previous access. +TEST(Registerizer, RegisterizerPartialBefore) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), + Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK: for ( +# CHECK: A[x] = A[x - 1]; +# CHECK: } +# CHECK: int A_1 = 0; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// The combination of the previous two tests, an access is cut by an overlapping +// access in both directions. +TEST(Registerizer, RegisterizerPartialInside) { + BufHandle a("A", {1}, kInt); + VarHandle x1("x1", kInt); + VarHandle x2("x2", kInt); + VarHandle x3("x3", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 2), + For::make( + x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))), + For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))), + For::make( + x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))}); + + /* + * A[0] = 2; + * for (int x1 = 0; x1 < 10; x1++) { + * A[0] = (A[0]) + x1; + * } + * for (int x2 = 1; x2 < 10; x2++) { + * A[x2] = A[x2 - 1]; + * } + * for (int x3 = 0; x3 < 10; x3++) { + * A[0] = (A[0]) + x3; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 2; + * for (int x1 = 0; x1 < 10; x1++) { + * A_1 = A_1 + x1; + * } + * A[0] = A_1; + * for (int x2 = 1; x2 < 10; x2++) { + * A[x2] = A[x2 - 1]; + * } + * int A_2 = A[0]; + * for (int x3 = 0; x3 < 10; x3++) { + * A_2 = A_2 + x3; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 2; +# CHECK: for ( +# CHECK: A_1 = A_1 + x1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x2] = +# CHECK: } +# CHECK: int A_2 = A[0]; +# CHECK: for ( +# CHECK: A_2 = A_2 + x3; +# CHECK: } +# CHECK: A[0] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An element could be registerized program wide but is cut by a conditional +// access, we should break this into two scalars and write back to the buffer +// before the condition. +TEST(Registerizer, RegisterizerPartialCondition) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 2), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Load::make(a, {x - 1})), + nullptr), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))}); + + /* + * A[0] = 2; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + * if (x<5 ? 1 : 0) { + * A[x] = A[x - 1]; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 2; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + * if (x<5 ? 1 : 0) { + * A[x] = A[x - 1]; + * } + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + x; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 2; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: if ( +# CHECK: A[x] = +# CHECK: } +# CHECK: int A_2 = A[0]; +# CHECK: for ( +# CHECK: A_2 = A_2 + x; +# CHECK: } +# CHECK: A[0] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Tests case where an access is cut by an internal conditional access which +// itself is registerized. +TEST(Registerizer, RegisterizerPartialConditionInternalCut) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 1), + Store::make(a, {0}, 3), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), + nullptr), + Store::make(a, {0}, 4), + Store::make(a, {0}, 6)}); + + /* + * A[0] = 1; + * A[0] = 3; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * A[x] = 3; + * } + * A[0] = 4; + * A[0] = 6; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 1; + * A_1 = 3; + * A[0] = A_1; + * if (x<5 ? 1 : 0) { + * int A_2 = 1; + * A_2 = 3; + * A[x] = A_2; + * } + * int A_3 = 4; + * A_3 = 6; + * A[0] = A_3; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 1; +# CHECK: A_1 = 3 +# CHECK: A[0] = A_1; +# CHECK: if ( +# CHECK: int A_2 = 1; +# CHECK: A_2 = 3; +# CHECK: A[x] = A_2; +# CHECK: } +# CHECK: int A_3 = 4; +# CHECK: A_3 = 6; +# CHECK: A[0] = A_3;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// First statement in condition closes outer access, but can be registerized +// with later statements. +TEST(Registerizer, RegisterizerPartialConditionInternalStart) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 1), + Store::make(a, {0}, 3), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), + nullptr), + Store::make(a, {x}, 4), + Store::make(a, {x}, 6)}); + + /* + * A[0] = 1; + * A[0] = 3; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * A[x] = 3; + * } + * A[x] = 4; + * A[x] = 6; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 1; + * A_1 = 3; + * A[0] = A_1; + * int A_2 = A[x]; <--- must read from the input here. + * if (x<5 ? 1 : 0) { + * A_2 = 1; + * A_2 = 3; + * } + * A_2 = 4; + * A_2 = 6; + * A[x] = A_2; + */ + + // TODO: I suppose we could refactor with a conditional initializer? + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 1; +# CHECK: A_1 = 3 +# CHECK: A[0] = A_1; +# CHECK: int A_2 = A[x]; +# CHECK: if ( +# CHECK: A_2 = 1; +# CHECK: A_2 = 3; +# CHECK: } +# CHECK: A_2 = 4; +# CHECK: A_2 = 6; +# CHECK: A[x] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An access cuts two open overlaps and creates four scalar variables. +TEST(Registerizer, RegisterizerPartialOverlapsTwo) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {1}, Load::make(a, {0})), + Store::make(a, {0}, Load::make(a, {1})), + Store::make(a, {0}, Load::make(a, {1})), + For::make(x, 1, 10, Store::make(a, {x}, x)), + Store::make(a, {1}, Load::make(a, {0})), + Store::make(a, {0}, Load::make(a, {1})), + Store::make(a, {0}, Load::make(a, {1}))}); + + /* + * A[1] = A[0]; + * A[0] = A[1]; + * A[0] = A[1]; + * for (int x = 1; x < 10; x++) { + * A[x] = x; + * } + * A[1] = A[0]; + * A[0] = A[1]; + * A[0] = A[1]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * int A_2 = A_1; + * A_1 = A_2; + * A_1 = A_2; + * A[1] = A_2; + * A[0] = A_1; + * for (int x = 1; x < 10; x++) { + * A[x] = x; + * } + * int A_3 = A[0]; + * int A_4 = A_3; + * A_3 = A_4; + * A_3 = A_4; + * A[1] = A_4; + * A[0] = A_3; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: int A_2 = A_1; +# CHECK: A_1 = A_2; +# CHECK: A_1 = A_2; +# CHECK: A[1] = A_2; +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x] = x; +# CHECK: } +# CHECK: int A_3 = A[0]; +# CHECK: int A_4 = A_3; +# CHECK: A_3 = A_4; +# CHECK: A_3 = A_4; +# CHECK: A[1] = A_4; +# CHECK: A[0] = A_3;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Nested blocks will automatically be flattened and do not provent +// registerization of enclosed accesses. +TEST(Registerizer, RegisterizerNestedBlocks) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})}); + + /* + * A[0] = (A[0]) + 1; + * { + * A[0] = (A[0]) + 2; + * } + * { + * A[0] = (A[0]) + 3; + * { + * A[0] = (A[0]) + 4; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * A_1 = A_1 + 2; + * A_1 = A_1 + 3; + * A_1 = A_1 + 4; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: A_1 = A_1 + 2; +# CHECK: A_1 = A_1 + 3; +# CHECK: A_1 = A_1 + 4; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// The access can be registerized internally to a condition, but must ensure +// that both initializer and finalizer are within the same condition. +TEST(Registerizer, RegisterizerNestedConditions) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * if (x==2 ? 1 : 0) { + * + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * if (x==2 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x==2 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// If an access exists outside the scope of the condition then we can lift +// nested conditional usages into the same scalar. +TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {1}, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr)}); + + /* + * A[0] = (A[0]) + 1; + * if (x<5 ? 1 : 0) { + * A[1] = 1; + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * if (x<5 ? 1 : 0) { + * A[1] = 1; + * if (x==2 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x<5 +# CHECK: A[1] = 1; +# CHECK: if (x==2 +# CHECK: A_1 = A_1 + 1; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * if (x<5 ? 1 : 0) { + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + stmt = registerize(stmt); +} + +TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + stmt = registerize(stmt); +} + +// If an access is cut by another access internal to a condition block, it still +// cuts the access. +TEST(Registerizer, RegisterizerNestedConditionsCut) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {x}, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr)}); + + /* + * A[0] = (A[0]) + 1; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * if (x==2 ? 1 : 0) { + * + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}))}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * for (int x = 0; x < 10; x++) { + * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering. + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Three loops and four element regions, three of which should be registerized +// at different levels of the IR. +TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {4}, 0), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kGT), + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kGT), + Block::make({ + Cond::make( + CompareSelect::make(x, 4, CompareSelectOperation::kGT), + Block::make({ + Store::make( + a, {1}, Add::make(Load::make(a, {1}), 1)), + Store::make( + a, {2}, Add::make(Load::make(a, {2}), 1)), + Store::make( + a, {3}, Add::make(Load::make(a, {3}), 1)), + Store::make( + a, {4}, Add::make(Load::make(a, {4}), 1)), + Store::make( + a, {1}, Add::make(Load::make(a, {1}), 1)), + }), + nullptr), + Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)), + }), + nullptr), + nullptr)}); + + /* + * A[4] = 0; + * if (x>2 ? 1 : 0) { + * if (x>3 ? 1 : 0) { + * if (x>4 ? 1 : 0) { + * A[1] = (A[1]) + 1; + * A[2] = (A[2]) + 1; + * A[3] = (A[3]) + 1; + * A[4] = (A[4]) + 1; + * A[1] = (A[1]) + 1; + * } + * A[2] = (A[2]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * if (x>2 ? 1 : 0) { + * if (x>3 ? 1 : 0) { + * int A_3 = A[2]; + * if (x>4 ? 1 : 0) { + * int A_2 = A[1]; + * A_2 = A_2 + 1; + * A_3 = A_3 + 1; + * A[3] = (A[3]) + 1; + * A_1 = A_1 + 1; + * A_2 = A_2 + 1; + * A[1] = A_2; + * } + * A_3 = A_3 + 1; + * A[2] = A_3; + * } + * } + * A[4] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: if (x>2 ? 1 : 0) { +# CHECK: if (x>3 ? 1 : 0) { +# CHECK: int A_3 = A[2]; +# CHECK: if (x>4 ? 1 : 0) { +# CHECK: int A_2 = A[1]; +# CHECK: A_2 = A_2 + 1; +# CHECK: A_3 = A_3 + 1; +# CHECK: A[3] = (A[3]) + 1; +# CHECK: A_1 = A_1 + 1; +# CHECK: A_2 = A_2 + 1; +# CHECK: A[1] = A_2; +# CHECK: } +# CHECK: A_3 = A_3 + 1; +# CHECK: A[2] = A_3; +# CHECK: } +# CHECK: } +# CHECK: A[4] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can replace a simple scalar access with a local variable even when that +// variable is an outer loop var. +TEST(Registerizer, RegisterizerNestedLoopSimple) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make({For::make( + y, + 0, + 10, + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))}); + + /* + * for (int y = 0; y < 10; y++) { + * for (int x = 0; x < 10; x++) { + * A[y] = (A[y]) + x; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * for (int y = 0; y < 10; y++) { + * int A_1 = A[y]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[y] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int y +# CHECK: int A_1 = A[y]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[y] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Test the positive case of the hiddenAccess split, where an internal +// conditional access can be hoisted up through a loop to match an existing +// access in a higher scope and the two can be registerized. +TEST(Registerizer, RegisterizerHiddenAccessYes) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0), + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kEQ), + For::make( + y, + 0, + 10, + Store::make( + a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr)}))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A[0] = (A[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A_1 = A_1 + 1; + * } + * } + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: int A_1 = 0; +# CHECK: for (int x +# CHECK: B[x] = 0; +# CHECK: if (x==3 +# CHECK: for (int y +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Test the negative case of the hiddenAccess split, where the hoisted access is +// never unhidden at a higher scope and registerization occurs at the lower +// scope. +TEST(Registerizer, RegisterizerHiddenAccessNo) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0), + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kEQ), + For::make( + y, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr)}))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A[0] = (A[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * int A_1 = A[0]; + * for (int y = 0; y < 10; y++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * } + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: for (int x +# CHECK: B[x] = 0; +# CHECK: if (x==3 +# CHECK: int A_1 = A[0]; +# CHECK: for (int y +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: } +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// In this case the conditional access must be hoisted by two loops, there are +// two accesses here one is unhidden and the other isn't. A[0] can be +// registerized but B[0] cannot. +TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + Block::make({Cond::make( + CompareSelect::make(y, 3, CompareSelectOperation::kEQ), + Block::make( + {Store::make( + a, {0}, Add::make(Load::make(a, {0}), 1)), + Store::make( + b, {0}, Add::make(Load::make(b, {0}), 1))}), + nullptr)})))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * if (y==3 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * B[0] = (B[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * if (y==3 ? 1 : 0) { + * A_1 = A_1 + 1; + * B[0] = (B[0]) + 1; + * } + * } + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: int A_1 = 0; +# CHECK: for (int x +# CHECK: for (int y +# CHECK: if (y==3 +# CHECK: A_1 = A_1 + 1; +# CHECK: B[0] = (B[0]) + 1; +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Accesses are registerized inside two conditions, but the immediate parent is +// not a condition. +TEST(Registerizer, RegisterizerTwoConditionalLoops) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + * if (x>5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * if (x>5 ? 1 : 0) { + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + 1; + * } + * A[0] = A_2; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: if (x>5 +# CHECK: int A_2 = A[0]; +# CHECK: for (int x +# CHECK: A_2 = A_2 + 1; +# CHECK: } +# CHECK: A[0] = A_2; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Accesses are registerized inside two conditions, cut in the middle. +TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr), + For::make(x, 0, 10, Store::make(a, {x}, 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + * for (int x = 0; x < 10; x++) { + * A[x] = 1; + * } + * if (x>5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * for (int x = 0; x < 10; x++) { + * A[x] = 1; + * } + * if (x>5 ? 1 : 0) { + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + 1; + * } + * A[0] = A_2; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: for (int x +# CHECK: A[x] = 1; +# CHECK: if (x>5 +# CHECK: int A_2 = A[0]; +# CHECK: for (int x +# CHECK: A_2 = A_2 + 1; +# CHECK: } +# CHECK: A[0] = A_2; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// references a Let var in a local scope which cannot be hoisted out of the +// loop. +TEST(Registerizer, RegisterizerLoopLetVar) { + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make( + x, + 0, + 10, + Block::make( + {Let::make(y, 30), + Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))})); + + /* + * for (int x = 0; x < 10; x++) { + * int y = 30; + * A[y] = x + (A[y]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// references a Let var in an outer scope that does not prevent hoisting the +// initializer. +TEST(Registerizer, RegisterizerLoopLetVarOuter) { + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make( + {Let::make(y, 30), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}); + + /* + * int y = 30; + * for (int x = 0; x < 10; x++) { + * A[y] = x + (A[y]); + * } + */ + + stmt = registerize(stmt); + + /* + * int y = 30; + * int A_1 = A[y]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[y] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int y = 30; +# CHECK: int A_1 = A[y]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + x; +# CHECK: A[y] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Okay so the registerizer generally goes after index flattening, but just in +// case. Test multi index registerization. +TEST(Registerizer, RegisterizerMultiDim) { + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, 1, 2] = (A[0, 1, 2]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[0, 1, 2] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0, 1, 2] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Won't registerize if only some dims match, but will still registerize +// distinct elements. +TEST(Registerizer, RegisterizerMultiDimPartial) { + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, 2, 2] = (A[0, 1, 4]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * A[0, 1, 2] = 0; + * int A_1 = A[0, 1, 4]; + * int A_2 = A[0, 2, 2]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_1 + x; + * } + * A[0, 2, 2] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: A[0, 1, 2] = 0; +# CHECK: int A_1 = A[0, 1, 4]; +# CHECK: int A_2 = A[0, 2, 2]; +# CHECK: for ( +# CHECK: A_2 = A_1 + x; +# CHECK: A[0, 2, 2] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// If they could overlap across all dimensions we cannot registerize. +TEST(Registerizer, RegisterizerMultiDimOverlap) { + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))}); + stmt = IRSimplifier::simplify(stmt); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = (A[y, 2, 2]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// But, if one dimension is known to be distinct they do not overlap. +TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))}); + + /* + * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store. + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff. + * } + */ + + stmt = registerize(stmt); + + /* + * A[0, 1, 2] = 0; + * int A_1 = A[y, 2, 4]; + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = A_1 + x; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: A[0, 1, 2] = 0; +# CHECK: int A_1 = A[y, 2, 4]; +# CHECK: for ( +# CHECK: A[0, x, 2] = A_1 + x; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// A 3D reduction with different input dimensionality. +TEST(Registerizer, RegisterizerMultiDim3DReduction1) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10, 10}, kInt); + BufHandle c("C", {10, 10, 10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + StmtPtr stmt = For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + For::make( + z, + 0, + 10, + Store::make( + c, + {x, y, z}, + Add::make( + Load::make(c, {x, y, z}), + Mul::make(Load::make(b, {x, y}), Load::make(a, {x}))))))); + + /* + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 10; z++) { + * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]); + * } + * } + * } + */ + + // We can registerize the A and B access since they can be hoisted before + // hitting a dependent loop var. + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * for (int y = 0; y < 10; y++) { + * int B_1 = B[x, y]; + * for (int z = 0; z < 10; z++) { + * C[x, y, z] = A_1 * B_1 + (C[x, y, z]); + * } + * } + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x +# CHECK: int A_1 = A[x]; +# CHECK: for (int y +# CHECK: int B_1 = B[x, y]; +# CHECK: for (int z +# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]); +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// A 3D reduction with the same smaller dimensionality using different loop +// vars. +TEST(Registerizer, RegisterizerMultiDim3DReduction2) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + StmtPtr stmt = For::make( + x, + 0, + 10, + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + For::make( + y, + 0, + 10, + For::make( + z, + 0, + 10, + Store::make( + c, + {x}, + Add::make( + Load::make(c, {x}), + Mul::make(Load::make(b, {y}), Load::make(a, {x}))))))); + + /* + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 10; z++) { + * C[x] = (C[x]) + (B[y]) * (A[x]); + * } + * } + * } + */ + + // We can registerize all accesses, the A and C access can be hoisted to the + // outer loop since they depend only on it's loop var while the B can only be + // raised to the loop of y. + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * int C_1 = C[x]; + * for (int y = 0; y < 10; y++) { + * int B_1 = B[y]; + * for (int z = 0; z < 10; z++) { + * C_1 = A_1 * B_1 + C_1; + * } + * } + * C[x] = C_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x +# CHECK: int A_1 = A[x]; +# CHECK: int C_1 = C[x]; +# CHECK: for (int y +# CHECK: int B_1 = B[y]; +# CHECK: for (int z +# CHECK: C_1 = A_1 * B_1 + C_1; +# CHECK: } +# CHECK: } +# CHECK: C[x] = C_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp new file mode 100644 index 000000000000..7ca2b74eaa76 --- /dev/null +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -0,0 +1,5680 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; +using SimpleIRExprEval = ExprEval; + +TEST(Simplify, ConstantFoldSimple) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle f = (a + b); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 5); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 5.f); +} + +TEST(Simplify, ConstantFoldTwoLayer) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle f = (a + b) - (c + d); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), -4); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), -4.f); +} + +TEST(Simplify, ConstantFoldShifts) { + ExprHandle a(7); + ExprHandle b(2); + ExprHandle c(3); + ExprHandle f = ((a << b) << b) >> c; + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 14); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 7 << (4 - 3)); +} + +TEST(Simplify, ConstantFoldBitwise) { + ExprHandle a(59); + ExprHandle b(22); + ExprHandle c(101); + ExprHandle f = (a ^ b) & c; + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 37); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), (59 ^ 22) & 101); +} + +TEST(Simplify, ConstantFoldMultiOp) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle e(6.0f); + ExprHandle f(7.0f); + ExprHandle fn = ((a / e) - (c + d)) * (f / b); + + ExprHandle newF = IRSimplifier::simplify(fn); + ASSERT_NE(newF.AsNode(), nullptr); + + SimpleIRExprEval eval(newF); + SimpleIRExprEval ref(fn); + + ASSERT_EQ(eval.value(), ref.value()); +} + +TEST(Simplify, ConstantFoldMinMax) { + ExprHandle a(12.0f); + ExprHandle b(15.0f); + ExprHandle c(17.0f); + + // x = max(12, min(15, 17)). + ExprHandle minHandle = Min::make(b, c, true); + ExprHandle fn = Max::make(a, minHandle, false); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float); + + ExprHandle newF = IRSimplifier::simplify(fn); + ASSERT_NE(newF.AsNode(), nullptr); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 15.f); +} + +TEST(Simplify, ConstantFoldIntrinsics) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle powHandle = Intrinsics::make(kPow, a, b); + ExprHandle sinHandle = Intrinsics::make(kSin, powHandle); + ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle); + ExprHandle logHandle = Intrinsics::make(kLog10, modHandle); + ExprHandle rndHandle = Intrinsics::make(kRound, logHandle); + ExprHandle fn = Intrinsics::make(kAbs, rndHandle); + + ExprHandle newF = IRSimplifier::simplify(fn); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + SimpleIRExprEval ref(fn); + + ASSERT_EQ(eval.value(), ref.value()); +} + +TEST(Simplify, ConstantFoldCastToBool) { + ExprHandle f = Cast::make(kBool, IntImm::make(0)); + ExprHandle newF = IRSimplifier::simplify(f); + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), false); +} + +TEST(Simplify, ConstantFoldWithVar) { + { + VarHandle x("x", kInt); + ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); + + ExprHandle newF = IRSimplifier::simplify(body); + MulPtr root = newF.AsNode(); + ASSERT_NE(root, nullptr); + ASSERT_NE(to(root->lhs()), nullptr); + + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(3)); + ASSERT_EQ(eval.value(), 3 * (2 + 4)); + } + + { + VarHandle x("x", kFloat); + ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); + + ExprHandle newF = IRSimplifier::simplify(body); + MulPtr root = newF.AsNode(); + ASSERT_NE(root, nullptr); + ASSERT_NE(to(root->rhs()), nullptr); + + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(3.f)); + ASSERT_EQ(eval.value(), 3 * (2 + 4)); + } +} + +TEST(Simplify, ConditionalSelectFoldSimple) { + ExprHandle a(3.0f); + ExprHandle b(4.0f); + ExprHandle c(3.0f); + { + ExprHandle f = (a > b); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 0); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 0); + } + { + ExprHandle f = (a < b); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 1); + } + { + ExprHandle f = (a == c); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 1); + } + { + ExprHandle f = (a != c); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 0); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 0); + } +} + +TEST(Simplify, ConditionalSelectFoldTwoLayer) { + ExprHandle a(3.0f); + ExprHandle b(2.0f); + ExprHandle c(2.0f); + ExprHandle d(1.0f); + { + ExprHandle f = (a + b < c + d); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 0); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 0); + } + { + ExprHandle f = (a + b > c + d); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 1); + } + { + ExprHandle f = (a + d == b + c); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 1); + } + { + ExprHandle f = (a + d != b + c); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 0); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 0); + } +} + +TEST(Simplify, ConditionalSelectFoldWithVar) { + VarHandle x("x", kFloat); + ExprHandle f = x < 4.f; + + ExprHandle newF = IRSimplifier::simplify(f); + IntImmPtr folded = newF.AsNode(); + ASSERT_EQ(folded, nullptr); + + { + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(3.f)); + ASSERT_EQ(eval.value(), 1); + } + { + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(5.f)); + ASSERT_EQ(eval.value(), 0); + } +} + +TEST(Simplify, UnFoldableExpr) { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); + + ExprHandle newF = IRSimplifier::simplify(body); + AddPtr root = newF.AsNode(); + ASSERT_NE(root, nullptr); + ASSERT_EQ(to(root->lhs()), nullptr); + ASSERT_EQ(to(root->rhs()), nullptr); + + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(3.f)); + eval.bindVar(y, ExprHandle(2.f)); + ASSERT_EQ(eval.value(), 9 + 10); +} + +TEST(Simplify, HashSimple) { + VarHandle x("x", kFloat); + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle f = a + b * x; + + HashProvider hasher; + + auto hash_x = hasher.hash(x.node()); + auto hash_a = hasher.hash(a.node()); + auto hash_f = hasher.hash(f.node()); + + ASSERT_NE(hash_x, (size_t)0); + ASSERT_NE(hash_a, (size_t)0); + ASSERT_NE(hash_f, (size_t)0); + ASSERT_NE(hash_x, hash_a); + ASSERT_NE(hash_x, hash_f); + ASSERT_NE(hash_a, hash_f); +} + +TEST(Simplify, HashEquivalence) { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle f = (x * y) + (x * y); + + AddPtr root = f.AsNode(); + ASSERT_NE(root, nullptr); + + HashProvider hasher; + auto hash_f = hasher.hash(f.node()); + auto hash_l = hasher.hash(root->lhs()); + auto hash_r = hasher.hash(root->rhs()); + + // Root not equal to either branch. + ASSERT_NE(hash_f, hash_l); + ASSERT_NE(hash_f, hash_r); + // but branches are equal. + ASSERT_EQ(hash_l, hash_r); + + // Still equivalent if separate. + ExprHandle a(2); + ExprHandle f2 = x + a / y; + ExprHandle b(2); + ExprHandle f3 = x + b / y; + ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node())); + + // Not equivalent if different vars (even with same name). + VarHandle z("x", kFloat); + ExprHandle f4 = z + b / y; + ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node())); + + // Intrinsics sanity check. + ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x); + ASSERT_NE(hasher.hash(f5.node()), (size_t)0); +} + +TEST(Simplify, HashEquivalenceRand) { + ExprHandle f = + Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); + + AddPtr root = f.AsNode(); + ASSERT_NE(root, nullptr); + + HashProvider hasher; + auto hash_f = hasher.hash(f.node()); + auto hash_l = hasher.hash(root->lhs()); + auto hash_r = hasher.hash(root->rhs()); + + // Root not equal to either branch. + ASSERT_NE(hash_f, hash_l); + ASSERT_NE(hash_f, hash_r); + // and branches are NOT equal. + ASSERT_NE(hash_l, hash_r); +} + +TEST(Simplify, HashEquivalenceAfterFolding) { + VarHandle x("x", kFloat); + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(5.0f); + + ExprHandle f1 = ((a + b) * x); + ExprHandle f2 = (c * x); + + HashProvider hasher; + auto hash_l = hasher.hash(f1.node()); + auto hash_r = hasher.hash(f2.node()); + + // Root not equal to either branch, and branches not equal. + ASSERT_NE(hash_l, hash_r); + + ExprHandle ff1 = IRSimplifier::simplify(f1); + ExprHandle ff2 = IRSimplifier::simplify(f2); + + auto hash_l_n = hasher.hash(ff1.node()); + auto hash_r_n = hasher.hash(ff2.node()); + // but branches are now equal. + ASSERT_EQ(hash_l_n, hash_r_n); +} + +TEST(Simplify, HashDifferenceTypes) { + HashProvider hasher; + std::vector immediates; + + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + // NOLINTNEXTLINE(modernize-use-bool-literals) + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + + // Immediates of different types are not equal. + for (unsigned int i = 0; i < immediates.size(); ++i) { + for (unsigned int j = i + 1; j < immediates.size(); ++j) { + ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j])); + } + } + + // But coerced immediates are if they are the same type: + ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1); + ExprHandle f2 = Cast::make(kFloat, IntImm::make(3)); + + ExprHandle ff1 = IRSimplifier::simplify(f1); + ExprHandle ff2 = IRSimplifier::simplify(f2); + + ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node())); +} + +TEST(Simplify, HashLargeExpression) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + VarHandle i("i", kInt); + auto memcpy_stmt = For::make( + i, + 0, + N, + Store::make( + c, + {i}, + CompareSelect::make( + Load::make(a, {i}), + Load::make(b, {i}), + CompareSelectOperation::kEQ))); + + BufHandle d("D", {1}, kInt); + BufHandle e("E", {1}, kInt); + auto store_ramp_stmt = Store::make( + e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)})); + + auto if_stmt = Cond::make( + CompareSelect::make( + Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE), + memcpy_stmt, + store_ramp_stmt); + + HashProvider hasher; + auto hash_r = hasher.hash(if_stmt); + // We should not have to do any more work. + ASSERT_TRUE(hasher.cachedHash(memcpy_stmt)); + auto hash_t = hasher.hash(memcpy_stmt); + ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt)); + auto hash_f = hasher.hash(store_ramp_stmt); + + // Root not equal to either branch, and branches not equal. + ASSERT_NE(hash_r, hash_t); + ASSERT_NE(hash_r, hash_f); + ASSERT_NE(hash_t, hash_f); +} + +TEST(Simplify, HashForLoopOptions) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + VarHandle i("i", kInt); + auto for_stmt = For::make( + i, + 0, + N, + Store::make( + c, + {i}, + CompareSelect::make( + Load::make(a, {i}), + Load::make(b, {i}), + CompareSelectOperation::kEQ))); + + HashProvider hasher; + auto hash_before = hasher.hash(for_stmt); + hasher.clearCache(); + + for_stmt->set_gpu_block_index(LoopOptions::IDX_X); + auto hash_block_idx = hasher.hash(for_stmt); + hasher.clearCache(); + + ASSERT_NE(hash_before, hash_block_idx); + + for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET); + auto hash_reset = hasher.hash(for_stmt); + hasher.clearCache(); + + ASSERT_EQ(hash_before, hash_reset); + for_stmt->set_gpu_thread_index(LoopOptions::IDX_X); + auto hash_thread_idx = hasher.hash(for_stmt); + + ASSERT_NE(hash_before, hash_thread_idx); + ASSERT_NE(hash_block_idx, hash_thread_idx); +} + +/// (2 + x) + 4 => x + 6 +TEST(Simplify, SimplifyAdd) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + VarHandle m("m", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + VarHandle n("n", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + VarHandle n_1("n_1", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); + + ExprHandle simplified = IRSimplifier::simplify(body); + AddPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + VarPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + ASSERT_EQ(lhs->name_hint(), "x"); + IntImmPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + ASSERT_EQ(rhs->value(), 6.f); +} + +/// (2 - x) - 4 => -2 - x +TEST(Simplify, SimplifySub) { + VarHandle x("x", kInt); + ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); + + ExprHandle simplified = IRSimplifier::simplify(body); + SubPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + IntImmPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + ASSERT_EQ(lhs->value(), -2.f); + VarPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + ASSERT_EQ(rhs->name_hint(), "x"); +} + +/// 2 * (1 - x) - 4 => 2 * (-3 - x) +TEST(Simplify, SimplifyMultiLayer) { + VarHandle x("x", kInt); + ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_IMM_WITH_VAL(Int, sub->lhs(), -3); + IS_VAR_WITH_NAME(sub->rhs(), "x"); +} + +/// 2 * (3 * x) - (x * 4) => 2 * x +TEST(Simplify, SimplifyMultiTerm) { + VarHandle x("x", kInt); + ExprHandle body = + (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); + + ExprHandle simplified = IRSimplifier::simplify(body); + MulPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + IntImmPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + ASSERT_EQ(lhs->value(), 2); + VarPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + ASSERT_EQ(rhs->name_hint(), "x"); +} + +/// 2 * (3 * (long)x) - (x * 4) => 2 * x +TEST(Simplify, SimplifyCasts) { + VarHandle x("x", kLong); + ExprHandle body = + (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); + + ExprHandle simplified = IRSimplifier::simplify(body); + MulPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + LongImmPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + ASSERT_EQ(lhs->value(), 2); + VarPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + ASSERT_EQ(rhs->name_hint(), "x"); +} + +/// (x + 0) * 1 => x +TEST(Simplify, SimplifyEliminatesNoOps) { + VarHandle x("x", kInt); + ExprHandle body = (x + ExprHandle(0)) * 1; + + ExprHandle simplified = IRSimplifier::simplify(body); + VarPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + ASSERT_EQ(root->name_hint(), "x"); +} + +/// Cannot simplify this. +TEST(Simplify, SimplifyMultiVar) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = x * 24 + y * 34; + + ExprHandle simplified = IRSimplifier::simplify(body); + + AddPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + MulPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + VarPtr varX = to(lhs->rhs()); + ASSERT_NE(varX, nullptr); + ASSERT_EQ(varX->name_hint(), "x"); + MulPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + VarPtr varY = to(rhs->rhs()); + ASSERT_NE(varY, nullptr); + ASSERT_EQ(varY->name_hint(), "y"); +} + +// x + 2 + y => x + y + 2 +TEST(Simplify, DISABLED_SimplifyReorderings) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = x + 2 + y; + ExprHandle simplified = IRSimplifier::simplify(body); + + AddPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + + IS_NODE_WITH_NAME(Add, root->lhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + IS_IMM_WITH_VAL(Int, root->rhs(), 2); +} + +/// y + x * 0 => y +TEST(Simplify, SimplifyEliminatesVar) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = y + x * ExprHandle(0); + + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "y"); +} + +TEST(Simplify, SimplifyAdds) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) + (x + y) => 2 * (x + y) + ExprHandle body = (x + y) + (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Add, root->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // (x * y) + (x * y) => 2 * (x * y) + ExprHandle body = (x * y) + (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Mul, root->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // (x - y) + (x - y) => 2 * (x - y) + ExprHandle body = (x - y) + (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // (x + x + x + x) => 4 * x + ExprHandle body = (x + x + x + x); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 4); + IS_VAR_WITH_NAME(root->rhs(), "x"); + } + + { + // (x + 0) => x. + ExprHandle body = x + 0; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x + 0.f) => float(x). + ExprHandle body = x + 0.f; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } +} + +TEST(Simplify, SimplifyMuls) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) * (x + y) => (x + y) * (x + y) + // We don't attempt to simplify multiplication of polynomials since the + // result is only very rarely more efficient. + ExprHandle body = (x + y) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // x * y * x * y => x * x * y * y + // These get reordered only. + ExprHandle body = x * y * x * y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); + IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); + IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); + IS_VAR_WITH_NAME(mul1->rhs(), "y"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + IS_VAR_WITH_NAME(mul3->lhs(), "x"); + IS_VAR_WITH_NAME(mul3->rhs(), "x"); + } + + { + // 1 * (x * 1) => x + // Ones cancel cleanly. + ExprHandle body = ExprHandle(1) * (x * ExprHandle(1)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // 1.f * (x * 1.f) => x + // Even float ones cancel cleanly, but carry their type. + ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } + + { + // 1 * (x * 1.f) => x + // One float is enough to cast the expr. + ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } + + { + // 1 * (x * 0) => 0 + // Zeroes are eliminated. + ExprHandle body = ExprHandle(1) * (x * ExprHandle(0)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // 1 * (x * 0) => 0 + // But not for Float since nan * 0 = nan. + ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Cast, mul->lhs(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0); + } + + { + // (x - y) * (x - y) => (x - y) * (x - y) + // As with Add we don't attempt simplification of this. + ExprHandle body = (x - y) * (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // (x + y) * (x - y) => (x + y) * (x - y) + // Don't simplify with different ops on each side. + ExprHandle body = (x + y) * (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with no scalar, poly with non-identity scalar. + // x * (y + 1) => x + x * y + ExprHandle body = x * (y + ExprHandle(1)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with identity scalar, poly with non-identity scalar. + // (x * 1) * (y + 1) => x + x * y + ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with non-identity scalar, poly with non-identity scalar. + // (x * 2) * (y + 1) => 2 * (x + x * y) + ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); + IS_VAR_WITH_NAME(mul2->lhs(), "x"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with non-identity scalar, poly with identity scalar. + // (x * 2) * (y + 0) => 2 * (x * y) + ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2); + IS_VAR_WITH_NAME(mul2->lhs(), "x"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with identity scalar, poly with identity scalar. + // (x * 1) * (y + 0) => x * y + ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with no scalar, poly with identity scalar. + // x * (y + 0) => x * y + ExprHandle body = x * (y + ExprHandle(0)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } +} + +// Sub an expr from itself will result in zero. +TEST(Simplify, SimplifySubs) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) - (x + y) => 0 + ExprHandle body = (x + y) - (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x * y) - (x * y) => 0 + ExprHandle body = (x * y) - (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x - y) - (x - y) => 0 + ExprHandle body = (x - y) - (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x + y) - 2 * (x + y) => -1 * x - y + ExprHandle body = (x + y) - ExprHandle(2) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -1); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); + } + + { + // (x + y) - y => x + ExprHandle body = (x + y) - y; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x - 0) => x. + ExprHandle body = x - 0; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x - 0.f) => x. + // Simple enough to cancel in float. + ExprHandle body = x - ExprHandle(0.f); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } + + { + // (x - (float)(y - y)) => x. + ExprHandle body = x - Cast::make(kFloat, y - y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } + + { + // (x - y) - y => x - 2 * y + ExprHandle body = (x - y) - y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // 2 * x - x => x + ExprHandle body = (ExprHandle(2) * x) - x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // x - 2 * x = -1 * x + // We don't have a unary negate, but this could be 0 -x I guess? + ExprHandle body = x - (ExprHandle(2) * x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + + IS_IMM_WITH_VAL(Int, mul->lhs(), -1); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // (x + y + 5) * (x - x) => 0 + // Cancelling out one side of Mul cancels both. + ExprHandle body = (x + y + 5) * (x - x); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Cancel out opaque modulus. + ExprHandle body = (x % y + 2) - (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 2); + } + + { + // Cancel out opaque modulus with a bit more going on. + ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 2); + } + + { + // Sub where result is negative. + ExprHandle body = x - (x + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), -1); + } + + { + // Sub where result is positive due to negative scalar on RHS. + ExprHandle body = x - (x - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 1); + } + + { + // Term - Polynomial sub where RHS must be negated. + ExprHandle body = (x * 2) - (x * 2 + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), -1); + } + + { + // Term - Polynomial sub where the result is a Term. + ExprHandle body = (y * x * 2) - (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Term - Polynomial sub where the result is a Polynomial. + ExprHandle body = (x * 2) - (x + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_IMM_WITH_VAL(Int, sub->rhs(), 1); + } +} + +TEST(Simplify, SimplifyDiv) { + VarHandle x("x", kInt); + + { + ExprHandle body = ExprHandle(0) / x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + ExprHandle body = x / 1; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_VAR_WITH_NAME(simplified.node(), "x"); + } +} + +TEST(Simplify, SimplifyDivWithLoopContext0) { + // Stmt to simplify: + // for (int i = 0; i < 100; i++) { + // A[i] = i / 100; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {100}, kInt); + auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext1) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // A[i] = (i + 24) / 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {6}, kInt); + auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 4; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext2) { + // Stmt to simplify: + // for (const auto i : c10::irange(5)) { + // A[i] = (i + 25) / 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {5}, kInt); + auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 4; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext3) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // A[i] = (i + 24) / (-6); + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {6}, kInt); + auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NOT: A[i] = -4; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext4) { + // Stmt to simplify: + // for (const auto i : c10::irange(5)) { + // A[i] = (i - 5) / 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {5}, kInt); + auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NOT: A[i] = 0; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext5) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (const auto j : c10::irange(10)) { + // A[i, j] = (i + 6*j) / 6; + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6)); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NEXT: A[i, j] = j; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext6) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (int j = -1; j < 9; j++) { + // A[i, j+1] = (i + 6*j) / 6; + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = + For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6)); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NOT: A[i, j] = j; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext7) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (const auto j : c10::irange(10)) { + // A[i, j] = (i + 6*j) / (-6); + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = + For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6))); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NOT: A[i, j] = -j; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext0) { + // Stmt to simplify: + // for (const auto i : c10::irange(100)) { + // A[i] = i % 100; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {100}, kInt); + auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext1) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // A[i] = (i + 24) % 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {6}, kInt); + auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext2) { + // Stmt to simplify: + // for (const auto i : c10::irange(5)) { + // A[i] = (i + 25) % 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {5}, kInt); + auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = i + 1; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext3) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // A[i] = (i + 24) % (-6); + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {6}, kInt); + auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NOT: A[i] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext4) { + // Stmt to simplify: + // for (const auto i : c10::irange(5)) { + // A[i] = (i - 5) % 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {5}, kInt); + auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NOT: A[i] = i - 5; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext5) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (const auto j : c10::irange(10)) { + // A[i, j] = (i + 6*j) % 6; + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6)); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NEXT: A[i, j] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext6) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (int j = -1; j < 9; j++) { + // A[i, j+1] = (i + 6*j) % 6; + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = + For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6)); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NOT: A[i, j] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext7) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (const auto j : c10::irange(10)) { + // A[i, j] = (i + 6*j) % (-6); + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = + For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6))); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NOT: A[i, j] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyMod) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + { + // Constant folding works. + ExprHandle body = ExprHandle(10) % 8; + ExprHandle simplified = IRSimplifier::simplify(body); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_IMM_WITH_VAL(Int, simplified.node(), 2); + } + + { + // x % x => 0 + ExprHandle body = x % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // 0 % x => 0 + ExprHandle body = ExprHandle(0) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // x % 1 => 0 + ExprHandle body = x % 1; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Doesn't change unknown mods. + // x % y => x % y + ExprHandle body = x % y; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_VAR_WITH_NAME(mod->rhs(), "y"); + } + + { + // don't touch if RHS is unknown. + // 4 % x => 4 % x + ExprHandle body = ExprHandle(4) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_IMM_WITH_VAL(Int, mod->lhs(), 4); + IS_VAR_WITH_NAME(mod->rhs(), "x"); + } + + { + // don't touch if LHS is unknown. + // x % 4 => x % 4 + ExprHandle body = x % 4; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 4); + } + + { + // if LHS is a multiple of RHS, mod is zero. + // 2 * x % x => 0 + ExprHandle body = (x * 2) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true even if the multiple is not constant. + // x * y % x => 0 + ExprHandle body = (x * y) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true with multiple unknown values in LHS. + // x * y * z % x => 0 + ExprHandle body = (x * y * z) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true if the denom is compound. + // x * y * z % y * z => 0 + ExprHandle body = (x * y * z) % (y * z); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Sanity check true with scalars that are multiples. + // 12 * x % 4 => 0 + ExprHandle body = (x * 12) % 4; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Sanity check not true if the smaller scalar is on LHS. + // 4 * x % 12 => 4 * x % 12 + ExprHandle body = (x * 4) % 12; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 4); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 12); + } + + { + // Both scalar and symbolic in multiple. + // (6 * x * y) % (3 * x * y) => 0 + ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } +} + +// Test that mixing ops together simplifies as expected. +TEST(Simplify, SimplifyMultiOp) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x * y) + (x - y) => (x + x * y) - y + ExprHandle body = (x * y) + (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Add, sub->lhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); + } + + { + // (x + y) - x * y => (x + y) - x * y + ExprHandle body = (x + y) - x * y; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Add, sub->lhs(), add); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // (x - y) - (x + y) => -2 * y + ExprHandle body = (x - y) - (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -2); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // (x - 0) + (x * 1) - (x + 0) => x + ExprHandle body = (x - 0) + (x * 1) - (x + 0); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x) + // Even in Float simple terms cancel out, but the variable ones cannot. + ExprHandle body = + (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Add, sub->lhs(), add); + IS_NODE_WITH_NAME(Cast, add->lhs(), cast1); + IS_VAR_WITH_NAME(cast1->src_value(), "x"); + IS_NODE_WITH_NAME(Cast, add->rhs(), cast2); + IS_VAR_WITH_NAME(cast2->src_value(), "x"); + IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3); + IS_VAR_WITH_NAME(cast3->src_value(), "x"); + } +} + +// Test that chaining many ops together works as expected. +TEST(Simplify, SimplifyManyOps) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x + ExprHandle body = x + y + x + x + y + y + x + y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); + IS_VAR_WITH_NAME(rhs->rhs(), "x"); + } + + { + // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y + ExprHandle body = x - y + x + x - y - y + x - y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // x + y + x - x - y - y + x + y + x = 3 * x + ExprHandle body = x + y + x - x - y - y + x + y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 3); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } +} + +TEST(Simplify, SimplifyFactorization) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (2 * x) + (2 * y) => 2 * (x + y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // Factorization when scalars have common divider. + // (2 * x) + (4 * y) => 2 * (2 * y + x) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } + + { + // Factorization attempt without a common divider. + // (2 * x) + (5 * y) => (5 * y) + (2 * x) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // Factorization after merging. + // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + + (ExprHandle(8) * x + ExprHandle(6) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 10); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // Factorization with common divider but different signs. + // (2 * x) + (-4 * y) => 2 * (x - 2 * y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } + + { + // Factorization with all negative numbers. + // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y) + ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), -1); + IS_VAR_WITH_NAME(mul2->rhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3); + IS_IMM_WITH_VAL(Int, mul3->lhs(), 2); + IS_VAR_WITH_NAME(mul3->rhs(), "y"); + } + + { + // The following test ensures that there in no infinite recursion during + // factorization when negative numbers are involved. + VarHandle a("a", kInt); + VarHandle b("b", kInt); + VarHandle c("c", kInt); + VarHandle d("d", kInt); + VarHandle e("e", kInt); + VarHandle f("f", kInt); + VarHandle g("g", kInt); + VarHandle h("h", kInt); + + ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 + + f * 32 + g * (-1024) + h * (-32); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR( + simplified, + "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h"); + } +} + +// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) +TEST(Simplify, SimplifyFactorizeUneven) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = + (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Add, root->rhs(), add1); + IS_NODE_WITH_NAME(Add, add1->lhs(), add2); + + IS_VAR_WITH_NAME(add2->lhs(), "y"); + IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul); + IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); + + IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); + IS_VAR_WITH_NAME(xmul->rhs(), "x"); + + IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); + IS_VAR_WITH_NAME(zmul->rhs(), "z"); +} + +// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) +// This is kind of a placeholder test for variable factorization. +TEST(Simplify, SimplifyDeeperTerms) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); + IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); + IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); + IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); + IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); +} + +// Tests the difference between two less trivial expressions. +// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 +TEST(Simplify, SimplifyDeeperDifference) { + VarHandle n("n", kInt); + VarHandle n_1("n_1", kInt); + VarHandle m("m", kInt); + ExprHandle body = + (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 1); +} + +// Test constant folding into the difference between expressions. +// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 +TEST(Simplify, SimplifyFoldComplexDifference) { + VarHandle n("n", kInt); + VarHandle n_1("n_1", kInt); + VarHandle m("m", kInt); + ExprHandle body = + (IntImm::make(2) + + (Cast::make( + kChar, + (m * (ExprHandle(1) * n_1) + (n + 1)) - + (m * (ExprHandle(1) * n_1) + n)))); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 3); +} + +TEST(Simplify, SimplifyIfComponents) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make( + ((ExprHandle(5) - ExprHandle(4)) * x) > y, + ExprHandle(2) * x - x, + ExprHandle(2) * y - y); + + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); + + IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); + ASSERT_EQ(cmp->compare_select_op(), kGT); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_VAR_WITH_NAME(cmp->rhs(), "y"); + + IS_VAR_WITH_NAME(ifexpr->true_value(), "x"); + IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); +} + +TEST(Simplify, SimplifyOpaqueTerms) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // 2 * x/y * y - x/y * y => x/y * y + ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Div, mul->lhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // x%y - (x%y - 1) => 1 + ExprHandle body = (x % y) - ((x % y) - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 1); + } +} + +TEST(Simplify, SimplifySymbolicMinMax) { + { + // Minimum with constant difference between terms. + VarHandle x("x", kInt); + ExprHandle body = Min::make(x + 3, x + 7, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_IMM_WITH_VAL(Int, add->rhs(), 3); + } + + { + // Maximum with constant difference between terms. + VarHandle x("x", kInt); + ExprHandle body = Max::make(x + 3, x + 7, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_IMM_WITH_VAL(Int, add->rhs(), 7); + } + + { + // Can't simplify multiples because of signedness of variable component. + // TODO: maybe we could for unsigned types? + VarHandle x("x", kInt); + ExprHandle body = Max::make(x * 3, x * 7, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE(Max, simplified.node()); + } +} + +TEST(Simplify, SimplifyNestedMax) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + { + // Max(x + y, x + y) => x + y + ExprHandle body = Max::make(x + y, x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); + } + + { + // Max(x + y, Max(x + y, z)) => Max(x + y, z) + ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(max->rhs(), "z"); + } + + { + // Max(x + y, Max(z, x + y)) => Max(x + y, z) + ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(max->rhs(), "z"); + } + + { + // Max(Max(x + y, z), x + y) => Max(x + y, z) + ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(max->rhs(), "z"); + } + + { + // Max(Max(z, x + y), x + y) => Max(x + y, z) + ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(max->rhs(), "z"); + } + + { + // Max(Max(x, y), x) => Max(Max(x, y), x) + // Nested Max ops with different propagate_nans should not be simplified. + ExprHandle body = Max::make(Max::make(x, y, true), x, false); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y"); + ASSERT_TRUE(max1->propagate_nans()); + IS_VAR_WITH_NAME(max->rhs(), "x"); + ASSERT_FALSE(max->propagate_nans()); + } + + { + // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x) + ExprHandle body = + Max::make(Min::make(x, y, true), Min::make(x, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); + } + + { + // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x) + ExprHandle body = + Max::make(Min::make(x, y, true), Min::make(z, x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); + } + + { + // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x) + ExprHandle body = + Max::make(Min::make(y, x, true), Min::make(x, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); + } + + { + // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x) + ExprHandle body = + Max::make(Min::make(y, x, true), Min::make(z, x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); + } + + { + // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z)) + // When all the ops in the pattern do not have the same propagate_nans, + // it should not be simplified. + ExprHandle body = + Max::make(Min::make(y, x, true), Min::make(z, x, false), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y"); + ASSERT_TRUE(min1->propagate_nans()); + IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z"); + ASSERT_FALSE(min2->propagate_nans()); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(5, Max(x, 8)) => Max(x, 8) + ExprHandle body = Max::make(5, Max::make(x, 8, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(8, Max(x, 5)) => Max(x, 8) + ExprHandle body = Max::make(8, Max::make(x, 5, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(Max(x, 8), 5) => Max(x, 8) + ExprHandle body = Max::make(Max::make(x, 8, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(Max(x, 5), 8) => Max(x, 8) + ExprHandle body = Max::make(Max::make(x, 5, true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + 5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + 8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + 5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8) + // Do not simplify when all the Max ops do not have the same + // propagate_nans. + ExprHandle body = Max::make( + Max::make(Max::make(Max::make(z, 5, true), y, false), x, true), + 8, + false); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)"); + } + + { + // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + 8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } +} + +TEST(Simplify, SimplifyNestedMin) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + { + // Min(x + y, x + y) => x + y + ExprHandle body = Min::make(x + y, x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); + } + + { + // Min(x + y, Min(x + y, z)) => Min(x + y, z) + ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(min->rhs(), "z"); + } + + { + // Min(x + y, Min(z, x + y)) => Min(x + y, z) + ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(min->rhs(), "z"); + } + + { + // Min(Min(x + y, z), x + y) => Min(x + y, z) + ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(min->rhs(), "z"); + } + + { + // Min(Min(z, x + y), x + y) => Min(x + y, z) + ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(min->rhs(), "z"); + } + + { + // Min(Min(x, y), x) => Min(Min(x, y), x) + // Nested Min ops with different propagate_nans should not be simplified. + ExprHandle body = Min::make(Min::make(x, y, true), x, false); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y"); + ASSERT_TRUE(min2->propagate_nans()); + IS_VAR_WITH_NAME(min1->rhs(), "x"); + ASSERT_FALSE(min1->propagate_nans()); + } + + { + // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x) + ExprHandle body = + Min::make(Max::make(x, y, true), Max::make(x, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); + } + + { + // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x) + ExprHandle body = + Min::make(Max::make(x, y, true), Max::make(z, x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); + } + + { + // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x) + ExprHandle body = + Min::make(Max::make(y, x, true), Max::make(x, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); + } + + { + // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x) + ExprHandle body = + Min::make(Max::make(y, x, true), Max::make(z, x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); + } + + { + // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z)) + // When all the ops in the pattern do not have the same propagate_nans, + // it should not be simplified. + ExprHandle body = + Min::make(Max::make(y, x, true), Max::make(z, x, false), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y"); + ASSERT_TRUE(max1->propagate_nans()); + IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z"); + ASSERT_FALSE(max2->propagate_nans()); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(5, Min(x, 8)) => Min(x, 8) + ExprHandle body = Min::make(5, Min::make(x, 8, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(8, Min(x, 5)) => Min(x, 8) + ExprHandle body = Min::make(8, Min::make(x, 5, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(Min(x, 8), 5) => Min(x, 8) + ExprHandle body = Min::make(Min::make(x, 8, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(Min(x, 5), 8) => Min(x, 8) + ExprHandle body = Min::make(Min::make(x, 5, true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + 5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + 5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + 5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8) + // Do not simplify when all the Min ops do not have the same + // propagate_nans. + ExprHandle body = Min::make( + Min::make(Min::make(Min::make(z, 5, true), y, false), x, true), + 8, + false); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)"); + } + + { + // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + 8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } +} + +TEST(Simplify, SimplifyWontReorderFloat) { + { + // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) + // This is an expression we can simplify. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 9); + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); + } + + { + // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). + // If the vars are floating point, ops are not associative and we can't + // reorder. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); + IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); + IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); + IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y"); + } + + { + // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). + // We will simplify subexprs if they dont reorder floating point ops. + VarHandle x("x", kDouble); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); + IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); + IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); + IS_VAR_WITH_NAME(rhsMul->rhs(), "y"); + } + + { + // Prevent reordering if FP propagated from dtypes. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3.f) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); + IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); + IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); + IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); + IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); + IS_VAR_WITH_NAME(yCast->src_value(), "y"); + } + + { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + // x%y - (x%y - 1) => x%y - (x%y - 1). + // We won't reorder opaque ops if they are FP. + ExprHandle body = (x % y) - ((x % y) - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); + IS_VAR_WITH_NAME(lhsMod->lhs(), "x"); + IS_VAR_WITH_NAME(lhsMod->rhs(), "y"); + + IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); + IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); + IS_VAR_WITH_NAME(rhsMod->lhs(), "x"); + IS_VAR_WITH_NAME(rhsMod->rhs(), "y"); + IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); + } +} + +TEST(Simplify, SimplifyRoundModPattern) { + { + // (x/y)*y + x%y => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / y) * y) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Reverse order. + // x%y + (x/y)*y => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x % y) + ((x / y) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Non opaque denominator. + // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) + + (x % (y + ExprHandle(4))); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Reverse order. + // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x % (y + ExprHandle(4))) + + ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Opaque denominator. + // (x / (2/y)) * (2/y)) + (x % (2/y)) => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) + + (x % (ExprHandle(2) / y)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Non opaque numerator + // ((2*x)/y * y) + ((2*x) % y) => 2 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = + (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Opaque numerator. + // ((x/2) / y * y) + (x/2 % y) => x / 2. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = + (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_IMM_WITH_VAL(Int, div->rhs(), 2); + } + + { + // Numerator and denominator. + // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = + (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) + + ((ExprHandle(2) * x) % (ExprHandle(2) * y)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Reverse order. + // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) + + (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Negated Subtraction of Round Mod. + // (x/y) * y - (0 - x%y) => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Other terms are preserved. + // (x/y)*y + x%y + (y * x) => x + (y * x). + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / y) * y) + (x % y) + (y * x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Sanity checking we won't do the optimization on floats. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = ((x / y) * y) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); + IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv); + IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); + IS_VAR_WITH_NAME(roundDiv->rhs(), "y"); + IS_VAR_WITH_NAME(roundMul->rhs(), "y"); + IS_NODE_WITH_NAME(Mod, add->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_VAR_WITH_NAME(mod->rhs(), "y"); + } + + { + // Sanity check we won't do it if the mod term doesn't match. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = ((x / y) * y) + (x % z); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "(x / y) * y + x % z"); + } + + { + // Sanity check we won't do it if the div term doesn't match. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = (y * (x / z)) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "x % y + (x / z) * y"); + } + + { + // Sanity check we won't do it if the mul term doesn't match. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = ((x / y) * z) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "x % y + (x / y) * z"); + } +} + +TEST(Simplify, SimplifyRoundModPatternFactorization) { + { + // Full factorization. + // 2 * (x/y * y) + 2 * (x%y) => 2 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Partial Factorization. + // 32 * (x/8) + 4 * (x % 8) => 4 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) + ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 4); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Factorization requiring constant folding. + // 20 * (x / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x. + VarHandle x("x", kInt); + ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) + + (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 5); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + VarHandle x("x", kInt); + ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + VarHandle x("x", kInt); + ExprHandle body = (x / 10) * 0 + x % 5; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 5); + } +} + +TEST(Simplify, SimplifyRoundModPatternMultivar) { + { + // Multivar. + // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) + + (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // Find the right var. + // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8 + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = + (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_NODE_WITH_NAME(Add, add->lhs(), add2); + IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod); + IS_VAR_WITH_NAME(xMod->lhs(), "x"); + IS_IMM_WITH_VAL(Int, xMod->rhs(), 8); + IS_VAR_WITH_NAME(add2->rhs(), "y"); + IS_NODE_WITH_NAME(Mod, add->rhs(), zMod); + IS_VAR_WITH_NAME(zMod->lhs(), "z"); + IS_IMM_WITH_VAL(Int, zMod->rhs(), 8); + } + + { + // Compound. + // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16) + // => (z + 512 * y) + x + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "x + (z + 512 * y)"); + } +} + +TEST(Simplify, SimplifyModRoundModPattern) { + { + // t/7 % 9 * 7 + t % 7 => t%63 + VarHandle t("t", kInt); + ExprHandle body = (t / 7 % 9) * 7 + t % 7; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63 + VarHandle t("t", kInt); + ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // t/x % y * x + t % x => t%(x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (t / x % y) * x + t % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // k*t/x % y * x + k*t % x => k*t%(x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle k("k", kInt); + ExprHandle body = (k * t / x % y) * x + k * t % x; + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "(k * t) % (x * y)"); + } + + { + // t/k/x % y * x + t/k % x => t/k%(x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle k("k", kInt); + ExprHandle body = (t / k / x % y) * x + t / k % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Div, mod->lhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "t"); + IS_VAR_WITH_NAME(div->rhs(), "k"); + IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Sanity checking we won't do the optimization on floats. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + VarHandle z("z", kFloat); + ExprHandle body = ((x / y % z) * y) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_NODE_WITH_NAME(Mul, add->lhs(), mul); + IS_NODE_WITH_NAME(Mod, mul->lhs(), mod); + IS_NODE_WITH_NAME(Div, mod->lhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + IS_VAR_WITH_NAME(mod->rhs(), "z"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); + IS_VAR_WITH_NAME(mod2->lhs(), "x"); + IS_VAR_WITH_NAME(mod2->rhs(), "y"); + } +} + +TEST(Simplify, SimplifyModRoundModPatternFactorization) { + { + // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63) + VarHandle t("t", kInt); + ExprHandle body = + ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63) + VarHandle t("t", kInt); + ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63 + VarHandle t("t", kInt); + ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Div, mod->lhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "t"); + IS_IMM_WITH_VAL(Int, div->rhs(), 2); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189 + VarHandle t("t", kInt); + ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 + + t % (ExprHandle(7) * ExprHandle(3)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 189); + } + + { + // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y)) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = + ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2); + IS_VAR_WITH_NAME(mul2->lhs(), "x"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } +} + +TEST(Simplify, SimplifyModRoundModPatternMultivar) { + { + // t/7 % 9 * 7 + t % 7 + t => t % 63 + t + VarHandle t("t", kInt); + ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t; + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "t % 63 + t"); + } + + { + // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8 => t % 63 + t % 72 + VarHandle t("t", kInt); + ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_NODE_WITH_NAME(Mod, add->lhs(), mod1); + IS_VAR_WITH_NAME(mod1->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod1->rhs(), 63); + IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); + IS_VAR_WITH_NAME(mod2->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod2->rhs(), 72); + } + + { + // k + t/x % y * x + t % x => k + t%(x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle k("k", kInt); + ExprHandle body = k + (t / x % y) * x + t % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "k"); + IS_NODE_WITH_NAME(Mod, add->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x + // => t%(x*y) + t/k % (x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle k("k", kInt); + ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x; + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)"); + } + + { + // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63) + // => io_flat + VarHandle t("io_flat", kInt); + ExprHandle body = + ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "io_flat"); + } + + { // 5D: i0_flat / (11 * 10 * 9 * 7) * (7 * 9 * 10 * 11) + + // (i0_flat / (10 * 9 * 7) % 11) * 7 * 9 * 10 + + // (i0_flat / (9 * 7) % 10) * 7 * 9 + + // (i0_flat / 7 % 9) * 7 + + // i0_flat % 7 => io_flat + VarHandle t("io_flat", kInt); + ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) + + (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 + + (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "io_flat"); + } + + { + // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) * + // (i0_flat / (m * n)) => io_flat + VarHandle t("io_flat", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "io_flat"); + } + + { // 5D: i0_flat / (k * l * n * m) * (m * n * l * k) + + // (i0_flat / (l * n * m) % k) * m * n * l + + // (i0_flat / (n * m) % l) * m * n + + // (i0_flat / m % n) * m + + // i0_flat % m => io_flat + VarHandle t("io_flat", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle l("l", kInt); + VarHandle k("k", kInt); + ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) + + (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n + + (t / m % n) * m + t % m; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "io_flat"); + } +} + +TEST(Simplify, SimplifyDivisionScalarFactorization) { + { + // Simple factorization of numerator and denominator. + // 8x / 4y => 2x / y. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * 8) / (y * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + } + + { + // Don't change anything if we can't factorize. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * 7) / (y * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 7); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // Don't reorder floats. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = (x * 8) / (y * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f); + IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "y"); + IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f); + } + + { + // Sanity check we do nothing if there are only scalar parts. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * 1) / (y * 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + } + + { + // Can factorize amounts of variables. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x + x + x + x) / (y + y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + } +} + +TEST(Simplify, SimplifyConstantBranches) { + { + // If the condition is constant true then take the true_value. + // 1 ? x : y => x + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle t(1); + ExprHandle body = IfThenElse::make(t, x, y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // If the condition is constant false then take the false_value. + // 0 ? x : y => y + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle t(0); + ExprHandle body = IfThenElse::make(t, x, y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "y"); + } + + { + // condition is simplified before checking. + // (x-x) ? x : y => y + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make(x - x, x, y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "y"); + } + + { + // If both branches are the same then don't do the condition. + // y ? x : x => x + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make(y, x, x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // If both branches simplify to the same thing it still works. + // y ? (x + x) : (2 * x) => x + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } +} + +TEST(Simplify, SimplifyConstantCond) { + { + // If the condition is constant true then take the true_value. + // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + ExprHandle condition(1); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); + + CondPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "A"); + } + + { + // If the condition is constant false then take the false_value. + // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1 + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + ExprHandle condition(0); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "B"); + } + + { + // condition is simplified before checking. + // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1 + VarHandle x("x", kInt); + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + ExprHandle condition(x - x); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "B"); + } + + { + // If both branches are the same then don't do the condition. + // x ? A[0] = x : A[0] = x => A[0] = x + VarHandle x("x", kInt); + BufHandle a("A", {1}, kInt); + ExprHandle condition(x - x); + StmtPtr true_val = Store::make(a, {0}, x); + StmtPtr false_val = Store::make(a, {0}, x); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "A"); + } + + { + // If both branches simplify to the same thing it still works. + // x ? (x + x) : (2 * x) => x + VarHandle x("x", kInt); + BufHandle a("A", {1}, kInt); + ExprHandle condition(x - x); + StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x); + StmtPtr false_val = Store::make(a, {0}, x + x); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "A"); + } + + { + // But not if they dont + // x ? x : (2 * x) => x ? x : (2 * x) + VarHandle x("x", kInt); + BufHandle a("A", {1}, kInt); + ExprHandle condition(x); + StmtPtr true_val = Store::make(a, {0}, x); + StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block, nullptr); + } + + { + StmtPtr cond = alloc( + ExprHandle(false).node(), + alloc(std::vector({})), + nullptr); + StmtPtr simplified = IRSimplifier::simplify(cond); + ASSERT_EQ(simplified, nullptr); + } + + { + StmtPtr cond = alloc( + ExprHandle(true).node(), + nullptr, + alloc(std::vector({}))); + StmtPtr simplified = IRSimplifier::simplify(cond); + ASSERT_EQ(simplified, nullptr); + } +} + +TEST(Simplify, SimplifyEliminateEmptyCond) { + // If the branches are empty in different ways, eliminate. + { + VarHandle x("x", kInt); + ExprHandle condition(x); + StmtPtr true_val = alloc(std::vector({})); + + StmtPtr body = alloc(condition.node(), true_val, nullptr); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_NE(block, nullptr); + ASSERT_EQ(block->nstmts(), 0); + } + + { + VarHandle x("x", kInt); + ExprHandle condition(x); + StmtPtr false_val = alloc(std::vector({})); + + StmtPtr body = alloc(condition.node(), nullptr, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_NE(block, nullptr); + ASSERT_EQ(block->nstmts(), 0); + } +} + +TEST(Simplify, SimplifyConstantComparisons) { + auto ComparisonTest = + [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { + ExprHandle body = CompareSelect::make(a, b, op); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), result); + }; + + // Equals. + ComparisonTest(2, 2, kEQ, 1); + ComparisonTest(1, 2, kEQ, 0); + ComparisonTest(2, 1, kEQ, 0); + + // Greater than. + ComparisonTest(2, 2, kGT, 0); + ComparisonTest(1, 2, kGT, 0); + ComparisonTest(2, 1, kGT, 1); + + // Greater or Equal. + ComparisonTest(2, 2, kGE, 1); + ComparisonTest(1, 2, kGE, 0); + ComparisonTest(2, 1, kGE, 1); + + // Less Than. + ComparisonTest(2, 2, kLT, 0); + ComparisonTest(1, 2, kLT, 1); + ComparisonTest(2, 1, kLT, 0); + + // Less or Equal. + ComparisonTest(2, 2, kLE, 1); + ComparisonTest(1, 2, kLE, 1); + ComparisonTest(2, 1, kLE, 0); + + // Not equal. + ComparisonTest(2, 2, kNE, 0); + ComparisonTest(1, 2, kNE, 1); + ComparisonTest(2, 1, kNE, 1); + + // With specified results: + ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 42); +} + +TEST(Simplify, SimplifySymbolicComparisons) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); }; + auto TookFalseBranch = [](ExprHandle a) { + IS_IMM_WITH_VAL(Int, a.node(), 0); + }; + + // EQ + + // x == x => 1 + ExprHandle body = CompareSelect::make(x, x, kEQ); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x == x+1 => 0 + body = CompareSelect::make(x, x + 1, kEQ); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x == x * 2 cannot simplify since we don't know x is nonzero. + body = CompareSelect::make(x, x * 2, kEQ); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // x == x * 1 => 1 + body = CompareSelect::make(x, x * 1, kEQ); + TookTrueBranch(IRSimplifier::simplify(body)); + + { + // x == y => x == y + body = CompareSelect::make(x, y, kEQ); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); + ASSERT_EQ(cmp->compare_select_op(), kEQ); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_VAR_WITH_NAME(cmp->rhs(), "y"); + } + + { + // x == 5 => x == 5 + body = CompareSelect::make(x, 5, kEQ); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); + ASSERT_EQ(cmp->compare_select_op(), kEQ); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_IMM_WITH_VAL(Int, cmp->rhs(), 5); + } + + // GT + + // x+1 > x => 1 + body = CompareSelect::make(x + 1, x, kGT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x > x + 1 => 0 + body = CompareSelect::make(x, x + 1, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x > x - 1 => 1 + body = CompareSelect::make(x, x - 1, kGT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x - 1 > x => 0 + body = CompareSelect::make(x - 1, x, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x > x => 0 + body = CompareSelect::make(x, x, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x * 2 > x => x * 2 > x + // since we don't know the sign of x. + body = CompareSelect::make(x * 2, x, kGT); + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // GE + + // x+1 >= x => 1 + body = CompareSelect::make(x + 1, x, kGE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x >= x + 1 => 0 + body = CompareSelect::make(x, x + 1, kGE); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x >= x => 1 + body = CompareSelect::make(x, x, kGE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x * 2 >= x => x * 2 >= x + // since we don't know the sign of x. + body = CompareSelect::make(x * 2, x, kGE); + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // LT + + // x+1 < x => 0 + body = CompareSelect::make(x + 1, x, kLT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x < x + 1 => 1 + body = CompareSelect::make(x, x + 1, kLT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x < x => 0 + body = CompareSelect::make(x, x, kLT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // LE + + // x+1 <= x => 0 + body = CompareSelect::make(x + 1, x, kLE); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x <= x + 1 => 1 + body = CompareSelect::make(x, x + 1, kLE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x <= x => 1 + body = CompareSelect::make(x, x, kLE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // NE + + // x+1 != x => 1 + body = CompareSelect::make(x + 1, x, kNE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x != x + 1 => 1 + body = CompareSelect::make(x, x + 1, kNE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x != x => 0 + body = CompareSelect::make(x, x, kNE); + TookFalseBranch(IRSimplifier::simplify(body)); +} + +TEST(Simplify, SimplifyEliminateZeroLengthFor) { + { + // Will eliminate zero loop For. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block->nstmts(), 0); + } + + { + // still works if start is not zero. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block->nstmts(), 0); + } + + { + // works if both terms are variable. + VarHandle x("x", kInt); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block->nstmts(), 0); + } + + { + // works if one term simplifies down. + VarHandle x("x", kInt); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block->nstmts(), 0); + } + + { + // Sanity check does nothing if the condition is not met. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE(For, simplified); + } +} + +TEST(Simplify, SimplifyOneLoopFor) { + { + // Will remove the loop if the body is run once. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 0); + } + + { + // still works if start is not zero. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 2); + } + + { + // works if both terms are variable. + VarHandle x("x", kInt); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_VAR_WITH_NAME(store->flat_index(), "x"); + } + + { + // works if one term simplifies down. + VarHandle x("x", kInt); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = + For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 0); + } + + { + // Sanity check does nothing if the condition is not met. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE(For, simplified); + } +} + +TEST(Simplify, SimplifyForWontLoseLoopOptions) { + { + // Sanity check does nothing if the condition is not met. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + LoopOptions options; + options.set_gpu_block_index(LoopOptions::IDX_W); + auto body = + For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, for_); + LoopOptions options2 = for_->loop_options(); + ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index()); + } +} + +TEST(Simplify, SimplifyMultilevelFor) { + { + // Multiple layers of For will be simplified out. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); + auto outer = For::make(j, 0, 1, body); + StmtPtr simplified = IRSimplifier::simplify(outer); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 0); + } + + { + // Will maintain an outer loop if the inner loop is eliminated. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); + auto outer = For::make(j, 0, 2, body); + StmtPtr simplified = IRSimplifier::simplify(outer); + ForPtr for__ = static_to(simplified); + IS_NODE_WITH_NAME(For, for__, for_); + IS_VAR_WITH_NAME(for_->var(), "j"); + IS_IMM_WITH_VAL(Int, for_->start(), 0); + IS_IMM_WITH_VAL(Int, for_->stop(), 2); + BlockPtr block = to(for_->body()); + ASSERT_NE(block, nullptr); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 0); + } + + { + // Will maintain inner loop if outer loops is eliminated. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); + auto outer = For::make(j, 0, 1, body); + StmtPtr simplified = IRSimplifier::simplify(outer); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(For, block->front(), for_); + IS_VAR_WITH_NAME(for_->var(), "i"); + IS_IMM_WITH_VAL(Int, for_->start(), 0); + IS_IMM_WITH_VAL(Int, for_->stop(), 2); + IS_NODE_WITH_NAME(Store, for_->body()->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_VAR_WITH_NAME(store->flat_index(), "i"); + } +} + +TEST(Simplify, SimplifyForCleansUp) { + { + BufHandle a("a", {1, 12, 1}, kFloat); + VarHandle x("x", kInt); + Tensor b = Compute( + "x", + {1, 12, 1}, + [](const VarHandle& i, const VarHandle& m, const VarHandle& n) { + return i + m + n; + }); + LoopNest l({b}); + l.prepareForCodegen(); + + StmtPtr body = LoopNest::sanitizeNames(l.root_stmt()); + StmtPtr simplified = IRSimplifier::simplify(body); + + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(For, block->front(), for_); + // for is over "m". + IS_VAR_WITH_NAME(for_->var(), "j"); + // x[m] = m; + IS_NODE_WITH_NAME(Store, for_->body()->front(), store); + IS_VAR_WITH_NAME(store->flat_index(), "j"); + IS_VAR_WITH_NAME(store->value(), "j"); + } +} + +TEST(Simplify, SimplifyEliminateEmptyFor) { + { + // Flatten many layers around an empty block to an empty block. + StmtPtr last = alloc(std::vector({})); + for ([[maybe_unused]] const auto i : c10::irange(11)) { + VarHandle loopVar("loopVar", kInt); + last = For::make(loopVar, 0, 10, last); + } + + StmtPtr simplified = IRSimplifier::simplify(last); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 0); + } +} + +TEST(Simplify, SimplifyFlattenBlock) { + { + // Flatten multiple blocks down to one. + // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } + BufHandle a("A", {1}, kInt); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); + + BlockPtr block1 = alloc(std::vector({store1, store2})); + BlockPtr block2 = alloc(std::vector({block1})); + + BlockPtr enclosing = alloc(std::vector({block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); + + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + + IS_NODE_WITH_NAME(Store, block->front(), store1_); + IS_NODE_WITH_NAME(Store, block->back(), store2_); + + ASSERT_EQ(store1->value(), store1_->value()); + ASSERT_EQ(store2->value(), store2_->value()); + } + + { + // Flatten multiple sub blocks containing statements. + // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } + BufHandle a("A", {1}, kInt); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); + + BlockPtr block1 = alloc(std::vector({store1})); + BlockPtr block2 = alloc(std::vector({store2})); + + BlockPtr enclosing = alloc(std::vector({block1, block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); + + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + + IS_NODE_WITH_NAME(Store, block->front(), store1_); + IS_NODE_WITH_NAME(Store, block->back(), store2_); + + ASSERT_EQ(store1->value(), store1_->value()); + ASSERT_EQ(store2->value(), store2_->value()); + } + + { + // Flatten sub blocks with different depths. + // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } + BufHandle a("A", {1}, kInt); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); + + BlockPtr block1 = alloc(std::vector({store2})); + BlockPtr block2 = alloc(std::vector({block1})); + + BlockPtr enclosing = alloc(std::vector({store1, block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); + + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + + IS_NODE_WITH_NAME(Store, block->front(), store1_); + IS_NODE_WITH_NAME(Store, block->back(), store2_); + + ASSERT_EQ(store1->value(), store1_->value()); + ASSERT_EQ(store2->value(), store2_->value()); + } + + { + // Flatten many layers around an empty block to an empty block. + StmtPtr last = alloc(std::vector({})); + for ([[maybe_unused]] const auto i : c10::irange(11)) { + last = alloc(std::vector({last})); + } + + StmtPtr simplified = IRSimplifier::simplify(last); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 0); + } +} + +TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { + { + // Simple positive case. + BufHandle b("x", {0}, kInt); + + AllocatePtr alloc_ = Allocate::make(b); + FreePtr free_ = Free::make(b); + + BlockPtr block1 = alloc(std::vector({alloc_, free_})); + ASSERT_EQ(block1->nstmts(), 2); + + StmtPtr simplified = IRSimplifier::simplify(block1); + IS_NODE_WITH_NAME(Block, simplified, block2); + ASSERT_EQ(block2->nstmts(), 0); + } + + { + // Simple negative case. + BufHandle b("x", {2}, kInt); + + AllocatePtr alloc_ = Allocate::make(b); + FreePtr free_ = Free::make(b); + + BlockPtr block1 = alloc(std::vector({alloc_, free_})); + ASSERT_EQ(block1->nstmts(), 2); + + StmtPtr simplified = IRSimplifier::simplify(block1); + IS_NODE_WITH_NAME(Block, simplified, block2); + ASSERT_EQ(block2->nstmts(), 2); + } + + { + // Finds right Alloc/Free. + BufHandle b1("x", {0}, kInt); + BufHandle b2("y", {2}, kInt); + + AllocatePtr alloc1 = Allocate::make(b1); + AllocatePtr alloc2 = Allocate::make(b2); + FreePtr free2_ = Free::make(b2); + FreePtr free1_ = Free::make(b1); + + BlockPtr block1 = + alloc(std::vector({alloc1, alloc2, free2_, free1_})); + ASSERT_EQ(block1->nstmts(), 4); + + StmtPtr simplified = IRSimplifier::simplify(block1); + IS_NODE_WITH_NAME(Block, simplified, block2); + ASSERT_EQ(block2->nstmts(), 2); + IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc); + IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y"); + IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free); + ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var()); + } + + { + // Dynamic shape. + VarHandle z("z", kInt); + BufHandle b1("x", {0}, kInt); + BufHandle b2("y", {z}, kInt); + + AllocatePtr alloc1 = Allocate::make(b1); + AllocatePtr alloc2 = Allocate::make(b2); + FreePtr free2_ = Free::make(b2); + FreePtr free1_ = Free::make(b1); + + BlockPtr block1 = + alloc(std::vector({alloc1, alloc2, free2_, free1_})); + ASSERT_EQ(block1->nstmts(), 4); + StmtPtr simplified = IRSimplifier::simplify(block1); + IS_NODE_WITH_NAME(Block, simplified, block2); + ASSERT_EQ(block2->nstmts(), 2); + } +} + +TEST(Simplify, DontSimplifyRand) { + { + // rand() + rand() = rand() + rand() NOT 2 * rand(). + ExprHandle body = + Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_RAND(add->lhs()); + IS_RAND(add->rhs()); + } + + { + // rand() - rand() = rand() - rand() NOT 0. + ExprHandle body = + Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_RAND(sub->lhs()); + IS_RAND(sub->rhs()); + } + + { + // rand() * rand() = rand() * rand(). + ExprHandle body = + Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_RAND(mul->lhs()); + IS_RAND(mul->rhs()); + } +} + +TEST(Simplify, SimplifyReorderForCond) { + BufHandle a("A", {4}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + { + // for ( if ( ... ) ) => if ( for ( ... ) ). + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(c, {i}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + } + + { + // Can't reorder if condition is dependent on the loop var. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make(i, 2, CompareSelectOperation::kEQ), + Store::make(c, {i}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); + } + + { + // Can't reorder if condition is dependent on a var that is modified inside + // the loop. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(c, {0}), 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); + } + + { + // Condition based on buffer not referenced in body. Can reorder here. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(b, {0}), 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + } + + { + // Condition based on buffer read only in body. Can reorder here. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(a, {0}), 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + } + + { + // Condition depends on Let in the loop. Cannot reorder. + auto body = For::make( + i, + 0, + 4, + Block::make( + {Let::make(j, 3), + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + nullptr)})); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Let, loop->body()->front(), let); + IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond); + } + + { + // Multi level Ifs where all conditions are distinct. Move BOTH Cond + // statements outside the loop. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(a, {0}), 10, CompareSelectOperation::kLT), + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kEQ), + Store::make(c, {0}, Load::make(a, {i})), + nullptr), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(Cond, true_block->front(), cond2); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2); + IS_NODE_WITH_NAME(For, true_block2->front(), loop); + } + + { + // Multi level Ifs where the inner condition does depend on a loop var, + // reorder only the first Cond. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(a, {0}), 10, CompareSelectOperation::kLT), + Cond::make( + CompareSelect::make(i, 3, CompareSelectOperation::kEQ), + Store::make(c, {0}, Load::make(a, {i})), + nullptr), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + IS_NODE_WITH_NAME(Block, loop->body(), loop_body); + IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2); + } + + { + // Don't reorder if there's an else block of the Cond. + // We could, but is it much better? + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + Store::make(c, {0}, 0))); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); + } + + { + // Condition uses distinct region of Tensor. + // We could reorder here with better analysis, but we don't. Included for + // completeness. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(c, {0}), 10, CompareSelectOperation::kLT), + Store::make(c, {1}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); + } +} + +TEST(Simplify, SimplifyFuseConditions) { + BufHandle a("A", {2}, kInt); + BufHandle b("B", {2}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + { + // Can fuse since the conditions are identical. + // if (A) { X }; if (A) { Y }; => if (A) { X; Y } + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Can't fuse, conditions are not identical in lhs (i != j). + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + + IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt1->nstmts(), 1); + ASSERT_EQ(true_stmt2->nstmts(), 1); + + ASSERT_EQ(cond1->false_stmt(), nullptr); + ASSERT_EQ(cond2->false_stmt(), nullptr); + } + { + // Can't fuse, conditions are not identical in rhs (10 != 11). + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 11, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + + IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt1->nstmts(), 1); + ASSERT_EQ(true_stmt2->nstmts(), 1); + + ASSERT_EQ(cond1->false_stmt(), nullptr); + ASSERT_EQ(cond2->false_stmt(), nullptr); + } + + { + // Can't fuse, conditions are not identical in operation (LT vs GT). + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kGT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + + IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt1->nstmts(), 1); + ASSERT_EQ(true_stmt2->nstmts(), 1); + + ASSERT_EQ(cond1->false_stmt(), nullptr); + ASSERT_EQ(cond2->false_stmt(), nullptr); + } + + { + // Can't fuse, CompareSelect results are different. + // Actually we totally could if we normalized CompareSelect results, but + // TODO for later. + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + + IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt1->nstmts(), 1); + ASSERT_EQ(true_stmt2->nstmts(), 1); + + ASSERT_EQ(cond1->false_stmt(), nullptr); + ASSERT_EQ(cond2->false_stmt(), nullptr); + } + + { + // Can fuse with false stmt only. + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + nullptr, + Store::make(a, {0}, i)), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + nullptr, + Store::make(a, {1}, i))}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt); + ASSERT_EQ(false_stmt->nstmts(), 2); + ASSERT_EQ(cond->true_stmt(), nullptr); + } + + { + // Can fuse with both true and false stmt. + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + Store::make(b, {0}, i)), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + Store::make(b, {1}, i))}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); + ASSERT_EQ(false_stmt->nstmts(), 2); + } + + { + // Can fuse with mismatched true / false stmt existing + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + nullptr, + Store::make(b, {1}, i))}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 1); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); + ASSERT_EQ(false_stmt->nstmts(), 1); + } + + { + // Can fuse partial block contents, ie when there are non fused stmts before + // and after. + // before: + // if (j < 10) { A[0] = j; } + // if (i < 10) { A[0] = i; } + // if (i < 10) { A[1] = i; } + // if (i < 11) { A[1] = j; } + // + // after: + // + // if (j < 10) { A[0] = j; } + // if (i < 10) { + // A[0] = i; + // A[1] = i; + // } + // if (i < 11) { A[1] = j; } + + auto body = Block::make({ + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, j), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 11, CompareSelectOperation::kLT), + Store::make(a, {1}, j), + nullptr), + }); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 3); + auto it = block->begin(); + it++; + IS_NODE_WITH_NAME(Cond, *it, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Can fuse longer sequences of identical conditions. + auto body = Block::make({ + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, j), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, j), + nullptr), + }); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 4); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Can't fuse through a non condition. + auto body = Block::make({ + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, j), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Store::make(b, {1}, i + j), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, j), + nullptr), + }); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 3); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt2->nstmts(), 2); + ASSERT_EQ(cond2->false_stmt(), nullptr); + + auto it = block->begin(); + it++; + IS_NODE_WITH_NAME(Store, *it, middle); + } + + { + // Can fuse if the conditions simplify to the same thing. + auto body = Block::make( + {Cond::make( + CompareSelect::make( + i * 2, + ExprHandle(87) % ExprHandle(11), + CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make( + i * 2, + ExprHandle(300) / ExprHandle(30), + CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Can fuse non-CompareSelects. + // if (i) { X } if (i) { Y } => if (i) { X; Y } + auto body = Block::make( + {Cond::make(i, Store::make(a, {0}, i), nullptr), + Cond::make(i, Store::make(a, {1}, i), nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Sanity check won't fuse different non-CompareSelects. + auto body = Block::make( + {Cond::make(i, Store::make(a, {0}, i), nullptr), + Cond::make(j, Store::make(a, {1}, i), nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + } + + { + // Sanity check constant condition elimination still occurs when merging is + // possible. + auto body = Block::make( + {Cond::make(1, Store::make(a, {0}, i), nullptr), + Cond::make(1, Store::make(a, {1}, i), nullptr)}); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Store, block->front(), store1); + IS_NODE_WITH_NAME(Store, block->back(), store2); + } + + { + // Sanity check for-cond reordering occurs after fusing. + auto body = For::make( + i, + 0, + 4, + Block::make( + {Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, Load::make(b, {0})), + nullptr), + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(a, {2}, Load::make(b, {0})), + nullptr)})); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + } +} + +TEST(Simplify, SimplifySyncThreads) { + BufHandle a("A", {4}, kInt); + VarHandle i("i", kInt); + + { + // Merge two inner SyncThreads. + auto body = Block::make( + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + {Store::make(a, {0}, 1), + alloc(), + alloc(), + Store::make(a, {1}, 0)}); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 3); + auto it = block->begin(); + IS_NODE(Store, *it++); + IS_NODE(SyncThreads, *it++); + IS_NODE(Store, *it++); + } + + { + // Eliminate outer SyncThreads. + auto body = Block::make( + {alloc(), Store::make(a, {1}, 0), alloc()}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + auto it = block->begin(); + IS_NODE(Store, *it); + } + + { + // Merge many inner SyncThreads. + auto body = Block::make( + {Store::make(a, {0}, 1), + alloc(), + alloc(), + alloc(), + alloc(), + alloc(), + Store::make(a, {1}, 0)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 3); + auto it = block->begin(); + IS_NODE(Store, *it++); + IS_NODE(SyncThreads, *it++); + IS_NODE(Store, *it++); + } + + { + // Merge multiple outer SyncThreads. + auto body = Block::make( + {alloc(), + alloc(), + Store::make(a, {1}, 0), + alloc(), + alloc(), + alloc(), + alloc()}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + auto it = block->begin(); + IS_NODE(Store, *it); + } + + { + // Merge multiple sections; + auto body = Block::make( + {Store::make(a, {0}, 1), + alloc(), + alloc(), + Store::make(a, {1}, 0), + Store::make(a, {2}, 0), + alloc(), + alloc(), + alloc(), + Store::make(a, {3}, 0)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 6); + auto it = block->begin(); + IS_NODE(Store, *it++); + IS_NODE(SyncThreads, *it++); + IS_NODE(Store, *it++); + IS_NODE(Store, *it++); + IS_NODE(SyncThreads, *it++); + IS_NODE(Store, *it++); + } +} + +TEST(Simplify, SimplifyRampSubBroadcast) { + int num_lanes = 4; + ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); + ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); + ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast); + RampPtr newRamp = simplified.AsNode(); + IS_NODE_WITH_NAME(IntImm, newRamp->base(), base); + ASSERT_EQ(base->value(), 5); + IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride); + ASSERT_EQ(stride->value(), 6); + ASSERT_EQ(newRamp->lanes(), num_lanes); +} + +TEST(Simplify, SimplifyBroadcastTermExpander) { + int num_lanes = 8; + ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); + ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes); + ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes); + // NB: We need a term in the middle which isn't simplified to trigger the + // relevant path in TermExpander::mutate. The two bc1 terms are brought + // together and simplified to 2 * bc1, which then needs to make 2 multi-lane. + ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1); + BufHandle buf("buf", {num_lanes}, kInt); + // The result isn't fully simplified currently and thus would be brittle to + // match. Observe its value instead. + auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified); + SimpleIREvaluator eval(store, {buf}); + std::vector output(num_lanes); + eval(output); + for (const auto i : c10::irange(num_lanes)) { + ASSERT_EQ(output[i], 2); + } +} + +TEST(Simplify, CompareSelectLoopBounds) { + constexpr int N = 8; + BufHandle b("b", {N}, kFloat); + VarHandle n("n", kInt); + VarHandle m("m", kInt); + VarHandle var_N("var_N", kInt); + VarHandle var_M("var_M", kInt); + + auto test_case_fn = [](const VarHandle& n, + const BufHandle& b, + const ExprHandle& start, + const ExprHandle& stop, + const int& cmp_val, + const CompareSelectOperation& cmp_op, + const std::string& check_string) { + StmtPtr s = For::make( + n, + start, + stop, + b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op))); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + std::string target_string = "# CHECK: "; + target_string += check_string; + torch::jit::testing::FileCheck().run(target_string, oss.str()); + }; + + auto test_case_nest_loops_fn = [](const VarHandle& n, + const VarHandle& m, + const BufHandle& b, + const ExprHandle& n_start, + const ExprHandle& n_stop, + const ExprHandle& m_start, + const ExprHandle& m_stop, + const CompareSelectOperation& cmp_op, + const std::string& check_string) { + StmtPtr s = For::make( + m, + m_start, + m_stop, + b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op))); + StmtPtr root_s = For::make(n, n_start, n_stop, s); + root_s = IRSimplifier::simplify(root_s); + std::ostringstream oss; + oss << *root_s; + std::string target_string = "# CHECK: "; + target_string += check_string; + torch::jit::testing::FileCheck().run(target_string, oss.str()); + }; + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 1 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 7 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 1 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 7 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 5 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 5 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 5 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 5 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, 2)) { + // b[n] = n == 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, 2)) { + // b[1] = 0.f; + // } + test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 1 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 7 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 1 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 7 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 5 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 5 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_N + 30, + var_N + 40, + kNE, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_M + 30, + var_M + 40, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 20, + kNE, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 20, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 31)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 31)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + test_case_nest_loops_fn( + n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 31, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 31, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + test_case_nest_loops_fn( + n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_N + 30, + var_N + 40, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_M + 30, + var_M + 40, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n < m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_N + 30, + var_N + 40, + kLT, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_M + 30, + var_M + 40, + kLT, + "b[n, m] = n m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 20, + kGT, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 20, + kGT, + "b[n, m] = n>m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n > m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 1.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_N + 30, + var_N + 40, + kGT, + "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_M + 30, + var_M + 40, + kGT, + "b[n, m] = n>m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 31)) { + // b[n, m] = (n >= m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 31)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 31, + kGE, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 31, + kGE, + "b[n, m] = n>=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n >= m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 1.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_N + 30, + var_N + 40, + kGE, + "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_M + 30, + var_M + 40, + kGE, + "b[n, m] = n>=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n <= m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_N + 30, + var_N + 40, + kLE, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_M + 30, + var_M + 40, + kLE, + "b[n, m] = n<=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = (n <= m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 20, + kLE, + "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 20, + kLE, + "b[n, m] = n<=m ? 0.f : 1.f;"); +} + +TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) { + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + constexpr int N = 8; + BufHandle b("b", {N}, kFloat); + VarHandle n("n", kInt); + StmtPtr s = For::make( + n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT))); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: b[n] = 1.f; +)IR", + oss.str()); +} + +TEST(Simplify, IfThenCondAlwaysInLoopBounds) { + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f); + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + constexpr int N = 8; + BufHandle b("b", {N}, kFloat); + VarHandle n("n", kInt); + StmtPtr s = + For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f))); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: b[n] = 1.f; +)IR", + oss.str()); +} + +TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) { + // This test mimics the unpadded region of a conv2d. We want to remove any + // conditional that is provably satisfied (or unsatisfied) by the entire loop + // range. + // Before: + // for (const auto i : c10::irange(1, 7)) { + // for (const auto j : c10::irange(1, 7)) { + // b[i, j] = IfThenElse( + // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f); + // After: + // for (const auto i : c10::irange(1, 7)) { + // for (const auto j : c10::irange(1, 7)) { + // b[i, j] = 1.f; + constexpr int N = 8; + BufHandle b("b", {N, N}, kFloat); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto csel = CompareSelect::make(i, 1, kLT); + csel = CompareSelect::make(j, 1, 1, csel, kLT); + csel = CompareSelect::make(i, N - 1, 1, csel, kGE); + csel = CompareSelect::make(j, N - 1, 1, csel, kGE); + StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); + s = For::make(j, 1, N - 1, s); + s = For::make(i, 1, N - 1, s); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: b[i, j] = 1.f; +)IR", + oss.str()); +} + +TEST(Simplify, DISABLED_SimplifyLoopBounds) { + // This test mimics the padded region of a conv2d. We want to adjust the + // loop bounds such that the condition will be always met. Note that this + // could be solved by peeling, and applying the range-based conditional + // simplification in the previous tests. + // Before: + // for (const auto i : c10::irange(3)) { + // for (const auto j : c10::irange(3)) { + // b[i, j] = (b[i, j]) + (IfThenElse( + // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j])); + // After: + // for (const auto i : c10::irange(1, 3)) { + // for (const auto j : c10::irange(1, 3)) { + // b[i, j] = (b[i, j]) + 1.f; + constexpr int N = 8; + constexpr int K = 3; + BufHandle a("a", {N, N}, kFloat); + BufHandle b("b", {N, N}, kFloat); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto csel = CompareSelect::make(i, 1, kLT); + csel = CompareSelect::make(j, 1, 1, csel, kLT); + csel = CompareSelect::make(i, N - 1, 1, csel, kGE); + csel = CompareSelect::make(j, N - 1, 1, csel, kGE); + StmtPtr s = b.store( + {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j}))); + s = For::make(j, 0, K, s); + s = For::make(i, 0, K, s); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: for (const auto i : c10::irange(1, 3)) { +# CHECK: for (const auto j : c10::irange(1, 3)) { +# CHECK-NOT: IfThenElse +)IR", + oss.str()); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp new file mode 100644 index 000000000000..56535de914e4 --- /dev/null +++ b/test/cpp/tensorexpr/test_te_fuser_pass.cpp @@ -0,0 +1,402 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +struct WithCPUFuser { + WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { + overrideCanFuseOnCPU(val); + } + + ~WithCPUFuser() { + overrideCanFuseOnCPU(cpuFuserEnabled); + } + + bool cpuFuserEnabled; +}; + +TEST(TEFuserPass, FuserPass_1) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%0 : Float(128, strides=[1], device=cpu), + %1 : Float(128, strides=[1], device=cpu)): + %12 : int = prim::Constant[value=1]() + %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) + %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) + %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) + %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) + %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // We should not be able to fuse across the in-place operation here. + testing::FileCheck() + .check("prim::TensorExprGroup_") + ->check("aten::add_") + ->check("prim::TensorExprGroup_") + ->run(*g); +} + +TEST(TEFuserPass, FuserPass_2) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%0 : Float(128, strides=[1], device=cpu), + %1 : Float(128, strides=[1], device=cpu)): + %12 : int = prim::Constant[value=1]() + %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) + %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) + %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) + %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) + return (%d))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // We should not be able to fuse across the in-place operation here. + testing::FileCheck() + .check("aten::add_") + ->check("prim::TensorExprGroup_0") + ->run(*g); +} + +TEST(TEFuserPass, FuserPass_3) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(128, strides=[1], device=cpu), + %y : Float(128, strides=[1], device=cpu)): + %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) + return (%r))IR"; + { + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + + // We should not create a fusion group since its size would be too small + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // We should create a fusion group since its size is above the threshold + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); + } +} + +TEST(TEFuserPass, FuserPass_0DimInput) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(device=cpu), + %y : Float(device=cpu)): + %one : int = prim::Constant[value=1]() + %a : Float(device=cpu) = aten::mul(%x, %y) + %b : Float(device=cpu) = aten::add(%x, %a, %one) + return (%b))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // We should fuse 0-dim tensors too + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_UnfusibleDevice) { + WithCPUFuser cf(false); + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(10, strides=[1], device=cpu)): + %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) + return (%a))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // Test that we're not starting fusion groups from nodes with unfusible device + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_UnknownShapes) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Tensor, + %y : Tensor): + %a : Tensor = aten::mul(%x, %y) + %b : Tensor = aten::mul(%x, %a) + return (%b))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // Test that we're not generating fusion groups when shapes are not known + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_Multidevice) { + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + return (%cat))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // We should be able to fuse this + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cuda:0), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + return (%cat))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // We should not fuse this aten::cat since its inputs are from different + // devices + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(10, strides=[1], device=cuda:0)): + %dim : int = prim::Constant[value=0]() + %xy_list : Tensor[] = prim::ListConstruct(%x, %y) + %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) + %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z) + return (%r))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + + // Test that we check device before merging one node (cat) into another + // (mul) + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(10, strides=[1], device=cuda:0)): + %z2 : Tensor = aten::mul(%z, %z) + %dim : int = prim::Constant[value=0]() + %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) + return (%cat))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + + // Test that we check device before merging one node (mul) into another + // (cat) + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cuda:0)): + %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) + return (%r))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // We should not fuse this graph since its inputs are from different devices + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cuda:0), + %y : Float(20, strides=[1], device=cuda:1), + %z : Float(20, strides=[1], device=cpu)): + %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x) + %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y) + %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z) + return (%x2, %y2, %z2))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + + // We should not fuse these two computations since they use different + // devices + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } +} + +TEST(TEFuserPass, FuserPass_MergeGroups) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%a : Float(128, strides=[1], device=cpu), + %b : Float(128, strides=[1], device=cpu)): + %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) + %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) + return (%x, %y))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // The %x and %y computations are completely independent and yet we should put + // them into a single fusion group rather than having two separate ones. + testing::FileCheck() + .check("= prim::TensorExprGroup_") + ->check_not("= prim::TensorExprGroup_") + ->run(*g); +} + +TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Bool(8, strides=[1], device=cpu), + %y : Bool(8, strides=[1], device=cpu)): + %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y) + %b : Tensor = aten::__or__(%a, %y) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_Where) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(8, strides=[1], device=cpu), + %y : Float(8, strides=[1], device=cpu), + %z : Float(8, strides=[1], device=cpu)): + %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) + %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_WhereList) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(8, strides=[1], device=cpu), + %y : Float(8, strides=[1], device=cpu), + %z : Float(8, strides=[1], device=cpu)): + %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) + %b : Tensor[] = aten::where(%cond) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, DynamicShapeFusion) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%0 : Float(10, 5, strides=[5, 1], device=cpu), + %1 : Float(10, 5, strides=[5, 1], device=cpu)): + %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1) + %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1) + return (%3))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs( + g, + /* min_group_size = */ 2, + /* add_composed_op = */ true, + /* fuse_to_dynamic_shapes = */ true); + Code code(g, ""); + + testing::FileCheck() + .check("prim::TensorExprDynamicGroup_") + ->check("prim::TensorExprDynamicGuard") + ->check("prim::TensorExprGroup_") + ->run(*g); + + auto run_and_compare = [&](const std::vector& inputs) { + TORCH_INTERNAL_ASSERT(inputs.size() == 2); + + auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]); + + InterpreterState interp(code); + Stack stack(inputs.begin(), inputs.end()); + interp.run(stack); + at::Tensor out = pop(stack).toTensor(); + ASSERT_TRUE(at::allclose(out, ref)); + }; + + std::vector inputs = {at::rand({10, 5}), at::rand({10, 5})}; + run_and_compare(inputs); + + std::vector inputs2 = {at::rand({20, 5}), at::rand({20, 5})}; + run_and_compare(inputs2); + + std::vector inputs3 = {at::rand({25, 60}), at::rand({25, 60})}; + run_and_compare(inputs3); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp new file mode 100644 index 000000000000..6758503f4de7 --- /dev/null +++ b/test/cpp/tensorexpr/test_type.cpp @@ -0,0 +1,202 @@ +#include + +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +TEST(Type, Test01) { + { + Dtype dt1 = kInt; + ASSERT_EQ(dt1, kInt); + } + { + Dtype dt2_a(kInt, 8); + Dtype dt2_b(kInt, 4); + Dtype dt2_c(ScalarType::Int, 8); + ASSERT_EQ(dt2_a, dt2_c); + ASSERT_NE(dt2_a, dt2_b); + } + { + ASSERT_EQ(kInt, ToDtype()); + ASSERT_EQ(kFloat, ToDtype()); + ASSERT_EQ(kByte, ToDtype()); + ASSERT_EQ(kChar, ToDtype()); + ASSERT_EQ(kShort, ToDtype()); + ASSERT_EQ(kLong, ToDtype()); + ASSERT_EQ(kHalf, ToDtype()); + ASSERT_EQ(kDouble, ToDtype()); + ASSERT_EQ(kBool, ToDtype()); + } + { + Dtype int32x8(kInt, 8); + Dtype float32x8(kFloat, 8); + ASSERT_NE(int32x8, float32x8); + ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8)); + ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8)); + ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8)); + ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8)); + } +} + +TEST(Type, BitCasting) { + { + VarHandle x("x", kFloat); + ExprHandle y = bitcast(x); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(y.dtype(), kInt); + } + { + VarHandle x("x", kInt); + ExprHandle y = bitcast(x); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(y.dtype(), kFloat); + } + { + VarHandle x("x", kShort); + ExprHandle y = bitcast(x); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(y.dtype(), kHalf); + } + { + VarHandle x("x", kHalf); + ExprHandle y = bitcast(x); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(y.dtype(), kShort); + } + + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + using SimpleIRExprEval = ExprEval; + // this is broken + /*{ + constexpr int16_t ref16 = 1337; + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(*k); + auto b = BitCast::make(kShort, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } + + // This segfaults :( + /*{ + VarHandle x("x", kDouble); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + VarHandle x("x", kFloat); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + VarHandle x("x", kLong); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + VarHandle x("x", kShort); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + VarHandle x("x", kInt); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + }*/ +} + +TEST(Type, Propagation) { + // Same types: + { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = FloatImm::make(2.f) + + (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y); + ASSERT_EQ(body.dtype(), kFloat); + } + // Int to bigger int: + { + VarHandle x("x", kShort); + VarHandle y("y", kLong); + ExprHandle body = + ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y); + ASSERT_EQ(body.dtype(), kLong); + } + // Float to bigger float: + { + VarHandle x("x", kHalf); + VarHandle y("y", kDouble); + ExprHandle body = + HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y); + ASSERT_EQ(body.dtype(), kDouble); + } + // Int to Float: + { + VarHandle x("x", kFloat); + VarHandle y("y", kInt); + ExprHandle body = + IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y); + ASSERT_EQ(body.dtype(), kFloat); + } + // Smaller float, bigger Int: + { + VarHandle x("x", kHalf); + VarHandle y("y", kLong); + ExprHandle body = + HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y); + ASSERT_EQ(body.dtype(), kHalf); + } + // Bigger float, smaller Int: + { + VarHandle x("x", kChar); + VarHandle y("y", kDouble); + ExprHandle body = + CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); + ASSERT_EQ(body.dtype(), kDouble); + } + // Sign change char/byte upgrades to short: + { + VarHandle x("x", kChar); + VarHandle y("y", kByte); + ExprHandle body = + CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); + ASSERT_EQ(body.dtype(), kShort); + } +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_type_specializations.cpp b/test/cpp/tensorexpr/test_type_specializations.cpp new file mode 100644 index 000000000000..d9756627fa74 --- /dev/null +++ b/test/cpp/tensorexpr/test_type_specializations.cpp @@ -0,0 +1,75 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +// Test that tensor type specializations are available in +// the custom passes + +namespace torch { +namespace jit { + +namespace { + +bool hasTensorTypeSpecializations(torch::jit::Block* block) { + for (Value* v : block->inputs()) { + if (hasTensorTypeSpecialization(v)) + return true; + } + for (Node* n : block->nodes()) { + for (torch::jit::Block* b : n->blocks()) { + if (hasTensorTypeSpecializations(b)) + return true; + } + for (Value* v : n->outputs()) { + if (hasTensorTypeSpecialization(v)) + return true; + } + } + return false; +} + +static bool hasSpecializations = false; +void detectTTSpecializationPass(std::shared_ptr& graph) { + GRAPH_DUMP("In detectTTSpecialization Custom Post Pass: ", graph); + hasSpecializations = hasTensorTypeSpecializations(graph->block()); +} + +} // namespace + +TEST(SpecializationsInCustomPasses, Basic) { + RegisterPass p(detectTTSpecializationPass); + hasSpecializations = false; + std::shared_ptr graph = std::make_shared(); + parseIR( + R"IR( +graph(%a.1 : Tensor, + %b.1 : Tensor): + %c.1 : Tensor = aten::mul(%a.1, %b.1) # misc/test_specializations.py:5:8 + %d.1 : Tensor = aten::mul(%c.1, %b.1) # misc/test_specializations.py:6:8 + return (%d.1) + )IR", + &*graph); + + IValue ival = IValue(torch::randn({22}, at::kCPU)); + std::vector stack = {ival, ival}; + auto run = [&](std::shared_ptr& graph, std::vector stack) { + GraphExecutor executor(graph, ""); + executor.run(stack); + return stack; + }; + run(graph, stack); + + // Profiling mode will not be run with simple executor + if (!getExecutorMode()) { + EXPECT_TRUE(hasSpecializations); + } +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_utils.h b/test/cpp/tensorexpr/test_utils.h new file mode 100644 index 000000000000..065e513c1a64 --- /dev/null +++ b/test/cpp/tensorexpr/test_utils.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +#define IS_NODE(T, node) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + } + +#define IS_NODE_WITH_NAME(T, node, name) \ + auto name = to(node); \ + ASSERT_NE(nullptr, name); + +#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ + NodePtr name = nullptr; \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ + name = to(node_->src_value()); \ + } \ + ASSERT_NE(nullptr, name); + +#define IS_IMM_WITH_VAL(T, node, val) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->value(), val); \ + } + +#define IS_VAR_WITH_NAME(node, name) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->name_hint(), name); \ + } + +#define IS_BINOP_W_VARS(T, node, name, v1, v2) \ + NodePtr name = nullptr; \ + { \ + name = to(node); \ + ASSERT_NE(nullptr, name); \ + IS_VAR_WITH_NAME(name->lhs(), v1); \ + IS_VAR_WITH_NAME(name->rhs(), v2); \ + } + +#define IS_BINOP_W_CONST(T, node, name, v, c) \ + NodePtr name = nullptr; \ + { \ + name = to(node); \ + ASSERT_NE(nullptr, name); \ + IS_VAR_WITH_NAME(name->lhs(), v); \ + IS_IMM_WITH_VAL(Int, name->rhs(), c); \ + } + +#define IS_RAND(node) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->op_type(), kRand); \ + } + +void checkIR(StmtPtr s, const std::string& pattern); +void checkExprIR(ExprPtr e, const std::string& pattern); +void checkExprIR(const ExprHandle& e, const std::string& pattern); + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp new file mode 100644 index 000000000000..3f4c32af463b --- /dev/null +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -0,0 +1,542 @@ +// *** Tensor Expressions *** +// +// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to +// work with them, and outlines how they are used in the overall TorchScript +// compilation pipeline. This doc is permanently a "work in progress" since NNC +// is under active development and things change fast. +// +// This Tutorial's code is compiled in the standard pytorch build, and the +// executable can be found in `build/bin/tutorial_tensorexpr`. +// +// *** What is NNC *** +// +// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT +// and it performs on-the-fly code generation for kernels, which are often a +// combination of multiple aten (torch) operators. +// +// When the JIT interpreter executes a torchscript model, it automatically +// extracts subgraphs from the torchscript IR graph for which specialized code +// can be JIT generated. This usually improves performance as the 'combined' +// kernel created from the subgraph could avoid unnecessary memory traffic that +// is unavoidable when the subgraph is interpreted as-is, operator by operator. +// This optimization is often referred to as 'fusion'. Relatedly, the process of +// finding and extracting subgraphs suitable for NNC code generation is done by +// a JIT pass called 'fuser'. +// +// *** What is TE *** +// +// TE stands for Tensor Expressions. TE is a commonly used approach for +// compiling kernels performing tensor (~matrix) computation. The idea behind it +// is that operators are represented as a mathematical formula describing what +// computation they do (as TEs) and then the TE engine can perform mathematical +// simplification and other optimizations using those formulas and eventually +// generate executable code that would produce the same results as the original +// sequence of operators, but more efficiently. +// +// NNC's design and implementation of TE was heavily inspired by Halide and TVM +// projects. +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace torch::jit::tensorexpr; + +#ifdef TORCH_ENABLE_LLVM + +// Helper function to print a snippet from a big multi-line string +static void printLinesToFrom(const std::string& input_str, int from, int to); + +#endif + +int main(int argc, char* argv[]) { + std::cout << "*** Structure of tensor expressions and statements ***" + << std::endl; + { + // A tensor expression is a tree of expressions. Each expression has a type, + // and that type defines what sub-expressions the current expression has. + // For instance, an expression of type 'Mul' would have a type 'kMul' and + // two subexpressions: LHS and RHS. Each of these two sub-expressions could + // also be a 'Mul' or some other expression. + // + // Let's construct a simple TE: + ExprPtr lhs = alloc(5); + ExprPtr rhs = alloc("x", kInt); + ExprPtr mul = alloc(lhs, rhs); + std::cout << "Tensor expression: " << *mul << std::endl; + // Prints: Tensor expression: 5 * x + + // Here we created an expression representing a 5*x computation, where x is + // an int variable. + + // Another, probably a more convenient, way to construct tensor expressions + // is to use so called expression handles (as opposed to raw expressions + // like we did in the previous example). Expression handles overload common + // operations and allow us to express the same semantics in a more natural + // way: + ExprHandle l = 5; + ExprHandle r = Var::make("x", kInt); + ExprHandle m = l * r; + std::cout << "Tensor expression: " << *m.node() << std::endl; + // Prints: Tensor expression: 5 * x + + // Converting from handles to raw expressions and back is easy: + ExprHandle handle = Var::make("x", kInt); + ExprPtr raw_expr_from_handle = handle.node(); + ExprPtr raw_expr = alloc("x", kInt); + ExprHandle handle_from_raw_expr = ExprHandle(raw_expr); + + // We could construct arbitrarily complex expressions using mathematical + // and logical operations, casts between various data types, and a bunch of + // intrinsics. + ExprHandle a = Var::make("a", kInt); + ExprHandle b = Var::make("b", kFloat); + ExprHandle c = Var::make("c", kFloat); + ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f); + std::cout << "Tensor expression: " << *x.node() << std::endl; + // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f) + + // An ultimate purpose of tensor expressions is to optimize tensor + // computations, and in order to represent accesses to tensors data, there + // is a special kind of expression - a load. + // To construct a load we need two pieces: the base and the indices. The + // base of a load is a Buf expression, which could be thought of as a + // placeholder similar to Var, but with dimensions info. + // + // Let's construct a simple load: + BufHandle A("A", {64, 32}, kInt); + VarPtr i_var = alloc("i", kInt), j_var = alloc("j", kInt); + ExprHandle i(i_var), j(j_var); + ExprHandle load = Load::make(A.dtype(), A, {i, j}); + std::cout << "Tensor expression: " << *load.node() << std::endl; + // Prints: Tensor expression: A[i, j] + + // Tensor Expressions constitute Tensor Statements, which are used to + // represent computation of a given operator or a group of operators from a + // fusion group. + // + // There are three main kinds of tensor statements: + // - block + // - store + // - loop + // + // A Store represents a store to a single element of a tensor (or to a + // group of elements if it's a vectorized store). Store statements, + // similarly to Load expressions, have a base and indices, but on top of + // that they also include a value - an expression representing what needs + // to be stored at the given memory location. Let's create a Store stmt: + StmtPtr store_a = Store::make(A, {i, j}, i + j); + std::cout << "Store statement: " << *store_a << std::endl; + // Prints: Store statement: A[i, j] = i + j; + + // An operator fills the entire tensor, not just a single element, and to + // represent this we need to use For stmt: let's wrap our store stmt with + // two nested loops to represent that variables i and j need to iterate + // over some ranges. + ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a); + ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a); + + std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl; + // Prints: + // Nested for loops: + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // A[i, j] = i + j; + // } + // } + + // A Block statement is used when we need a sequence of other statements. + // E.g. if a fusion group contains several operators, we initially define + // separate loopnest for each of them and put them all into a common block: + BufHandle B("B", {64, 32}, kInt); + StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j)); + ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b); + ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b); + + BlockPtr block = Block::make({loop_i_a, loop_i_b}); + std::cout << "Compound Block statement: " << std::endl + << *block << std::endl; + // Prints: + // Compound Block statement: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // A[i, j] = i + j; + // } + // } + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // B[i, j] = A[i, j]; + // } + // } + // } + + // Manually constructing nested loops and blocks to represent a computation + // might be laborious, and instead we can use a 'Compute' API. This API + // requires us to specify dimensions and a lambda to compute a single + // element of the resulting tensor and returns a `Tensor` structure. This + // structure is simply a pair of a buffer that was created to represent the + // result of the computation (BufPtr) and a statement representing the + // computation itself (StmtPtr). + Tensor C = + Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + std::cout << "Stmt produced by 'Compute' API: " << std::endl + << *C.stmt() << std::endl; + // Prints: + // Stmt produced by 'Compute' API: + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // C[i, j] = i * j; + // } + // } + + // To construct statements to represent computations with reductions, we + // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple + // of extra arguments defining how to perform the reduction. Let's define a + // simple 2D sum of C using that: + Tensor D = Reduce( + "D", + {}, + Sum(), + [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); }, + {64, 32}); + std::cout << "Stmt produced by 'Reduce' API: " << std::endl + << *D.stmt() << std::endl; + } + + std::cout << "*** Loopnests transformations ***" << std::endl; + { + // When a statement for the computation is generated, we might want to + // apply some optimizations to it. These transformations allow us to end up + // with a statement producing the same results, but more efficiently. + // + // Let's look at a couple of transformations that are used in NNC. We will + // begin with constructing a Block statement like we did before. + + Tensor C = + Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { + return i * (j + 1); + }); + BufHandle c_buf(C.buf()); + Tensor D = + Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { + return c_buf.load(i, j) - i; + }); + StmtPtr block = Block::make({C.stmt(), D.stmt()}); + std::cout << "Stmt produced by 'Compute' API: " << std::endl + << *block << std::endl; + // Prints: + // Stmt produced by 'Compute' API: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // C[i, j] = i * (j + 1); + // } + // } + // for (const auto i_1 : c10::irange(64)) { + // for (const auto j_1 : c10::irange(32)) { + // D[i_1, j_1] = (C[i_1, j_1]) - i_1; + // } + // } + // } + + // One transformation we can apply to this computation is inlining: i.e. + // taking the expression that defines values of C and substituting a load + // from C with it. + // To do that, we first need to create a special object called LoopNest - + // all transformations are methods of this class. To create a loopnest we + // need to provide a list of output buffers and the root statement: + LoopNest nest(block, {D.buf()}); + + // We can always retrieve the Stmt back from LoopNest: + std::cout << "LoopNest root stmt: " << std::endl + << *nest.root_stmt() << std::endl; + // Prints: + // LoopNest root stmt: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // C[i, j] = i * (j + 1); + // } + // } + // for (const auto i_1 : c10::irange(64)) { + // for (const auto j_1 : c10::irange(32)) { + // D[i_1, j_1] = (C[i_1, j_1]) - i_1; + // } + // } + // } + + // Now we can apply the inlining transformation: + nest.computeInline(C.buf()); + std::cout << "Stmt after inlining:" << std::endl + << *nest.root_stmt() << std::endl; + // Prints: + // Stmt after inlining: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // D[i, j] = i * (j + 1) - i; + // } + // } + // } + + // We can also apply algebraic simplification to a statement: + StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt()); + std::cout << "Stmt after simplification:" << std::endl + << *simplified << std::endl; + // Prints: + // Stmt after simplification: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // D[i, j] = i * j; + // } + // } + // } + + // Many loopnest transformations are stateless and can be applied without + // creating a LoopNest object. In fact, we plan to make all transformations + // stateless. + // splitWithTail is one such transformation: it splits an iteration space + // of a given loop into two with a given factor. + ForPtr outer_loop = to(to(simplified)->stmts().front()); + LoopNest::splitWithTail(outer_loop, 13); + // Call simplifier once more to fold some arithmetic. + simplified = IRSimplifier::simplify(simplified); + std::cout << "Stmt after splitWithTail:" << std::endl + << *simplified << std::endl; + // Prints: + // Stmt after splitWithTail: + // { + // for (const auto i_outer : c10::irange(4)) { + // for (const auto i_inner : c10::irange(13)) { + // for (const auto j : c10::irange(32)) { + // D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j); + // } + // } + // } + // for (const auto i_tail : c10::irange(12)) { + // for (const auto j : c10::irange(32)) { + // D[i_tail + 52, j] = i_tail * j + 52 * j; + // } + // } + // } + + // NNC supports a wide range of loop nest transformations, which we are not + // listing here. Please refer to documentation in + // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h + // for more details. + } + + std::cout << "*** Codegen ***" << std::endl; + { + // An ultimate goal of tensor expressions is to be provide a mechanism to + // execute a given computation in the fastest possible way. So far we've + // looked at how we could describe what computation we're interested in, but + // we haven't looked at how to actually execute it. + // + // All we've been dealing with was just symbols with no actual data + // associated, in this section we would look at how we can bridge that gap. + + // Let's start by constructing a simple computation for us to work with: + BufHandle A("A", {64, 32}, kInt); + BufHandle B("B", {64, 32}, kInt); + Tensor X = + Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i, j) + B.load(i, j); + }); + + // And let's lower it to a loop nest, as we did in the previous section. We + // can pass Tensor object directly: + LoopNest loopnest({X}); + std::cout << *loopnest.root_stmt() << std::endl; + // Prints: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // X[i, j] = (A[i, j]) + (B[i, j]); + // } + // } + + // Now imagine that we have two actual tensors 64x32 that we want sum + // together, how do we pass those tensors to the computation and how do we + // carry it out? + // + // Codegen object is aimed at providing exactly that functionality. Codegen + // is an abstract class and concrete codegens are derived from it. + // Currently, we have three codegens: + // 1) Simple Evaluator, + // 2) LLVM Codegen for CPU, + // 3) CUDA Codegen. + // In this example we will be using Simple Evaluator, since it's available + // everywhere. + + // To create a codegen, we need to provide the statement - it specifies the + // computation we want to perform - and a list of placeholders and tensors + // used in the computation. The latter part is crucial since that's the only + // way the codegen could use to correlate symbols in the statement to actual + // data arrays that we will be passing when we will actually be performing + // the computation. + // + // Let's create a Simple IR Evaluator codegen for our computation: + SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X}); + + // We are using the simplest codegen and in it almost no work is done at the + // construction step. Real codegens such as CUDA and LLVM perform + // compilation during that stage so that when we're about to run the + // computation everything is ready. + + // Let's now create some inputs and run our computation with them: + std::vector data_A(64 * 32, 3); // This will be the input A + std::vector data_B(64 * 32, 5); // This will be the input B + std::vector data_X(64 * 32, 0); // This will be used for the result + + // Now let's invoke our codegen to perform the computation on our data. We + // need to provide as many arguments as how many placeholders and tensors we + // passed at the codegen construction time. A position in these lists would + // define how real data arrays from the latter call (these arguments are + // referred to as 'CallArg's in our codebase) correspond to symbols + // (placeholders and tensors) used in the tensor expressions we constructed + // (these are referred to as 'BufferArg'). + // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A + // contains data for the placeholder A, data_B - for the placeholder B, and + // data_X would be used for contents of tensor X. + ir_eval(data_A, data_B, data_X); + + // Let's print one of the elements from each array to verify that the + // computation did happen: + std::cout << "A[10] = " << data_A[10] << std::endl + << "B[10] = " << data_B[10] << std::endl + << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl; + // Prints: + // A[10] = 3 + // B[10] = 5 + // X[10] = A[10] + B[10] = 8 + } + + std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl; + { + // This section requires a LLVM-enabled PyTorch build, so we have to use a + // guard: +#ifdef TORCH_ENABLE_LLVM + + // Often we would like to convert a TorchScript IR to TE rather than + // construct TE IR from scratch. NNC provides an API to perform such + // lowering: it takes a TorchScript graph and returns an object that can be + // used to invoke the generated kernel. + // This API is currently used by the TorchScript JIT fuser and can also be + // used ahead of time to pre-compile parts of a model. + // + // To get familiar with this API let's first start with defining a simple + // TorchScript graph: + const auto graph_string = R"IR( + graph(%A : Float(5, 3, strides=[3, 1], device=cpu), + %B : Float(5, 3, strides=[3, 1], device=cpu)): + %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B) + %one : int = prim::Constant[value=1]() + %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB) + %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one) + return (%AAB_plus_B))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + // This graph defines a simple computation of A*A*B + B where A and B are + // input 5x3 tensors. + + // To lower this TorchScript graph to TE, we just need to create a + // TensorExprKernel object. In its constructor it constructs the + // corresponding TE IR and compiles it for the given backend (in this + // example for CPU using LLVM compiler). + TensorExprKernel kernel(graph); + + // We can retrieve the generated TE stmt from the kernel object: + StmtPtr kernel_stmt = kernel.getCodeGenStmt(); + std::cout << "TE Stmt constructed from TorchScript: " << std::endl + << *kernel_stmt << std::endl; + // Prints: + // TE Stmt constructed from TorchScript: + // { + // for (const auto v : c10::irange(5)) { + // for (const auto _tail_tail : c10::irange(3)) { + // aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) * + // ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) + + // (tB[_tail_tail + 3 * v]); + // } + // } + // } + + // We can also examine generated LLVM IR and assembly code: + std::cout << "Generated LLVM IR: " << std::endl; + auto ir_str = kernel.getCodeText("ir"); + printLinesToFrom(ir_str, 15, 20); + // Prints: + // Generated LLVM IR: + // %9 = bitcast float* %2 to <8 x float>* + // %10 = load <8 x float>, <8 x float>* %9 ... + // %11 = bitcast float* %5 to <8 x float>* + // %12 = load <8 x float>, <8 x float>* %11 ... + // %13 = fmul <8 x float> %10, %12 + // %14 = fmul <8 x float> %10, %13 + + std::cout << "Generated assembly: " << std::endl; + auto asm_str = kernel.getCodeText("asm"); + printLinesToFrom(asm_str, 10, 15); + // Prints: + // Generated assembly: + // vmulps %ymm1, %ymm0, %ymm2 + // vfmadd213ps %ymm1, %ymm0, %ymm2 + // vmovups %ymm2, (%rax) + // vmovss 32(%rcx), %xmm0 + // vmovss 32(%rdx), %xmm1 + // vmulss %xmm1, %xmm0, %xmm2 + + // We can also execute the generated kernel: + auto A = + at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * + 2.0; + auto B = + at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * + 3.0; + std::vector inputs = {A, B}; + std::vector stack = torch::fmap(inputs); + kernel.run(stack); + auto R = stack[0].toTensor(); + + // Let's print one of the elements from the result tensor to verify that the + // computation did happen and was correct: + std::cout << "R[2][2] = " << R[2][2] << std::endl; + // Prints: + // R[2][2] = 15 + // [ CPUFloatType{} ] +#endif + } + return 0; +} + +void printLinesToFrom(const std::string& input_str, int from, int to) { + std::istringstream f(input_str); + std::string s; + int idx = 0; + while (getline(f, s)) { + if (idx > from) { + std::cout << s << "\n"; + } + if (idx++ > to) { + break; + } + } +} diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c3e26d37da1b..8d3a8090c67a 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -2939,10 +2939,7 @@ def f({", ".join(param_names)}): @slowTest @onlyCPU - @ops( - [op for op in op_db if get_name(op) not in known_failures], - dtypes=OpDTypes.supported, - ) + @ops(op_db, dtypes=OpDTypes.supported) def test_nnc_correctness(self, device, dtype, op): if not op.supports_tracing: self.skipTest("Requires tracing support") diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 9e408682ca6c..d5586a5b9cd7 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1910,7 +1910,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { } auto& out_t = p_node->Output(0).toTensor(); - if (te && te->checkInput(in0_t) && in0_t.sizes() == in1_t.sizes() && + if (in0_t.sizes() == in1_t.sizes() && in0_t.scalar_type() == in1_t.scalar_type() && in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() && in0_t.scalar_type() == at::kFloat) {